use crate::interceptor::SendMode;
use crate::node::{ActorId, NodeId};
use uuid::Uuid;
pub trait RemoteMessage: crate::message::Message + Send + 'static {}
pub trait MessageSerializer: Send + Sync + 'static {
fn name(&self) -> &'static str;
fn serialize(&self, value: &dyn std::any::Any) -> Result<Vec<u8>, SerializationError>;
fn deserialize(
&self,
bytes: &[u8],
type_name: &str,
) -> Result<Box<dyn std::any::Any + Send>, SerializationError>;
}
#[derive(Debug, Clone)]
pub struct SerializationError {
pub message: String,
}
impl SerializationError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for SerializationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "serialization error: {}", self.message)
}
}
impl std::error::Error for SerializationError {}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct WireEnvelope {
pub target: ActorId,
pub target_name: String,
pub message_type: String,
pub send_mode: SendMode,
pub headers: WireHeaders,
pub body: Vec<u8>,
pub request_id: Option<Uuid>,
pub version: Option<u32>,
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct WireHeaders {
pub entries: std::collections::HashMap<String, Vec<u8>>,
}
impl WireHeaders {
pub fn new() -> Self {
Self {
entries: std::collections::HashMap::new(),
}
}
pub fn insert(&mut self, name: String, value: Vec<u8>) {
self.entries.insert(name, value);
}
pub fn get(&self, name: &str) -> Option<&[u8]> {
self.entries.get(name).map(|v| v.as_slice())
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn to_headers(&self, registry: &HeaderRegistry) -> crate::message::Headers {
let mut headers = crate::message::Headers::new();
for (name, bytes) in &self.entries {
if let Some(header_value) = registry.deserialize(name, bytes) {
headers.insert_boxed(header_value);
}
}
headers
}
}
pub type HeaderDeserializerFn =
Box<dyn Fn(&[u8]) -> Option<Box<dyn crate::message::HeaderValue>> + Send + Sync>;
pub struct HeaderRegistry {
deserializers: std::collections::HashMap<String, HeaderDeserializerFn>,
}
impl HeaderRegistry {
pub fn new() -> Self {
Self {
deserializers: std::collections::HashMap::new(),
}
}
pub fn register(
&mut self,
header_name: impl Into<String>,
deserializer: impl Fn(&[u8]) -> Option<Box<dyn crate::message::HeaderValue>>
+ Send
+ Sync
+ 'static,
) {
self.deserializers
.insert(header_name.into(), Box::new(deserializer));
}
pub fn deserialize(
&self,
header_name: &str,
bytes: &[u8],
) -> Option<Box<dyn crate::message::HeaderValue>> {
let deser = self.deserializers.get(header_name)?;
deser(bytes)
}
pub fn len(&self) -> usize {
self.deserializers.len()
}
pub fn is_empty(&self) -> bool {
self.deserializers.is_empty()
}
}
impl Default for HeaderRegistry {
fn default() -> Self {
Self::new()
}
}
pub trait MessageVersionHandler: Send + Sync + 'static {
fn message_type(&self) -> &'static str;
fn migrate(&self, payload: &[u8], from_version: u32, to_version: u32) -> Option<Vec<u8>>;
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ClusterState {
pub local_node: NodeId,
pub nodes: Vec<NodeId>,
pub is_leader: bool,
pub wire_version: crate::version::WireVersion,
pub app_version: Option<String>,
pub peer_versions: std::collections::HashMap<NodeId, PeerVersionInfo>,
}
impl ClusterState {
pub fn new(local_node: NodeId, nodes: Vec<NodeId>) -> Self {
Self {
local_node,
nodes,
is_leader: false,
wire_version: crate::version::WireVersion::parse(
crate::version::DACTOR_WIRE_VERSION,
)
.expect("DACTOR_WIRE_VERSION must be valid"),
app_version: None,
peer_versions: std::collections::HashMap::new(),
}
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn contains(&self, node_id: &NodeId) -> bool {
self.nodes.contains(node_id)
}
pub fn peer_version(&self, node_id: &NodeId) -> Option<&PeerVersionInfo> {
self.peer_versions.get(node_id)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PeerVersionInfo {
pub wire_version: crate::version::WireVersion,
pub app_version: Option<String>,
pub adapter: String,
}
#[derive(Debug, Clone)]
pub struct DiscoveryError {
pub message: String,
}
impl DiscoveryError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for DiscoveryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for DiscoveryError {}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DiscoveredPeer {
pub node_id: NodeId,
pub address: String,
}
impl DiscoveredPeer {
pub fn new(node_id: NodeId, address: impl Into<String>) -> Self {
Self {
node_id,
address: address.into(),
}
}
pub fn from_address(address: impl Into<String>) -> Self {
let addr = address.into();
Self {
node_id: NodeId(addr.clone()),
address: addr,
}
}
}
#[async_trait::async_trait]
pub trait ClusterDiscovery: Send + Sync + 'static {
async fn discover(&self) -> Result<Vec<DiscoveredPeer>, DiscoveryError>;
}
pub struct StaticSeeds {
pub peers: Vec<DiscoveredPeer>,
}
impl StaticSeeds {
pub fn new(addresses: Vec<String>) -> Self {
Self {
peers: addresses.into_iter().map(DiscoveredPeer::from_address).collect(),
}
}
pub fn from_peers(peers: Vec<DiscoveredPeer>) -> Self {
Self { peers }
}
}
#[async_trait::async_trait]
impl ClusterDiscovery for StaticSeeds {
async fn discover(&self) -> Result<Vec<DiscoveredPeer>, DiscoveryError> {
Ok(self.peers.clone())
}
}
#[cfg(feature = "serde")]
pub struct JsonSerializer;
#[cfg(feature = "serde")]
impl JsonSerializer {
pub fn serialize_typed<T: serde::Serialize>(value: &T) -> Result<Vec<u8>, SerializationError> {
serde_json::to_vec(value)
.map_err(|e| SerializationError::new(format!("json serialize: {e}")))
}
pub fn deserialize_typed<T: serde::de::DeserializeOwned>(
bytes: &[u8],
) -> Result<T, SerializationError> {
serde_json::from_slice(bytes)
.map_err(|e| SerializationError::new(format!("json deserialize: {e}")))
}
}
#[cfg(feature = "serde")]
pub fn build_tell_envelope<M: serde::Serialize>(
target: crate::node::ActorId,
target_name: impl Into<String>,
msg: &M,
headers: WireHeaders,
) -> Result<WireEnvelope, SerializationError> {
let body = JsonSerializer::serialize_typed(msg)?;
Ok(WireEnvelope {
target,
target_name: target_name.into(),
message_type: std::any::type_name::<M>().to_string(),
send_mode: crate::interceptor::SendMode::Tell,
headers,
body,
request_id: None,
version: None,
})
}
#[cfg(feature = "serde")]
pub fn build_ask_envelope<M: serde::Serialize>(
target: crate::node::ActorId,
target_name: impl Into<String>,
msg: &M,
headers: WireHeaders,
request_id: uuid::Uuid,
) -> Result<WireEnvelope, SerializationError> {
let body = JsonSerializer::serialize_typed(msg)?;
Ok(WireEnvelope {
target,
target_name: target_name.into(),
message_type: std::any::type_name::<M>().to_string(),
send_mode: crate::interceptor::SendMode::Ask,
headers,
body,
request_id: Some(request_id),
version: None,
})
}
#[cfg(feature = "serde")]
pub fn build_wire_envelope<M: serde::Serialize>(
target: crate::node::ActorId,
target_name: impl Into<String>,
msg: &M,
send_mode: crate::interceptor::SendMode,
headers: WireHeaders,
request_id: Option<uuid::Uuid>,
version: Option<u32>,
) -> Result<WireEnvelope, SerializationError> {
let body = JsonSerializer::serialize_typed(msg)?;
Ok(WireEnvelope {
target,
target_name: target_name.into(),
message_type: std::any::type_name::<M>().to_string(),
send_mode,
headers,
body,
request_id,
version,
})
}
pub fn receive_envelope_body(
envelope: &WireEnvelope,
registry: &crate::type_registry::TypeRegistry,
) -> Result<Box<dyn std::any::Any + Send>, SerializationError> {
registry.deserialize(&envelope.message_type, &envelope.body)
}
pub fn receive_envelope_body_versioned(
envelope: &WireEnvelope,
registry: &crate::type_registry::TypeRegistry,
version_handlers: &std::collections::HashMap<String, Box<dyn MessageVersionHandler>>,
expected_version: Option<u32>,
) -> Result<Box<dyn std::any::Any + Send>, SerializationError> {
let body = match (envelope.version, expected_version) {
(Some(received), Some(expected)) if received != expected => {
if let Some(handler) = version_handlers.get(&envelope.message_type) {
handler
.migrate(&envelope.body, received, expected)
.ok_or_else(|| {
SerializationError::new(format!(
"{}: cannot migrate from v{received} to v{expected}",
envelope.message_type
))
})?
} else {
envelope.body.clone()
}
}
_ => envelope.body.clone(),
};
registry.deserialize(&envelope.message_type, &body)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::interceptor::SendMode;
use crate::node::NodeId;
#[test]
fn test_wire_envelope_construction() {
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("node-1".into()),
local: 42,
},
target_name: "test".into(),
message_type: "my_crate::Increment".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: vec![1, 2, 3],
request_id: None,
version: Some(1),
};
assert_eq!(envelope.message_type, "my_crate::Increment");
assert_eq!(envelope.body, vec![1, 2, 3]);
assert_eq!(envelope.version, Some(1));
}
#[test]
fn test_wire_headers() {
let mut headers = WireHeaders::new();
assert!(headers.is_empty());
headers.insert("trace-id".into(), b"abc-123".to_vec());
headers.insert("priority".into(), vec![128]);
assert_eq!(headers.len(), 2);
assert_eq!(headers.get("trace-id").unwrap(), b"abc-123");
assert_eq!(headers.get("priority").unwrap(), &[128]);
assert!(headers.get("missing").is_none());
}
#[test]
fn test_serialization_error() {
let err = SerializationError::new("invalid format");
assert!(format!("{}", err).contains("invalid format"));
}
#[test]
fn test_cluster_state() {
let mut state = ClusterState::new(
NodeId("node-1".into()),
vec![
NodeId("node-1".into()),
NodeId("node-2".into()),
NodeId("node-3".into()),
],
);
state.is_leader = true;
assert_eq!(state.node_count(), 3);
assert!(state.contains(&NodeId("node-2".into())));
assert!(!state.contains(&NodeId("node-99".into())));
assert!(state.is_leader);
assert!(state.app_version.is_none());
assert_eq!(
state.wire_version,
crate::version::WireVersion::parse(crate::version::DACTOR_WIRE_VERSION).unwrap()
);
assert!(state.peer_versions.is_empty());
}
#[test]
fn test_cluster_state_with_app_version() {
let mut state = ClusterState::new(
NodeId("node-1".into()),
vec![NodeId("node-1".into()), NodeId("node-2".into())],
);
state.app_version = Some("2.3.1".into());
assert_eq!(state.app_version.as_deref(), Some("2.3.1"));
}
#[test]
fn test_cluster_state_peer_versions() {
let mut state = ClusterState::new(
NodeId("node-1".into()),
vec![
NodeId("node-1".into()),
NodeId("node-2".into()),
NodeId("node-3".into()),
],
);
state.peer_versions.insert(
NodeId("node-2".into()),
PeerVersionInfo {
wire_version: crate::version::WireVersion::parse("0.2.0").unwrap(),
app_version: Some("1.0.0".into()),
adapter: "ractor".into(),
},
);
state.peer_versions.insert(
NodeId("node-3".into()),
PeerVersionInfo {
wire_version: crate::version::WireVersion::parse("0.2.0").unwrap(),
app_version: Some("1.0.1".into()),
adapter: "ractor".into(),
},
);
let p2 = state.peer_version(&NodeId("node-2".into())).unwrap();
assert_eq!(p2.app_version.as_deref(), Some("1.0.0"));
assert_eq!(p2.adapter, "ractor");
let p3 = state.peer_version(&NodeId("node-3".into())).unwrap();
assert_eq!(p3.app_version.as_deref(), Some("1.0.1"));
assert!(state.peer_version(&NodeId("node-1".into())).is_none());
assert!(state.peer_version(&NodeId("node-99".into())).is_none());
}
#[test]
fn test_cluster_state_mixed_app_versions() {
let mut state = ClusterState::new(
NodeId("node-1".into()),
vec![
NodeId("node-1".into()),
NodeId("node-2".into()),
NodeId("node-3".into()),
],
);
state.app_version = Some("2.3.1".into());
state.peer_versions.insert(
NodeId("node-2".into()),
PeerVersionInfo {
wire_version: crate::version::WireVersion::parse("0.2.0").unwrap(),
app_version: Some("2.3.0".into()),
adapter: "ractor".into(),
},
);
state.peer_versions.insert(
NodeId("node-3".into()),
PeerVersionInfo {
wire_version: crate::version::WireVersion::parse("0.2.0").unwrap(),
app_version: Some("2.3.1".into()),
adapter: "ractor".into(),
},
);
let total = state.node_count();
let on_latest = 1 + state.peer_versions.values()
.filter(|p| p.app_version.as_deref() == Some("2.3.1"))
.count();
assert_eq!(total, 3);
assert_eq!(on_latest, 2); }
#[tokio::test]
async fn test_static_seeds() {
let seeds = StaticSeeds::new(vec!["node1:4697".into(), "node2:4697".into()]);
let discovered = seeds.discover().await.unwrap();
assert_eq!(discovered.len(), 2);
assert_eq!(discovered[0].address, "node1:4697");
assert_eq!(discovered[0].node_id, NodeId("node1:4697".into()));
}
#[tokio::test]
async fn test_static_seeds_from_peers() {
let seeds = StaticSeeds::from_peers(vec![
DiscoveredPeer::new(NodeId("node-a".into()), "10.0.0.1:9000"),
DiscoveredPeer::new(NodeId("node-b".into()), "10.0.0.2:9000"),
]);
let discovered = seeds.discover().await.unwrap();
assert_eq!(discovered.len(), 2);
assert_eq!(discovered[0].node_id, NodeId("node-a".into()));
assert_eq!(discovered[0].address, "10.0.0.1:9000");
assert_eq!(discovered[1].node_id, NodeId("node-b".into()));
}
#[test]
fn test_discovered_peer_from_address() {
let peer = DiscoveredPeer::from_address("10.0.0.1:9000");
assert_eq!(peer.node_id, NodeId("10.0.0.1:9000".into()));
assert_eq!(peer.address, "10.0.0.1:9000");
}
#[test]
fn test_wire_envelope_with_request_id() {
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "Ask".into(),
send_mode: SendMode::Ask,
headers: WireHeaders::new(),
body: vec![],
request_id: Some(Uuid::new_v4()),
version: None,
};
assert!(envelope.request_id.is_some());
assert_eq!(envelope.send_mode, SendMode::Ask);
}
#[test]
fn test_header_registry_roundtrip() {
use crate::message::HeaderValue;
use std::any::Any;
#[derive(Debug, Clone)]
struct TraceId(String);
impl HeaderValue for TraceId {
fn header_name(&self) -> &'static str {
"trace-id"
}
fn to_bytes(&self) -> Option<Vec<u8>> {
Some(self.0.as_bytes().to_vec())
}
fn as_any(&self) -> &dyn Any {
self
}
}
let mut registry = HeaderRegistry::new();
registry.register("trace-id", |bytes: &[u8]| {
let s = String::from_utf8(bytes.to_vec()).ok()?;
Some(Box::new(TraceId(s)) as Box<dyn HeaderValue>)
});
assert_eq!(registry.len(), 1);
assert!(!registry.is_empty());
let mut headers = crate::message::Headers::new();
headers.insert(TraceId("abc-123".into()));
let wire = headers.to_wire();
assert_eq!(wire.len(), 1);
assert_eq!(wire.get("trace-id").unwrap(), b"abc-123");
let restored = wire.to_headers(®istry);
let trace = restored.get::<TraceId>().unwrap();
assert_eq!(trace.0, "abc-123");
}
#[test]
fn test_header_registry_missing_deserializer() {
let registry = HeaderRegistry::new();
assert!(registry.deserialize("unknown", &[]).is_none());
}
#[test]
fn test_headers_to_wire_skips_local_only() {
use crate::message::HeaderValue;
use std::any::Any;
#[derive(Debug)]
struct LocalOnlyHeader;
impl HeaderValue for LocalOnlyHeader {
fn header_name(&self) -> &'static str {
"local-only"
}
fn to_bytes(&self) -> Option<Vec<u8>> {
None
}
fn as_any(&self) -> &dyn Any {
self
}
}
let mut headers = crate::message::Headers::new();
headers.insert(LocalOnlyHeader);
let wire = headers.to_wire();
assert!(wire.is_empty());
}
#[test]
fn test_receive_envelope_body() {
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register("test::Amount", |bytes: &[u8]| {
if bytes.len() != 8 {
return Err(SerializationError::new("expected 8 bytes"));
}
let val = u64::from_be_bytes(bytes.try_into().unwrap());
Ok(Box::new(val))
});
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::Amount".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: 42u64.to_be_bytes().to_vec(),
request_id: None,
version: None,
};
let any = receive_envelope_body(&envelope, ®istry).unwrap();
let val = any.downcast::<u64>().unwrap();
assert_eq!(*val, 42);
}
#[test]
fn test_receive_envelope_body_unknown_type() {
let registry = crate::type_registry::TypeRegistry::new();
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "unknown::Type".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: vec![],
request_id: None,
version: None,
};
let result = receive_envelope_body(&envelope, ®istry);
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("no deserializer"));
}
#[test]
fn test_version_mismatch_with_handler() {
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register("test::Versioned", |bytes: &[u8]| {
if bytes.len() != 8 {
return Err(SerializationError::new("expected 8 bytes"));
}
let val = u64::from_be_bytes(bytes.try_into().unwrap());
Ok(Box::new(val))
});
struct DoubleMigrator;
impl MessageVersionHandler for DoubleMigrator {
fn message_type(&self) -> &'static str {
"test::Versioned"
}
fn migrate(&self, payload: &[u8], _from: u32, _to: u32) -> Option<Vec<u8>> {
if payload.len() != 8 {
return None;
}
let val = u64::from_be_bytes(payload.try_into().unwrap());
Some((val * 2).to_be_bytes().to_vec())
}
}
let mut version_handlers: std::collections::HashMap<
String,
Box<dyn MessageVersionHandler>,
> = std::collections::HashMap::new();
version_handlers.insert("test::Versioned".into(), Box::new(DoubleMigrator));
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::Versioned".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: 21u64.to_be_bytes().to_vec(),
request_id: None,
version: Some(1), };
let any = receive_envelope_body_versioned(
&envelope,
®istry,
&version_handlers,
Some(2), )
.unwrap();
let val = any.downcast::<u64>().unwrap();
assert_eq!(*val, 42); }
#[test]
fn test_version_match_skips_migration() {
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register("test::Same", |bytes: &[u8]| Ok(Box::new(bytes.to_vec())));
struct PanicMigrator;
impl MessageVersionHandler for PanicMigrator {
fn message_type(&self) -> &'static str {
"test::Same"
}
fn migrate(&self, _payload: &[u8], _from: u32, _to: u32) -> Option<Vec<u8>> {
panic!("migrate should not be called when versions match");
}
}
let mut version_handlers: std::collections::HashMap<
String,
Box<dyn MessageVersionHandler>,
> = std::collections::HashMap::new();
version_handlers.insert("test::Same".into(), Box::new(PanicMigrator));
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::Same".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: vec![1, 2, 3],
request_id: None,
version: Some(2),
};
let any = receive_envelope_body_versioned(
&envelope,
®istry,
&version_handlers,
Some(2), )
.unwrap();
let val = any.downcast::<Vec<u8>>().unwrap();
assert_eq!(*val, vec![1, 2, 3]);
}
#[test]
fn test_version_mismatch_no_handler_falls_through() {
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register("test::NoHandler", |bytes: &[u8]| Ok(Box::new(bytes.to_vec())));
let version_handlers: std::collections::HashMap<String, Box<dyn MessageVersionHandler>> =
std::collections::HashMap::new();
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::NoHandler".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: vec![10, 20],
request_id: None,
version: Some(1), };
let any = receive_envelope_body_versioned(
&envelope,
®istry,
&version_handlers,
Some(2),
)
.unwrap();
let val = any.downcast::<Vec<u8>>().unwrap();
assert_eq!(*val, vec![10, 20]);
}
#[test]
fn test_version_mismatch_handler_returns_none_rejects() {
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register("test::FailMigrate", |bytes: &[u8]| Ok(Box::new(bytes.to_vec())));
struct RejectingMigrator;
impl MessageVersionHandler for RejectingMigrator {
fn message_type(&self) -> &'static str {
"test::FailMigrate"
}
fn migrate(&self, _payload: &[u8], _from: u32, _to: u32) -> Option<Vec<u8>> {
None }
}
let mut version_handlers: std::collections::HashMap<
String,
Box<dyn MessageVersionHandler>,
> = std::collections::HashMap::new();
version_handlers.insert("test::FailMigrate".into(), Box::new(RejectingMigrator));
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::FailMigrate".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: vec![1, 2, 3],
request_id: None,
version: Some(1), };
let result = receive_envelope_body_versioned(
&envelope,
®istry,
&version_handlers,
Some(2), );
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.message.contains("cannot migrate from v1 to v2"),
"expected migration rejection, got: {}",
err.message
);
}
#[test]
fn test_version_none_on_sender_skips_migration() {
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register("test::OptionalVersion", |bytes: &[u8]| Ok(Box::new(bytes.to_vec())));
struct PanicMigrator;
impl MessageVersionHandler for PanicMigrator {
fn message_type(&self) -> &'static str {
"test::OptionalVersion"
}
fn migrate(&self, _payload: &[u8], _from: u32, _to: u32) -> Option<Vec<u8>> {
panic!("migrate should not be called when sender has no version");
}
}
let mut version_handlers: std::collections::HashMap<
String,
Box<dyn MessageVersionHandler>,
> = std::collections::HashMap::new();
version_handlers.insert("test::OptionalVersion".into(), Box::new(PanicMigrator));
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::OptionalVersion".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: vec![7, 8, 9],
request_id: None,
version: None, };
let any = receive_envelope_body_versioned(
&envelope,
®istry,
&version_handlers,
Some(2), )
.unwrap();
let val = any.downcast::<Vec<u8>>().unwrap();
assert_eq!(*val, vec![7, 8, 9]);
}
#[test]
fn test_version_none_on_both_sides_skips_migration() {
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register("test::NoVersion", |bytes: &[u8]| Ok(Box::new(bytes.to_vec())));
let version_handlers: std::collections::HashMap<String, Box<dyn MessageVersionHandler>> =
std::collections::HashMap::new();
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::NoVersion".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: vec![4, 5, 6],
request_id: None,
version: None,
};
let any = receive_envelope_body_versioned(
&envelope,
®istry,
&version_handlers,
None, )
.unwrap();
let val = any.downcast::<Vec<u8>>().unwrap();
assert_eq!(*val, vec![4, 5, 6]);
}
#[test]
fn test_version_none_on_receiver_skips_migration() {
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register("test::ReceiverNone", |bytes: &[u8]| Ok(Box::new(bytes.to_vec())));
struct PanicMigrator;
impl MessageVersionHandler for PanicMigrator {
fn message_type(&self) -> &'static str {
"test::ReceiverNone"
}
fn migrate(&self, _payload: &[u8], _from: u32, _to: u32) -> Option<Vec<u8>> {
panic!("migrate should not be called when receiver has no version expectation");
}
}
let mut version_handlers: std::collections::HashMap<
String,
Box<dyn MessageVersionHandler>,
> = std::collections::HashMap::new();
version_handlers.insert("test::ReceiverNone".into(), Box::new(PanicMigrator));
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::ReceiverNone".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: vec![11, 22, 33],
request_id: None,
version: Some(3), };
let any = receive_envelope_body_versioned(
&envelope,
®istry,
&version_handlers,
None, )
.unwrap();
let val = any.downcast::<Vec<u8>>().unwrap();
assert_eq!(*val, vec![11, 22, 33]);
}
#[test]
fn test_version_backward_migration_v2_to_v1() {
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register("test::Backward", |bytes: &[u8]| {
if bytes.len() != 8 {
return Err(SerializationError::new("expected 8 bytes"));
}
let val = u64::from_be_bytes(bytes.try_into().unwrap());
Ok(Box::new(val))
});
struct HalveMigrator;
impl MessageVersionHandler for HalveMigrator {
fn message_type(&self) -> &'static str {
"test::Backward"
}
fn migrate(&self, payload: &[u8], from: u32, to: u32) -> Option<Vec<u8>> {
if from > to {
let val = u64::from_be_bytes(payload.try_into().ok()?);
Some((val / 2).to_be_bytes().to_vec())
} else {
None
}
}
}
let mut version_handlers: std::collections::HashMap<
String,
Box<dyn MessageVersionHandler>,
> = std::collections::HashMap::new();
version_handlers.insert("test::Backward".into(), Box::new(HalveMigrator));
let envelope = WireEnvelope {
target: ActorId {
node: NodeId("n".into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::Backward".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: 100u64.to_be_bytes().to_vec(),
request_id: None,
version: Some(2), };
let any = receive_envelope_body_versioned(
&envelope,
®istry,
&version_handlers,
Some(1), )
.unwrap();
let val = any.downcast::<u64>().unwrap();
assert_eq!(*val, 50); }
#[cfg(feature = "serde")]
mod serde_tests {
use super::*;
#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
struct Increment {
amount: u64,
}
#[test]
fn json_serializer_roundtrip() {
let msg = Increment { amount: 42 };
let bytes = JsonSerializer::serialize_typed(&msg).unwrap();
let deserialized: Increment = JsonSerializer::deserialize_typed(&bytes).unwrap();
assert_eq!(deserialized, msg);
}
#[test]
fn json_serializer_invalid_bytes() {
let result = JsonSerializer::deserialize_typed::<Increment>(b"not json");
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("json deserialize"));
}
#[test]
fn build_tell_envelope_roundtrip() {
let target = ActorId {
node: NodeId("node-2".into()),
local: 7,
};
let msg = Increment { amount: 100 };
let envelope =
build_tell_envelope(target.clone(), "counter", &msg, WireHeaders::new()).unwrap();
assert_eq!(envelope.target, target);
assert_eq!(envelope.send_mode, SendMode::Tell);
assert!(envelope.request_id.is_none());
assert!(envelope.message_type.contains("Increment"));
let deserialized: Increment =
JsonSerializer::deserialize_typed(&envelope.body).unwrap();
assert_eq!(deserialized.amount, 100);
}
#[test]
fn build_ask_envelope_roundtrip() {
let target = ActorId {
node: NodeId("node-3".into()),
local: 42,
};
let msg = Increment { amount: 5 };
let request_id = Uuid::new_v4();
let envelope = build_ask_envelope(
target.clone(),
"counter",
&msg,
WireHeaders::new(),
request_id,
)
.unwrap();
assert_eq!(envelope.target, target);
assert_eq!(envelope.send_mode, SendMode::Ask);
assert_eq!(envelope.request_id, Some(request_id));
}
#[test]
fn full_pipeline_send_and_receive() {
let target = ActorId {
node: NodeId("node-2".into()),
local: 1,
};
let msg = Increment { amount: 77 };
let envelope =
build_tell_envelope(target, "counter", &msg, WireHeaders::new()).unwrap();
let mut registry = crate::type_registry::TypeRegistry::new();
registry.register_type::<Increment>();
let any = receive_envelope_body(&envelope, ®istry).unwrap();
let received = any.downcast::<Increment>().unwrap();
assert_eq!(received.amount, 77);
}
#[test]
fn full_pipeline_with_headers() {
use crate::message::HeaderValue;
use std::any::Any;
#[derive(Debug, Clone)]
struct Priority(u8);
impl HeaderValue for Priority {
fn header_name(&self) -> &'static str {
"priority"
}
fn to_bytes(&self) -> Option<Vec<u8>> {
Some(vec![self.0])
}
fn as_any(&self) -> &dyn Any {
self
}
}
let mut headers = crate::message::Headers::new();
headers.insert(Priority(5));
let wire_headers = headers.to_wire();
let target = ActorId {
node: NodeId("node-2".into()),
local: 1,
};
let msg = Increment { amount: 10 };
let envelope = build_tell_envelope(target, "counter", &msg, wire_headers).unwrap();
let mut header_registry = HeaderRegistry::new();
header_registry.register("priority", |bytes: &[u8]| {
if bytes.len() != 1 {
return None;
}
Some(Box::new(Priority(bytes[0])) as Box<dyn HeaderValue>)
});
let restored_headers = envelope.headers.to_headers(&header_registry);
let priority = restored_headers.get::<Priority>().unwrap();
assert_eq!(priority.0, 5);
}
}
}