use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use crate::datatype::{Rank, Tag};
pub const PROTOCOL_VERSION: u16 = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EnvelopeKind {
PointToPoint,
}
fn default_protocol_version() -> u16 {
PROTOCOL_VERSION
}
fn default_envelope_kind() -> EnvelopeKind {
EnvelopeKind::PointToPoint
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Envelope {
#[serde(default = "default_protocol_version")]
pub protocol_version: u16,
#[serde(default = "default_envelope_kind")]
pub kind: EnvelopeKind,
pub src: Rank,
pub dst: Rank,
pub tag: Tag,
pub payload: Vec<u8>,
}
impl Envelope {
pub fn new(src: Rank, dst: Rank, tag: Tag, payload: Vec<u8>) -> Self {
Self {
protocol_version: PROTOCOL_VERSION,
kind: EnvelopeKind::PointToPoint,
src,
dst,
tag,
payload,
}
}
pub fn validate(&self, size: Rank) -> Result<()> {
if self.protocol_version != PROTOCOL_VERSION {
return Err(Error::Protocol(format!(
"PROTO_UNSUPPORTED_VERSION: expected={}, got={}",
PROTOCOL_VERSION, self.protocol_version
)));
}
if size <= 0 {
return Err(Error::Protocol(
"PROTO_INVALID_ENVELOPE: communicator size must be > 0".to_string(),
));
}
if self.src < 0 || self.src >= size {
return Err(Error::Protocol(format!(
"PROTO_INVALID_ENVELOPE: invalid source rank {}, size={}",
self.src, size
)));
}
if self.dst < 0 || self.dst >= size {
return Err(Error::Protocol(format!(
"PROTO_INVALID_ENVELOPE: invalid destination rank {}, size={}",
self.dst, size
)));
}
Ok(())
}
}
pub fn encode<T>(value: &T) -> Result<Vec<u8>>
where
T: Serialize,
{
serde_json::to_vec(value).map_err(|err| Error::Serialization(err.to_string()))
}
pub fn decode<T>(payload: &[u8]) -> Result<T>
where
T: DeserializeOwned,
{
serde_json::from_slice(payload).map_err(|err| Error::Serialization(err.to_string()))
}
#[cfg(test)]
mod tests {
use super::{Envelope, EnvelopeKind, PROTOCOL_VERSION};
#[test]
fn envelope_new_sets_protocol_defaults() {
let env = Envelope::new(0, 1, 7, vec![1, 2, 3]);
assert_eq!(env.protocol_version, PROTOCOL_VERSION);
assert_eq!(env.kind, EnvelopeKind::PointToPoint);
}
#[test]
fn envelope_validate_rejects_bad_version() {
let mut env = Envelope::new(0, 1, 0, vec![]);
env.protocol_version = PROTOCOL_VERSION + 1;
let err = env.validate(2).expect_err("version mismatch must fail");
assert!(err.to_string().contains("PROTO_UNSUPPORTED_VERSION"));
}
#[test]
fn envelope_validate_rejects_bad_ranks() {
let env = Envelope::new(5, 1, 0, vec![]);
let err = env
.validate(2)
.expect_err("source rank out of range must fail");
assert!(err.to_string().contains("PROTO_INVALID_ENVELOPE"));
}
}