use crate::alloc::string::ToString;
use crate::compat::string::String;
use crate::errcode::{Kind, Origin};
#[cfg(feature = "std")]
use crate::OpenTelemetryContext;
use crate::{compat::vec::Vec, Decodable, Encodable, Encoded, Message, Route};
use crate::{Error, Result};
use core::fmt::{self, Display, Formatter};
pub type ProtocolVersion = u8;
pub const LATEST_PROTOCOL_VERSION: ProtocolVersion = 2;
pub const PROTOCOL_VERSION_V1: ProtocolVersion = 1;
#[derive(Debug, Clone, Eq, PartialEq, Message)]
pub struct TransportMessage {
pub version: ProtocolVersion,
pub onward_route: Route,
pub return_route: Route,
pub payload: Vec<u8>,
pub tracing_context: Option<String>,
}
impl TransportMessage {
pub fn latest(
onward_route: impl Into<Route>,
return_route: impl Into<Route>,
payload: Vec<u8>,
) -> Self {
TransportMessage::new(
LATEST_PROTOCOL_VERSION,
onward_route.into(),
return_route.into(),
payload,
None,
)
}
pub fn v1(
onward_route: impl Into<Route>,
return_route: impl Into<Route>,
payload: Vec<u8>,
) -> Self {
TransportMessage::new(
PROTOCOL_VERSION_V1,
onward_route,
return_route,
payload,
None,
)
}
pub fn new(
version: ProtocolVersion,
onward_route: impl Into<Route>,
return_route: impl Into<Route>,
payload: Vec<u8>,
tracing_context: Option<String>,
) -> Self {
Self {
version,
onward_route: onward_route.into(),
return_route: return_route.into(),
payload,
tracing_context,
}
}
#[cfg(feature = "std")]
pub fn with_tracing_context(self, tracing_context: String) -> Self {
Self {
tracing_context: Some(tracing_context),
..self
}
}
pub fn decode_message(buf: Vec<u8>) -> Result<TransportMessage> {
if buf.is_empty() {
return Err(Error::new(
Origin::Transport,
Kind::Serialization,
"empty buffer, no transport message received".to_string(),
));
};
let version = buf[0];
match version {
PROTOCOL_VERSION_V1 => TransportMessageV1::decode(&buf)
.map(|t| t.to_latest())
.map_err(|e| {
Error::new(
Origin::Transport,
Kind::Serialization,
format!("Error decoding message: {:?}", e),
)
}),
LATEST_PROTOCOL_VERSION => TransportMessage::decode(&buf).map_err(|e| {
Error::new(
Origin::Transport,
Kind::Serialization,
format!("Error decoding message: {:?}", e),
)
}),
v => Err(Error::new(
Origin::Transport,
Kind::Serialization,
format!("Unsupported version: {v}"),
)),
}
}
#[cfg(feature = "std")]
pub fn tracing_context(&self) -> OpenTelemetryContext {
match self.tracing_context.as_ref() {
Some(tracing_context) => OpenTelemetryContext::from_remote_context(tracing_context),
None => OpenTelemetryContext::current(),
}
}
}
impl Display for TransportMessage {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"Message (onward route: {}, return route: {})",
self.onward_route, self.return_route
)
}
}
impl Encodable for TransportMessage {
fn encode(self) -> Result<Encoded> {
let tracing = if let Some(tracing_context) = self.tracing_context.as_ref() {
1 + crate::bare::size_of_slice(tracing_context.as_bytes())
} else {
1
};
let mut encoded = Vec::with_capacity(
1 + self.onward_route.encoded_size()
+ self.return_route.encoded_size()
+ crate::bare::size_of_slice(&self.payload)
+ tracing,
);
encoded.push(self.version);
self.onward_route.manual_encode(&mut encoded);
self.return_route.manual_encode(&mut encoded);
crate::bare::write_slice(&mut encoded, &self.payload);
if let Some(tracing_context) = self.tracing_context.as_ref() {
encoded.push(1);
crate::bare::write_str(&mut encoded, tracing_context);
} else {
encoded.push(0);
}
Ok(encoded)
}
}
impl Decodable for TransportMessage {
fn decode(slice: &[u8]) -> Result<Self> {
Self::internal_decode(slice).ok_or_else(|| {
Error::new(
Origin::Transport,
Kind::Protocol,
"Failed to decode TransportMessage",
)
})
}
}
impl TransportMessage {
fn internal_decode(slice: &[u8]) -> Option<Self> {
let mut index = 0;
let version = slice.get(index)?;
index += 1;
let onward_route = Route::manual_decode(slice, &mut index)?;
let return_route = Route::manual_decode(slice, &mut index)?;
let payload = crate::bare::read_slice(slice, &mut index)?;
let present = slice.get(index).unwrap_or(&0);
index += 1;
let tracing_context = if present == &1 {
crate::bare::read_str(slice, &mut index).map(|s| s.to_string())
} else {
None
};
Some(Self {
version: *version,
onward_route,
return_route,
payload: payload.to_vec(),
tracing_context,
})
}
}
#[derive(Debug, Clone, Eq, PartialEq, Message)]
pub struct TransportMessageV1 {
pub version: u8,
pub onward_route: Route,
pub return_route: Route,
pub payload: Vec<u8>,
}
impl TransportMessageV1 {
pub fn to_latest(self) -> TransportMessage {
TransportMessage {
version: PROTOCOL_VERSION_V1,
onward_route: self.onward_route,
return_route: self.return_route,
payload: self.payload,
tracing_context: None,
}
}
pub fn new(
onward_route: impl Into<Route>,
return_route: impl Into<Route>,
payload: Vec<u8>,
) -> Self {
Self {
version: 1,
onward_route: onward_route.into(),
return_route: return_route.into(),
payload,
}
}
}
impl Encodable for TransportMessageV1 {
fn encode(self) -> Result<Encoded> {
let mut encoded = Vec::with_capacity(
1 + self.onward_route.encoded_size()
+ self.return_route.encoded_size()
+ crate::bare::size_of_slice(&self.payload),
);
encoded.push(self.version);
self.onward_route.manual_encode(&mut encoded);
self.return_route.manual_encode(&mut encoded);
crate::bare::write_slice(&mut encoded, &self.payload);
encoded.push(0);
Ok(encoded)
}
}
impl Decodable for TransportMessageV1 {
fn decode(slice: &[u8]) -> Result<Self> {
Self::internal_decode(slice).ok_or_else(|| {
Error::new(
Origin::Transport,
Kind::Protocol,
"Failed to decode TransportMessage",
)
})
}
}
impl TransportMessageV1 {
fn internal_decode(slice: &[u8]) -> Option<Self> {
let mut index = 0;
let version = slice.get(index)?;
index += 1;
let onward_route = Route::manual_decode(slice, &mut index)?;
let return_route = Route::manual_decode(slice, &mut index)?;
let payload = crate::bare::read_slice(slice, &mut index)?;
Some(Self {
version: *version,
onward_route,
return_route,
payload: payload.to_vec(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{route, Encodable, TransportMessageV1};
#[test]
fn test_encode_decode() {
let transport_message_v1 =
TransportMessageV1::new(route!["onward"], route!["return"], vec![]);
let transport_message_v2 =
TransportMessage::latest(route!["onward"], route!["return"], vec![]);
let encoded_v1 = transport_message_v1.encode().unwrap();
let expected = TransportMessage::new(
PROTOCOL_VERSION_V1,
route!["onward"],
route!["return"],
vec![],
None,
);
assert_eq!(
TransportMessage::decode_message(encoded_v1).unwrap(),
expected
);
let encoded_v2 = transport_message_v2.clone().encode().unwrap();
assert_eq!(
TransportMessage::decode_message(encoded_v2).unwrap(),
transport_message_v2
);
let encoded_v3 = TransportMessage {
version: 3,
onward_route: route![],
return_route: route![],
payload: vec![],
tracing_context: None,
}
.encode()
.unwrap();
assert!(TransportMessage::decode_message(encoded_v3).is_err());
}
}