use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use crate::actor::{Actor, ActorRef, AskReply, ReduceHandler, Handler, ExpandHandler, TransformHandler};
use crate::errors::{ActorSendError, RuntimeError};
use crate::interceptor::{Disposition, OutboundContext, OutboundInterceptor, SendMode};
use crate::message::{Headers, Message, RuntimeHeaders};
use crate::node::ActorId;
use crate::remote::{SerializationError, WireEnvelope};
use crate::stream::{BatchConfig, BoxStream};
use crate::transport::Transport;
type SerializeFn =
Arc<dyn Fn(&dyn Any) -> Result<(String, Vec<u8>), SerializationError> + Send + Sync>;
type DeserializeReplyFn =
Arc<dyn Fn(&[u8]) -> Result<Box<dyn Any + Send>, SerializationError> + Send + Sync>;
struct AskEntry {
serialize: SerializeFn,
deserialize_reply: DeserializeReplyFn,
}
pub struct RemoteActorRef<A: Actor> {
id: ActorId,
name: String,
transport: Arc<dyn Transport>,
tell_serializers: Arc<HashMap<TypeId, SerializeFn>>,
ask_entries: Arc<HashMap<TypeId, AskEntry>>,
outbound_interceptors: Arc<Vec<Arc<dyn OutboundInterceptor>>>,
_phantom: PhantomData<A>,
}
impl<A: Actor> Clone for RemoteActorRef<A> {
fn clone(&self) -> Self {
Self {
id: self.id.clone(),
name: self.name.clone(),
transport: Arc::clone(&self.transport),
tell_serializers: Arc::clone(&self.tell_serializers),
ask_entries: Arc::clone(&self.ask_entries),
outbound_interceptors: Arc::clone(&self.outbound_interceptors),
_phantom: PhantomData,
}
}
}
enum PipelineResult {
Continue,
Delay(std::time::Duration),
Dropped,
}
impl<A: Actor> RemoteActorRef<A> {
pub fn target_node(&self) -> &crate::node::NodeId {
&self.id.node
}
fn run_outbound_pipeline(
&self,
message_type: &'static str,
send_mode: SendMode,
headers: &mut Headers,
message: &dyn Any,
) -> Result<PipelineResult, ActorSendError> {
if self.outbound_interceptors.is_empty() {
return Ok(PipelineResult::Continue);
}
let runtime_headers = RuntimeHeaders::new();
let ctx = OutboundContext {
target_id: self.id.clone(),
target_name: &self.name,
message_type,
send_mode,
remote: true,
};
let mut delay = None;
for interceptor in self.outbound_interceptors.iter() {
match interceptor.on_send(&ctx, &runtime_headers, headers, message) {
Disposition::Continue => {}
Disposition::Reject(reason) => {
return Err(ActorSendError(format!(
"rejected by outbound interceptor '{}': {}",
interceptor.name(),
reason
)));
}
Disposition::Drop => {
return Ok(PipelineResult::Dropped);
}
Disposition::Delay(d) => {
delay = Some(match delay {
Some(existing) if existing > d => existing,
_ => d,
});
}
Disposition::Retry(retry_after) => {
return Err(ActorSendError(format!(
"retry after {:?} (from outbound interceptor '{}')",
retry_after,
interceptor.name()
)));
}
}
}
Ok(match delay {
Some(d) => PipelineResult::Delay(d),
None => PipelineResult::Continue,
})
}
}
impl<A: Actor + Sync> ActorRef<A> for RemoteActorRef<A> {
fn id(&self) -> ActorId {
self.id.clone()
}
fn name(&self) -> String {
self.name.clone()
}
fn is_alive(&self) -> bool {
let transport = Arc::clone(&self.transport);
let node = self.id.node.clone();
let _ = (transport, node);
true
}
fn stop(&self) {
tracing::warn!(
actor_id = %self.id,
"RemoteActorRef::stop() is a no-op — remote actor stop requires CancelManager"
);
}
fn tell<M>(&self, msg: M) -> Result<(), ActorSendError>
where
A: Handler<M>,
M: Message<Reply = ()>,
{
let mut headers = Headers::new();
let pipeline_result = self.run_outbound_pipeline(
std::any::type_name::<M>(),
SendMode::Tell,
&mut headers,
&msg as &dyn Any,
)?;
if matches!(pipeline_result, PipelineResult::Dropped) {
return Ok(());
}
let type_id = TypeId::of::<M>();
let serializer = self.tell_serializers.get(&type_id).ok_or_else(|| {
ActorSendError(format!(
"message type '{}' not registered for remote send to {}",
std::any::type_name::<M>(),
self.id
))
})?;
let (type_name, body) = serializer(&msg as &dyn Any)
.map_err(|e| ActorSendError(format!("serialization failed: {}", e.message)))?;
let wire_headers = headers.to_wire();
let envelope = WireEnvelope {
target: self.id.clone(),
target_name: self.name.clone(),
message_type: type_name,
send_mode: SendMode::Tell,
headers: wire_headers,
body,
request_id: None,
version: None,
};
let transport = Arc::clone(&self.transport);
let target_node = self.id.node.clone();
let delay = match pipeline_result {
PipelineResult::Delay(d) => Some(d),
_ => None,
};
tokio::spawn(async move {
if let Some(d) = delay {
tokio::time::sleep(d).await;
}
if let Err(e) = transport.send(&target_node, envelope).await {
tracing::error!(
target_node = %target_node,
error = %e,
"remote tell failed"
);
}
});
Ok(())
}
fn ask<M>(
&self,
msg: M,
cancel: Option<CancellationToken>,
) -> Result<AskReply<M::Reply>, ActorSendError>
where
A: Handler<M>,
M: Message,
{
let mut headers = Headers::new();
let pipeline_result = self.run_outbound_pipeline(
std::any::type_name::<M>(),
SendMode::Ask,
&mut headers,
&msg as &dyn Any,
)?;
if matches!(pipeline_result, PipelineResult::Dropped) {
return Err(ActorSendError(
"message dropped by outbound interceptor".into(),
));
}
let type_id = TypeId::of::<M>();
let entry = self.ask_entries.get(&type_id).ok_or_else(|| {
ActorSendError(format!(
"ask message type '{}' not registered for remote send to {}",
std::any::type_name::<M>(),
self.id
))
})?;
let (type_name, body) = (entry.serialize)(&msg as &dyn Any)
.map_err(|e| ActorSendError(format!("serialization failed: {}", e.message)))?;
let wire_headers = headers.to_wire();
let request_id = uuid::Uuid::new_v4();
let envelope = WireEnvelope {
target: self.id.clone(),
target_name: self.name.clone(),
message_type: type_name,
send_mode: SendMode::Ask,
headers: wire_headers,
body,
request_id: Some(request_id),
version: None,
};
let (tx, rx) = oneshot::channel::<Result<M::Reply, RuntimeError>>();
let transport = Arc::clone(&self.transport);
let target_node = self.id.node.clone();
let deserialize_reply = Arc::clone(&entry.deserialize_reply);
let outbound_interceptors = Arc::clone(&self.outbound_interceptors);
let ref_id = self.id.clone();
let ref_name = self.name.clone();
let request_headers = headers;
let delay = match pipeline_result {
PipelineResult::Delay(d) => Some(d),
_ => None,
};
tokio::spawn(async move {
if let Some(d) = delay {
tokio::time::sleep(d).await;
}
let result = if let Some(cancel_token) = cancel {
tokio::select! {
result = transport.send_request(&target_node, envelope) => result,
_ = cancel_token.cancelled() => {
let _ = tx.send(Err(RuntimeError::Cancelled));
return;
}
}
} else {
transport.send_request(&target_node, envelope).await
};
let notify_interceptors = |outcome: &crate::interceptor::Outcome<'_>| {
if !outbound_interceptors.is_empty() {
let ctx = OutboundContext {
target_id: ref_id.clone(),
target_name: &ref_name,
message_type: std::any::type_name::<M>(),
send_mode: SendMode::Ask,
remote: true,
};
let runtime_headers = RuntimeHeaders::new();
for interceptor in outbound_interceptors.iter() {
interceptor.on_reply(&ctx, &runtime_headers, &request_headers, outcome);
}
}
};
match result {
Ok(reply_envelope) => match deserialize_reply(&reply_envelope.body) {
Ok(any_reply) => {
if let Ok(reply) = (any_reply as Box<dyn Any + Send>).downcast::<M::Reply>()
{
let outcome = crate::interceptor::Outcome::AskSuccess {
reply: reply.as_ref(),
};
notify_interceptors(&outcome);
let _ = tx.send(Ok(*reply));
} else {
let _ = tx.send(Err(RuntimeError::Send(ActorSendError(
"reply type mismatch".into(),
))));
}
}
Err(e) => {
let error_msg = format!("reply deserialization failed: {}", e.message);
let outcome = crate::interceptor::Outcome::HandlerError {
error: crate::actor::ActorError::internal(&error_msg),
};
notify_interceptors(&outcome);
let _ = tx.send(Err(RuntimeError::Send(ActorSendError(error_msg))));
}
},
Err(e) => {
let error_msg = format!("remote ask failed: {}", e);
let outcome = crate::interceptor::Outcome::HandlerError {
error: crate::actor::ActorError::internal(&error_msg),
};
notify_interceptors(&outcome);
let _ = tx.send(Err(RuntimeError::Send(ActorSendError(error_msg))));
}
}
});
Ok(AskReply::new(rx))
}
fn expand<M, OutputItem>(
&self,
_msg: M,
_buffer: usize,
_batch_config: Option<BatchConfig>,
_cancel: Option<CancellationToken>,
) -> Result<BoxStream<OutputItem>, ActorSendError>
where
A: ExpandHandler<M, OutputItem>,
M: Send + 'static,
OutputItem: Send + 'static,
{
Err(ActorSendError(
"remote stream not yet implemented — requires streaming transport support".into(),
))
}
fn reduce<InputItem, Reply>(
&self,
_input: BoxStream<InputItem>,
_buffer: usize,
_batch_config: Option<BatchConfig>,
_cancel: Option<CancellationToken>,
) -> Result<AskReply<Reply>, ActorSendError>
where
A: ReduceHandler<InputItem, Reply>,
InputItem: Send + 'static,
Reply: Send + 'static,
{
Err(ActorSendError(
"remote feed not yet implemented — requires streaming transport support".into(),
))
}
fn transform<InputItem, OutputItem>(
&self,
_input: BoxStream<InputItem>,
_buffer: usize,
_batch_config: Option<BatchConfig>,
_cancel: Option<CancellationToken>,
) -> Result<BoxStream<OutputItem>, ActorSendError>
where
A: TransformHandler<InputItem, OutputItem>,
InputItem: Send + 'static,
OutputItem: Send + 'static,
{
Err(ActorSendError(
"remote transform not yet implemented — requires streaming transport support".into(),
))
}
}
pub struct RemoteActorRefBuilder<A: Actor> {
id: ActorId,
name: String,
transport: Arc<dyn Transport>,
tell_serializers: HashMap<TypeId, SerializeFn>,
ask_entries: HashMap<TypeId, AskEntry>,
outbound_interceptors: Vec<Arc<dyn OutboundInterceptor>>,
_phantom: PhantomData<A>,
}
impl<A: Actor> RemoteActorRefBuilder<A> {
pub fn new(id: ActorId, name: impl Into<String>, transport: Arc<dyn Transport>) -> Self {
Self {
id,
name: name.into(),
transport,
tell_serializers: HashMap::new(),
ask_entries: HashMap::new(),
outbound_interceptors: Vec::new(),
_phantom: PhantomData,
}
}
pub fn add_outbound_interceptor(mut self, interceptor: Arc<dyn OutboundInterceptor>) -> Self {
self.outbound_interceptors.push(interceptor);
self
}
pub fn register_tell_with(
mut self,
type_id: TypeId,
type_name: &'static str,
serialize: impl Fn(&dyn Any) -> Result<Vec<u8>, SerializationError> + Send + Sync + 'static,
) -> Self {
self.tell_serializers.insert(
type_id,
Arc::new(move |any: &dyn Any| {
let bytes = serialize(any)?;
Ok((type_name.to_string(), bytes))
}),
);
self
}
pub fn register_ask_with(
mut self,
type_id: TypeId,
type_name: &'static str,
serialize: impl Fn(&dyn Any) -> Result<Vec<u8>, SerializationError> + Send + Sync + 'static,
deserialize_reply: impl Fn(&[u8]) -> Result<Box<dyn Any + Send>, SerializationError>
+ Send
+ Sync
+ 'static,
) -> Self {
let type_name_owned = type_name.to_string();
self.ask_entries.insert(
type_id,
AskEntry {
serialize: Arc::new(move |any: &dyn Any| {
let bytes = serialize(any)?;
Ok((type_name_owned.clone(), bytes))
}),
deserialize_reply: Arc::new(deserialize_reply),
},
);
self
}
#[cfg(feature = "serde")]
pub fn register_tell<M>(self) -> Self
where
M: Message<Reply = ()> + serde::Serialize + 'static,
{
self.register_tell_with(
TypeId::of::<M>(),
std::any::type_name::<M>(),
|any: &dyn Any| {
let msg = any.downcast_ref::<M>().ok_or_else(|| {
SerializationError::new("type downcast failed in tell serializer")
})?;
serde_json::to_vec(msg)
.map_err(|e| SerializationError::new(format!("json serialize: {e}")))
},
)
}
#[cfg(feature = "serde")]
pub fn register_ask<M>(self) -> Self
where
M: Message + serde::Serialize + 'static,
M::Reply: serde::de::DeserializeOwned + Send + 'static,
{
self.register_ask_with(
TypeId::of::<M>(),
std::any::type_name::<M>(),
|any: &dyn Any| {
let msg = any.downcast_ref::<M>().ok_or_else(|| {
SerializationError::new("type downcast failed in ask serializer")
})?;
serde_json::to_vec(msg)
.map_err(|e| SerializationError::new(format!("json serialize: {e}")))
},
|bytes: &[u8]| {
let reply: M::Reply = serde_json::from_slice(bytes)
.map_err(|e| SerializationError::new(format!("json deserialize reply: {e}")))?;
Ok(Box::new(reply) as Box<dyn Any + Send>)
},
)
}
pub fn build(self) -> RemoteActorRef<A> {
RemoteActorRef {
id: self.id,
name: self.name,
transport: self.transport,
tell_serializers: Arc::new(self.tell_serializers),
ask_entries: Arc::new(self.ask_entries),
outbound_interceptors: Arc::new(self.outbound_interceptors),
_phantom: PhantomData,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ActorRefEnvelope {
pub actor_id: ActorId,
pub actor_name: String,
pub actor_type: String,
}
#[derive(Debug, Clone)]
pub struct ActorRefTypeMismatch {
pub expected: String,
pub actual: String,
}
impl std::fmt::Display for ActorRefTypeMismatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ActorRef type mismatch: expected '{}', got '{}'",
self.expected, self.actual
)
}
}
impl std::error::Error for ActorRefTypeMismatch {}
impl ActorRefEnvelope {
pub fn from_ref<A: Actor, R: crate::actor::ActorRef<A>>(actor_ref: &R) -> Self {
Self {
actor_id: actor_ref.id(),
actor_name: actor_ref.name(),
actor_type: std::any::type_name::<A>().to_string(),
}
}
pub fn new(
actor_id: ActorId,
actor_name: impl Into<String>,
actor_type: impl Into<String>,
) -> Self {
Self {
actor_id,
actor_name: actor_name.into(),
actor_type: actor_type.into(),
}
}
pub fn try_into_builder<A: Actor>(
self,
transport: Arc<dyn Transport>,
) -> Result<RemoteActorRefBuilder<A>, ActorRefTypeMismatch> {
let expected = std::any::type_name::<A>();
if self.actor_type != expected {
return Err(ActorRefTypeMismatch {
expected: expected.to_string(),
actual: self.actor_type,
});
}
Ok(RemoteActorRefBuilder::<A>::new(
self.actor_id,
self.actor_name,
transport,
))
}
pub fn into_builder_unchecked<A: Actor>(
self,
transport: Arc<dyn Transport>,
) -> RemoteActorRefBuilder<A> {
RemoteActorRefBuilder::<A>::new(self.actor_id, self.actor_name, transport)
}
pub fn is_type<A: Actor>(&self) -> bool {
self.actor_type == std::any::type_name::<A>()
}
}
impl std::fmt::Display for ActorRefEnvelope {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = if self.actor_name.len() > 128 {
&self.actor_name[..128]
} else {
&self.actor_name
};
let atype = if self.actor_type.len() > 128 {
&self.actor_type[..128]
} else {
&self.actor_type
};
write!(
f,
"ActorRef({}, name={}, type={})",
self.actor_id, name, atype
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::node::NodeId;
use crate::remote::WireHeaders;
use crate::transport::InMemoryTransport;
use async_trait::async_trait;
struct Counter {
count: u64,
}
impl Actor for Counter {
type Args = u64;
type Deps = ();
fn create(initial: u64, _: ()) -> Self {
Self { count: initial }
}
}
struct Increment;
impl Message for Increment {
type Reply = ();
}
struct GetCount;
impl Message for GetCount {
type Reply = u64;
}
#[async_trait]
impl Handler<Increment> for Counter {
async fn handle(&mut self, _msg: Increment, _ctx: &mut crate::actor::ActorContext) {
self.count += 1;
}
}
#[async_trait]
impl Handler<GetCount> for Counter {
async fn handle(&mut self, _msg: GetCount, _ctx: &mut crate::actor::ActorContext) -> u64 {
self.count
}
}
#[test]
fn remote_ref_id_and_name() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 42,
},
"counter",
transport,
)
.build();
assert_eq!(remote.id().local, 42);
assert_eq!(remote.name(), "counter");
assert_eq!(remote.target_node().0, "node-2");
assert!(remote.is_alive()); }
#[test]
fn remote_ref_tell_unregistered_type_returns_error() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
transport,
)
.build();
let result = remote.tell(Increment);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not registered"));
}
#[test]
fn remote_ref_ask_unregistered_type_returns_error() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
transport,
)
.build();
let result = remote.ask(GetCount, None);
assert!(result.is_err());
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected error"),
};
assert!(err.to_string().contains("not registered"));
}
#[test]
fn remote_ref_stream_returns_not_implemented() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let _remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
transport,
)
.build();
}
#[test]
fn remote_ref_is_clone() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
transport,
)
.build();
let cloned = remote.clone();
assert_eq!(cloned.id().local, 1);
assert_eq!(cloned.name(), "counter");
}
#[test]
fn remote_ref_tell_with_custom_serializer() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let _remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
transport,
)
.register_tell_with(
TypeId::of::<Increment>(),
"test::Increment",
|_any: &dyn Any| Ok(vec![1, 2, 3]),
)
.build();
}
#[tokio::test]
async fn remote_ref_tell_delivers_via_transport() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let mut rx = transport.register_node(NodeId("node-2".into())).await;
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_tell_with(
TypeId::of::<Increment>(),
"test::Increment",
|_any: &dyn Any| Ok(vec![42]),
)
.build();
transport.connect(&NodeId("node-2".into())).await.unwrap();
remote.tell(Increment).unwrap();
let received = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(received.body, vec![42]);
assert_eq!(received.message_type, "test::Increment");
assert_eq!(received.send_mode, SendMode::Tell);
assert!(received.request_id.is_none());
}
#[tokio::test]
async fn remote_ref_ask_delivers_and_receives_reply() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let mut rx = transport.register_node(NodeId("node-2".into())).await;
transport.connect(&NodeId("node-2".into())).await.unwrap();
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_ask_with(
TypeId::of::<GetCount>(),
"test::GetCount",
|_any: &dyn Any| Ok(vec![0]),
|bytes: &[u8]| {
if bytes.len() != 8 {
return Err(SerializationError::new("expected 8 bytes"));
}
let val = u64::from_be_bytes(bytes.try_into().unwrap());
Ok(Box::new(val) as Box<dyn Any + Send>)
},
)
.build();
let reply_future = remote.ask(GetCount, None).unwrap();
let received = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(received.message_type, "test::GetCount");
assert_eq!(received.send_mode, SendMode::Ask);
let request_id = received.request_id.unwrap();
let reply_envelope = WireEnvelope {
target: ActorId {
node: NodeId("local".into()),
local: 0,
},
target_name: "reply".into(),
message_type: "reply".into(),
send_mode: SendMode::Ask,
headers: WireHeaders::new(),
body: 99u64.to_be_bytes().to_vec(),
request_id: Some(request_id),
version: None,
};
transport
.complete_request(request_id, reply_envelope)
.await
.unwrap();
let count = reply_future.await.unwrap();
assert_eq!(count, 99);
}
#[tokio::test]
async fn remote_ref_ask_with_cancellation() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let _rx = transport.register_node(NodeId("node-2".into())).await;
transport.connect(&NodeId("node-2".into())).await.unwrap();
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_ask_with(
TypeId::of::<GetCount>(),
"test::GetCount",
|_any: &dyn Any| Ok(vec![0]),
|bytes: &[u8]| {
let val = u64::from_be_bytes(bytes.try_into().unwrap());
Ok(Box::new(val) as Box<dyn Any + Send>)
},
)
.build();
let token = CancellationToken::new();
let reply_future = remote.ask(GetCount, Some(token.clone())).unwrap();
token.cancel();
let result = reply_future.await;
assert!(result.is_err());
}
#[cfg(feature = "serde")]
mod serde_tests {
use super::*;
use async_trait::async_trait;
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
struct Add {
amount: u64,
}
impl Message for Add {
type Reply = ();
}
#[async_trait]
impl Handler<Add> for Counter {
async fn handle(&mut self, msg: Add, _ctx: &mut crate::actor::ActorContext) {
self.count += msg.amount;
}
}
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
struct GetValue;
impl Message for GetValue {
type Reply = u64;
}
#[async_trait]
impl Handler<GetValue> for Counter {
async fn handle(
&mut self,
_msg: GetValue,
_ctx: &mut crate::actor::ActorContext,
) -> u64 {
self.count
}
}
#[tokio::test]
async fn serde_tell_roundtrip() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let mut rx = transport.register_node(NodeId("node-2".into())).await;
transport.connect(&NodeId("node-2".into())).await.unwrap();
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_tell::<Add>()
.build();
remote.tell(Add { amount: 42 }).unwrap();
let received = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
let msg: Add = serde_json::from_slice(&received.body).unwrap();
assert_eq!(msg.amount, 42);
}
#[tokio::test]
async fn serde_ask_roundtrip() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let mut rx = transport.register_node(NodeId("node-2".into())).await;
transport.connect(&NodeId("node-2".into())).await.unwrap();
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_ask::<GetValue>()
.build();
let reply_future = remote.ask(GetValue, None).unwrap();
let received = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
let request_id = received.request_id.unwrap();
let reply_body = serde_json::to_vec(&77u64).unwrap();
transport
.complete_request(
request_id,
WireEnvelope {
target: ActorId {
node: NodeId("local".into()),
local: 0,
},
target_name: "reply".into(),
message_type: "reply".into(),
send_mode: SendMode::Ask,
headers: WireHeaders::new(),
body: reply_body,
request_id: Some(request_id),
version: None,
},
)
.await
.unwrap();
let value = reply_future.await.unwrap();
assert_eq!(value, 77);
}
}
struct HeaderStamper;
impl OutboundInterceptor for HeaderStamper {
fn name(&self) -> &'static str {
"header-stamper"
}
fn on_send(
&self,
_ctx: &OutboundContext<'_>,
_runtime_headers: &RuntimeHeaders,
headers: &mut Headers,
_message: &dyn Any,
) -> Disposition {
use crate::message::HeaderValue;
#[derive(Debug)]
struct Stamp;
impl HeaderValue for Stamp {
fn header_name(&self) -> &'static str {
"x-stamp"
}
fn to_bytes(&self) -> Option<Vec<u8>> {
Some(b"stamped".to_vec())
}
fn as_any(&self) -> &dyn Any {
self
}
}
headers.insert(Stamp);
Disposition::Continue
}
}
struct RejectAll;
impl OutboundInterceptor for RejectAll {
fn name(&self) -> &'static str {
"reject-all"
}
fn on_send(
&self,
_ctx: &OutboundContext<'_>,
_runtime_headers: &RuntimeHeaders,
_headers: &mut Headers,
_message: &dyn Any,
) -> Disposition {
Disposition::Reject("blocked by policy".into())
}
}
struct CallCounter(Arc<std::sync::atomic::AtomicU64>);
impl OutboundInterceptor for CallCounter {
fn name(&self) -> &'static str {
"call-counter"
}
fn on_send(
&self,
_ctx: &OutboundContext<'_>,
_runtime_headers: &RuntimeHeaders,
_headers: &mut Headers,
_message: &dyn Any,
) -> Disposition {
self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Disposition::Continue
}
}
#[test]
fn outbound_interceptor_rejects_tell() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_tell_with(
TypeId::of::<Increment>(),
"test::Increment",
|_any: &dyn Any| Ok(vec![1]),
)
.add_outbound_interceptor(Arc::new(RejectAll))
.build();
let result = remote.tell(Increment);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("rejected"));
assert!(err.to_string().contains("blocked by policy"));
}
#[test]
fn outbound_interceptor_rejects_ask() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_ask_with(
TypeId::of::<GetCount>(),
"test::GetCount",
|_any: &dyn Any| Ok(vec![0]),
|bytes: &[u8]| {
let val = u64::from_be_bytes(bytes.try_into().unwrap());
Ok(Box::new(val) as Box<dyn Any + Send>)
},
)
.add_outbound_interceptor(Arc::new(RejectAll))
.build();
let result = remote.ask(GetCount, None);
assert!(result.is_err());
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected error"),
};
assert!(err.to_string().contains("rejected"));
}
#[tokio::test]
async fn outbound_interceptor_stamps_headers_on_tell() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let mut rx = transport.register_node(NodeId("node-2".into())).await;
transport.connect(&NodeId("node-2".into())).await.unwrap();
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_tell_with(
TypeId::of::<Increment>(),
"test::Increment",
|_any: &dyn Any| Ok(vec![1]),
)
.add_outbound_interceptor(Arc::new(HeaderStamper))
.build();
remote.tell(Increment).unwrap();
let received = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(received.headers.get("x-stamp").unwrap(), b"stamped");
}
#[tokio::test]
async fn outbound_interceptor_counter_tracks_sends() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let _rx = transport.register_node(NodeId("node-2".into())).await;
transport.connect(&NodeId("node-2".into())).await.unwrap();
let count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_tell_with(
TypeId::of::<Increment>(),
"test::Increment",
|_any: &dyn Any| Ok(vec![1]),
)
.add_outbound_interceptor(Arc::new(CallCounter(Arc::clone(&count))))
.build();
remote.tell(Increment).unwrap();
remote.tell(Increment).unwrap();
remote.tell(Increment).unwrap();
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[test]
fn outbound_interceptor_chain_runs_in_order() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 1,
},
"counter",
Arc::clone(&transport) as Arc<dyn Transport>,
)
.register_tell_with(
TypeId::of::<Increment>(),
"test::Increment",
|_any: &dyn Any| Ok(vec![1]),
)
.add_outbound_interceptor(Arc::new(CallCounter(Arc::clone(&count))))
.add_outbound_interceptor(Arc::new(RejectAll))
.build();
let result = remote.tell(Increment);
assert!(result.is_err()); assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 1); }
#[test]
fn envelope_from_remote_ref() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let remote = RemoteActorRefBuilder::<Counter>::new(
ActorId {
node: NodeId("node-2".into()),
local: 42,
},
"counter",
transport,
)
.build();
let envelope = ActorRefEnvelope::from_ref::<Counter, _>(&remote);
assert_eq!(envelope.actor_id.local, 42);
assert_eq!(envelope.actor_name, "counter");
assert!(envelope.actor_type.contains("Counter"));
assert!(envelope.is_type::<Counter>());
}
#[test]
fn envelope_new_and_display() {
let envelope = ActorRefEnvelope::new(
ActorId {
node: NodeId("n1".into()),
local: 7,
},
"worker",
"myapp::Worker",
);
assert_eq!(envelope.actor_name, "worker");
let display = format!("{envelope}");
assert!(display.contains("worker"));
assert!(display.contains("myapp::Worker"));
}
#[test]
fn envelope_roundtrip_to_remote_ref() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let envelope = ActorRefEnvelope::new(
ActorId {
node: NodeId("node-3".into()),
local: 99,
},
"service",
std::any::type_name::<Counter>(),
);
assert!(envelope.is_type::<Counter>());
let builder = match envelope
.try_into_builder::<Counter>(Arc::clone(&transport) as Arc<dyn Transport>)
{
Ok(b) => b,
Err(e) => panic!("type should match: {e}"),
};
let remote = builder.build();
assert_eq!(remote.id().local, 99);
assert_eq!(remote.name(), "service");
}
#[test]
fn envelope_into_builder() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let envelope = ActorRefEnvelope::new(
ActorId {
node: NodeId("node-2".into()),
local: 5,
},
"actor",
std::any::type_name::<Counter>(),
);
let remote = envelope
.into_builder_unchecked::<Counter>(Arc::clone(&transport) as Arc<dyn Transport>)
.register_tell_with(
TypeId::of::<Increment>(),
"test::Increment",
|_any: &dyn Any| Ok(vec![1]),
)
.build();
assert_eq!(remote.id().local, 5);
}
#[test]
fn envelope_type_check() {
let envelope = ActorRefEnvelope::new(
ActorId {
node: NodeId("n".into()),
local: 1,
},
"x",
"other::Actor",
);
assert!(!envelope.is_type::<Counter>());
}
#[test]
fn envelope_type_mismatch_returns_error() {
let transport = Arc::new(InMemoryTransport::new(NodeId("local".into())));
let envelope = ActorRefEnvelope::new(
ActorId {
node: NodeId("n".into()),
local: 1,
},
"x",
"wrong::Type",
);
let result =
envelope.try_into_builder::<Counter>(Arc::clone(&transport) as Arc<dyn Transport>);
assert!(result.is_err());
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected type mismatch error"),
};
assert!(err.expected.contains("Counter"));
assert_eq!(err.actual, "wrong::Type");
assert!(err.to_string().contains("mismatch"));
}
#[test]
fn envelope_display_truncates_long_names() {
let long_name = "x".repeat(300);
let envelope = ActorRefEnvelope::new(
ActorId {
node: NodeId("n".into()),
local: 1,
},
&long_name,
"type",
);
let display = format!("{envelope}");
assert!(display.len() < 300);
}
#[cfg(feature = "serde")]
#[test]
fn envelope_serde_roundtrip() {
let envelope = ActorRefEnvelope::new(
ActorId {
node: NodeId("node-1".into()),
local: 42,
},
"counter",
"myapp::Counter",
);
let json = serde_json::to_string(&envelope).unwrap();
let deserialized: ActorRefEnvelope = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, envelope);
}
}