use std::collections::HashMap;
use ractor::cast;
use ractor::concurrency::JoinHandle;
use ractor::message::SerializedMessage;
use ractor::Actor;
use ractor::ActorCell;
use ractor::ActorId;
use ractor::ActorName;
use ractor::ActorProcessingErr;
use ractor::ActorRef;
use ractor::RpcReplyPort;
use ractor::SpawnErr;
use ractor_cluster_derive::RactorMessage;
use crate::node::NodeSessionMessage;
use crate::NodeId;
#[cfg(test)]
mod tests;
pub(crate) struct RemoteActor;
impl RemoteActor {
pub(crate) async fn spawn_linked(
self,
session: ActorRef<super::node::NodeSessionMessage>,
name: Option<ActorName>,
pid: u64,
node_id: NodeId,
supervisor: ActorCell,
) -> Result<(ActorRef<RemoteActorMessage>, JoinHandle<()>), SpawnErr> {
let actor_id = ActorId::Remote { node_id, pid };
ractor::ActorRuntime::<Self>::spawn_linked_remote(name, self, actor_id, session, supervisor)
.await
}
}
pub(crate) struct RemoteActorState {
message_tag: u64,
pending_requests: HashMap<u64, RpcReplyPort<Vec<u8>>>,
session: ActorRef<crate::node::NodeSessionMessage>,
}
impl RemoteActorState {
fn get_and_increment_mtag(&mut self) -> u64 {
self.message_tag += 1;
self.message_tag
}
}
#[derive(RactorMessage)]
pub(crate) struct RemoteActorMessage;
#[cfg_attr(feature = "async-trait", ractor::async_trait)]
impl Actor for RemoteActor {
type Msg = RemoteActorMessage;
type State = RemoteActorState;
type Arguments = ActorRef<crate::node::NodeSessionMessage>;
async fn pre_start(
&self,
_myself: ActorRef<Self::Msg>,
session: ActorRef<crate::node::NodeSessionMessage>,
) -> Result<Self::State, ActorProcessingErr> {
Ok(Self::State {
session,
message_tag: 0,
pending_requests: HashMap::new(),
})
}
async fn handle(
&self,
_myself: ActorRef<Self::Msg>,
_message: Self::Msg,
_state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
Err(From::from("`RemoteActor`s cannot handle local messages!"))
}
async fn handle_serialized(
&self,
myself: ActorRef<Self::Msg>,
message: SerializedMessage,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
let to = myself.get_id().pid();
match message {
SerializedMessage::Call {
args,
reply,
variant,
metadata,
} => {
let tag = state.get_and_increment_mtag();
let node_msg = crate::protocol::node::NodeMessage {
msg: Some(crate::protocol::node::node_message::Msg::Call(
crate::protocol::node::Call {
to,
tag,
what: args,
timeout_ms: reply.get_timeout().map(|t| t.as_millis() as u64),
variant,
metadata,
},
)),
};
state.pending_requests.insert(tag, reply);
let _ = cast!(state.session, NodeSessionMessage::SendMessage(node_msg));
}
SerializedMessage::Cast {
args,
variant,
metadata,
} => {
let node_msg = crate::protocol::node::NodeMessage {
msg: Some(crate::protocol::node::node_message::Msg::Cast(
crate::protocol::node::Cast {
to,
what: args,
variant,
metadata,
},
)),
};
let _ = cast!(state.session, NodeSessionMessage::SendMessage(node_msg));
}
SerializedMessage::CallReply(message_tag, reply_data) => {
if let Some(port) = state.pending_requests.remove(&message_tag) {
let _ = port.send(reply_data);
}
}
}
Ok(())
}
}