use crate::{
actor::{Actor, ActorContext},
error::{ActorError, Result},
message::{ExitReason, Message},
pid::Pid,
system::{ActorRef, ActorSystem},
telemetry::{GenStatemMetrics, GenStatemStateSpan, actor_type_name},
};
use async_trait::async_trait;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
use tokio::sync::oneshot;
#[async_trait]
pub trait GenStatem: Send + 'static {
type State: Send + Clone + Debug + 'static;
type Data: Send + 'static;
type Event: Send + Debug + 'static;
type Reply: Send + 'static;
async fn init(&mut self, ctx: &mut StateMachineContext<'_, Self>) -> (Self::State, Self::Data);
async fn handle_event(
&mut self,
event: Self::Event,
state: Self::State,
data: &mut Self::Data,
ctx: &mut StateMachineContext<'_, Self>,
) -> StateTransition<Self>;
async fn state_enter(
&mut self,
_old_state: &Self::State,
_new_state: &Self::State,
_data: &mut Self::Data,
_ctx: &mut StateMachineContext<'_, Self>,
) {
}
async fn terminate(
&mut self,
_reason: &ExitReason,
_state: &Self::State,
_data: &mut Self::Data,
_ctx: &mut StateMachineContext<'_, Self>,
) {
}
}
pub struct StateTransition<G: GenStatem + ?Sized> {
pub(crate) new_state: Option<G::State>,
pub(crate) reply: G::Reply,
pub(crate) stop: bool,
}
impl<G: GenStatem> StateTransition<G> {
pub fn keep_state(reply: G::Reply) -> Self {
Self {
new_state: None,
reply,
stop: false,
}
}
pub fn next_state(new_state: G::State, reply: G::Reply) -> Self {
Self {
new_state: Some(new_state),
reply,
stop: false,
}
}
pub fn stop(reply: G::Reply) -> Self {
Self {
new_state: None,
reply,
stop: true,
}
}
}
pub struct StateMachineContext<'a, G: GenStatem + ?Sized> {
actor_ctx: &'a mut ActorContext,
_phantom: PhantomData<G>,
}
impl<'a, G: GenStatem> StateMachineContext<'a, G> {
fn new(actor_ctx: &'a mut ActorContext) -> Self {
Self {
actor_ctx,
_phantom: PhantomData,
}
}
pub fn pid(&self) -> Pid {
self.actor_ctx.pid()
}
pub fn stop(&mut self, reason: ExitReason) {
self.actor_ctx.stop(reason);
}
pub fn trap_exit(&mut self, trap: bool) {
self.actor_ctx.trap_exit(trap);
}
}
enum StatemMsg<G: GenStatem> {
Call {
event: G::Event,
reply_tx: oneshot::Sender<G::Reply>,
},
Cast {
event: G::Event,
},
}
struct GenStatemActor<G: GenStatem> {
statem: G,
state: Option<G::State>,
data: Option<G::Data>,
fsm_type: &'static str,
state_span: Option<GenStatemStateSpan>,
}
impl<G: GenStatem> GenStatemActor<G> {
fn new(statem: G) -> Self {
let fsm_type = actor_type_name::<G>();
Self {
statem,
state: None,
data: None,
fsm_type,
state_span: None,
}
}
}
#[async_trait]
impl<G: GenStatem> Actor for GenStatemActor<G> {
async fn started(&mut self, ctx: &mut ActorContext) {
let mut sm_ctx = StateMachineContext::new(ctx);
let (state, data) = self.statem.init(&mut sm_ctx).await;
let state_str = format!("{:?}", state);
GenStatemMetrics::current_state(self.fsm_type, &state_str);
self.state_span = Some(GenStatemMetrics::state_duration_span(
self.fsm_type,
&state_str,
));
self.state = Some(state);
self.data = Some(data);
}
async fn handle_message(&mut self, msg: Message, ctx: &mut ActorContext) {
let state = match self.state.take() {
Some(s) => s,
None => return, };
let data = match self.data.as_mut() {
Some(d) => d,
None => {
self.state = Some(state);
return;
}
};
let mut sm_ctx = StateMachineContext::new(ctx);
if let Ok(statem_msg) = msg.downcast::<StatemMsg<G>>() {
match *statem_msg {
StatemMsg::Call { event, reply_tx } => {
let state_str = format!("{:?}", state);
let transition = self
.statem
.handle_event(event, state.clone(), data, &mut sm_ctx)
.await;
let _ = reply_tx.send(transition.reply);
if transition.stop {
sm_ctx.stop(ExitReason::Normal);
self.state = Some(state);
return;
}
if let Some(new_state) = transition.new_state {
let new_state_str = format!("{:?}", new_state);
GenStatemMetrics::state_transition(
self.fsm_type,
&state_str,
&new_state_str,
);
if let Some(span) = self.state_span.take() {
span.finish();
}
self.statem
.state_enter(&state, &new_state, data, &mut sm_ctx)
.await;
GenStatemMetrics::current_state(self.fsm_type, &new_state_str);
self.state_span = Some(GenStatemMetrics::state_duration_span(
self.fsm_type,
&new_state_str,
));
self.state = Some(new_state);
} else {
self.state = Some(state);
}
}
StatemMsg::Cast { event } => {
let state_str = format!("{:?}", state);
let transition = self
.statem
.handle_event(event, state.clone(), data, &mut sm_ctx)
.await;
if transition.stop {
sm_ctx.stop(ExitReason::Normal);
self.state = Some(state);
return;
}
if let Some(new_state) = transition.new_state {
let new_state_str = format!("{:?}", new_state);
GenStatemMetrics::state_transition(
self.fsm_type,
&state_str,
&new_state_str,
);
if let Some(span) = self.state_span.take() {
span.finish();
}
self.statem
.state_enter(&state, &new_state, data, &mut sm_ctx)
.await;
GenStatemMetrics::current_state(self.fsm_type, &new_state_str);
self.state_span = Some(GenStatemMetrics::state_duration_span(
self.fsm_type,
&new_state_str,
));
self.state = Some(new_state);
} else {
self.state = Some(state);
}
}
}
} else {
self.state = Some(state);
}
}
async fn stopped(&mut self, reason: &ExitReason, ctx: &mut ActorContext) {
if let (Some(state), Some(data)) = (self.state.as_ref(), self.data.as_mut()) {
let mut sm_ctx = StateMachineContext::new(ctx);
self.statem
.terminate(reason, state, data, &mut sm_ctx)
.await;
}
}
}
pub struct GenStatemRef<G: GenStatem> {
actor_ref: ActorRef,
_phantom: PhantomData<G>,
}
impl<G: GenStatem> GenStatemRef<G> {
pub fn pid(&self) -> Pid {
self.actor_ref.pid()
}
pub async fn call(&self, event: G::Event) -> Result<G::Reply> {
let (tx, rx) = oneshot::channel();
let msg: StatemMsg<G> = StatemMsg::Call {
event,
reply_tx: tx,
};
self.actor_ref.send(Box::new(msg)).await?;
rx.await
.map_err(|_| ActorError::ActorNotFound(self.actor_ref.pid()))
}
pub async fn cast(&self, event: G::Event) -> Result<()> {
let msg: StatemMsg<G> = StatemMsg::Cast { event };
self.actor_ref.send(Box::new(msg)).await
}
}
pub fn spawn<G: GenStatem>(system: &Arc<ActorSystem>, statem: G) -> GenStatemRef<G> {
let actor = GenStatemActor::new(statem);
let actor_ref = system.spawn(actor);
GenStatemRef {
actor_ref,
_phantom: PhantomData,
}
}