use crate::{
actor::{Actor, ActorContext},
error::{ActorError, Result},
message::{ExitReason, Message},
pid::Pid,
system::ActorSystem,
telemetry::{GenServerMetrics, actor_type_name},
};
use async_trait::async_trait;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
use tokio::sync::oneshot;
pub enum CallResponse<T> {
Reply(T),
NoReply,
}
pub struct ReplyHandle<T> {
reply_tx: Option<oneshot::Sender<T>>,
}
impl<T> ReplyHandle<T> {
pub(crate) fn new(reply_tx: oneshot::Sender<T>) -> Self {
Self {
reply_tx: Some(reply_tx),
}
}
pub fn reply(mut self, value: T) -> Result<()> {
if let Some(tx) = self.reply_tx.take() {
tx.send(value)
.map_err(|_| ActorError::other("reply failed: receiver dropped"))
} else {
Err(ActorError::other("reply already sent"))
}
}
}
impl<T> Drop for ReplyHandle<T> {
fn drop(&mut self) {
if self.reply_tx.is_some() {
tracing::warn!("ReplyHandle dropped without calling reply() - caller will timeout");
}
}
}
#[async_trait]
pub trait GenServer: Send + 'static {
type State: Send + 'static;
type Call: Send + Debug + 'static;
type Cast: Send + Debug + 'static;
type CallReply: Send + 'static;
async fn init(&mut self, ctx: &mut GenServerContext<'_, Self>) -> Self::State;
async fn handle_call(
&mut self,
call: Self::Call,
state: &mut Self::State,
ctx: &mut GenServerContext<'_, Self>,
) -> CallResponse<Self::CallReply>;
async fn handle_cast(
&mut self,
cast: Self::Cast,
state: &mut Self::State,
ctx: &mut GenServerContext<'_, Self>,
);
async fn terminate(
&mut self,
_reason: &ExitReason,
_state: &mut Self::State,
_ctx: &mut GenServerContext<'_, Self>,
) {
}
async fn handle_info(
&mut self,
_msg: Message,
_state: &mut Self::State,
_ctx: &mut GenServerContext<'_, Self>,
) {
}
}
pub struct GenServerContext<'a, G: GenServer + ?Sized> {
actor_ctx: &'a mut ActorContext,
reply_handle: Option<ReplyHandle<G::CallReply>>,
_phantom: PhantomData<G>,
}
impl<'a, G: GenServer> GenServerContext<'a, G> {
fn new(actor_ctx: &'a mut ActorContext) -> Self {
Self {
actor_ctx,
reply_handle: None,
_phantom: PhantomData,
}
}
fn set_reply_handle(&mut self, handle: ReplyHandle<G::CallReply>) {
self.reply_handle = Some(handle);
}
pub fn pid(&self) -> Pid {
self.actor_ctx.pid()
}
pub fn stop(&mut self, reason: ExitReason) {
self.actor_ctx.stop(reason);
}
pub fn trap_exit(&mut self, trap: bool) {
self.actor_ctx.trap_exit(trap);
}
pub fn reply_handle(&mut self) -> ReplyHandle<G::CallReply> {
self.reply_handle
.take()
.expect("reply_handle() called outside handle_call or already taken")
}
pub fn whereis(&self, name: &str) -> Option<Pid> {
self.actor_ctx.whereis(name)
}
pub fn send_after(
&self,
dest: crate::scheduler::Destination,
msg: Message,
duration: std::time::Duration,
) -> Option<crate::scheduler::TimerRef> {
self.actor_ctx.send_after(dest, msg, duration)
}
pub fn send_info_after(
&self,
msg: Message,
duration: std::time::Duration,
) -> Option<crate::scheduler::TimerRef> {
let info_msg: GenServerMsg<G> = GenServerMsg::Info { message: msg };
self.actor_ctx.send_after(
crate::scheduler::Destination::Pid(self.pid()),
Box::new(info_msg),
duration,
)
}
}
enum GenServerMsg<G: GenServer> {
Call {
request: G::Call,
reply_tx: oneshot::Sender<G::CallReply>,
},
Cast {
message: G::Cast,
},
Info {
message: Message,
},
}
struct GenServerActor<G: GenServer> {
server: G,
state: Option<G::State>,
server_type: &'static str,
}
impl<G: GenServer> GenServerActor<G> {
fn new(server: G) -> Self {
let server_type = actor_type_name::<G>();
Self {
server,
state: None,
server_type,
}
}
}
#[async_trait]
impl<G: GenServer> Actor for GenServerActor<G> {
async fn started(&mut self, ctx: &mut ActorContext) {
let mut gen_ctx = GenServerContext::new(ctx);
let state = self.server.init(&mut gen_ctx).await;
self.state = Some(state);
}
async fn handle_message(&mut self, msg: Message, ctx: &mut ActorContext) {
let state = match self.state.as_mut() {
Some(s) => s,
None => return, };
let mut gen_ctx = GenServerContext::new(ctx);
if let Ok(gen_msg) = msg.downcast::<GenServerMsg<G>>() {
match *gen_msg {
GenServerMsg::Call { request, reply_tx } => {
let _span = GenServerMetrics::call_span(self.server_type);
GenServerMetrics::calls_in_flight_inc(self.server_type);
gen_ctx.set_reply_handle(ReplyHandle::new(reply_tx));
let response = self.server.handle_call(request, state, &mut gen_ctx).await;
match response {
CallResponse::Reply(value) => {
if let Some(handle) = gen_ctx.reply_handle.take() {
let _ = handle.reply(value); }
}
CallResponse::NoReply => {
}
}
GenServerMetrics::calls_in_flight_dec(self.server_type);
}
GenServerMsg::Cast { message } => {
GenServerMetrics::cast(self.server_type);
self.server.handle_cast(message, state, &mut gen_ctx).await;
}
GenServerMsg::Info { message } => {
self.server.handle_info(message, state, &mut gen_ctx).await;
}
}
}
}
async fn stopped(&mut self, reason: &ExitReason, ctx: &mut ActorContext) {
if let Some(state) = self.state.as_mut() {
let mut gen_ctx = GenServerContext::new(ctx);
self.server.terminate(reason, state, &mut gen_ctx).await;
}
}
}
pub struct GenServerRef<G: GenServer> {
pid: Pid,
system: Arc<ActorSystem>,
_phantom: PhantomData<G>,
}
impl<G: GenServer> Clone for GenServerRef<G> {
fn clone(&self) -> Self {
Self {
pid: self.pid,
system: self.system.clone(),
_phantom: PhantomData,
}
}
}
impl<G: GenServer> std::fmt::Debug for GenServerRef<G> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GenServerRef")
.field("pid", &self.pid)
.finish()
}
}
impl<G: GenServer> GenServerRef<G> {
pub fn pid(&self) -> Pid {
self.pid
}
pub async fn call(&self, request: G::Call) -> Result<G::CallReply> {
let (tx, rx) = oneshot::channel();
let msg: GenServerMsg<G> = GenServerMsg::Call {
request,
reply_tx: tx,
};
self.system.send(self.pid, Box::new(msg)).await?;
rx.await.map_err(|_| ActorError::ActorNotFound(self.pid))
}
pub async fn cast(&self, message: G::Cast) -> Result<()> {
let msg: GenServerMsg<G> = GenServerMsg::Cast { message };
self.system.send(self.pid, Box::new(msg)).await
}
pub async fn send_info(&self, message: Message) -> Result<()> {
let msg: GenServerMsg<G> = GenServerMsg::Info { message };
self.system.send(self.pid, Box::new(msg)).await
}
}
pub fn spawn<G: GenServer>(system: &Arc<ActorSystem>, server: G) -> GenServerRef<G> {
let actor = GenServerActor::new(server);
let actor_ref = system.spawn(actor);
GenServerRef {
pid: actor_ref.pid(),
system: system.clone(),
_phantom: PhantomData,
}
}
pub fn spawn_named<G: GenServer>(
system: &Arc<ActorSystem>,
server: G,
name: impl Into<String>,
) -> Result<GenServerRef<G>> {
let actor = GenServerActor::new(server);
let actor_ref = system.spawn(actor);
let pid = actor_ref.pid();
system.register(name, pid)?;
Ok(GenServerRef {
pid,
system: system.clone(),
_phantom: PhantomData,
})
}