use anyhow::anyhow;
use std::{collections::HashMap, pin::Pin, time::Duration};
use tokio::sync::{mpsc, oneshot};
use std::future::Future;
use crate::command::CommandMessage;
type BoxedCallback = Box<dyn FnMut(&CommandMessage) -> CommandMessage + Send + 'static>;
type BoxedTickCallback = Box<dyn FnMut() -> () + Send + 'static>;
type BoxedAsyncCallback = Box<
dyn FnMut(&CommandMessage) -> Pin<Box<dyn Future<Output = CommandMessage> + Send>>
+ Send
+ 'static,
>;
type BoxedAsyncTickCallback =
Box<dyn FnMut() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + 'static>;
enum ActorMessage {
RegisterRoute {
respond_to: oneshot::Sender<u32>,
route: String,
cb: BoxedCallback,
},
ProcessRoute {
respond_to: oneshot::Sender<u32>,
route: String,
arg: CommandMessage,
},
RegisterAsyncRoute {
respond_to: oneshot::Sender<u32>,
route: String,
cb: BoxedAsyncCallback,
},
RegisterTick {
respond_to: oneshot::Sender<u32>,
cb: BoxedTickCallback,
},
RegisterAsyncTick {
respond_to: oneshot::Sender<u32>,
cb: BoxedAsyncTickCallback,
},
Shutdown,
}
struct MechActorInternal {
receiver: mpsc::Receiver<ActorMessage>,
next_id: u32,
routes: HashMap<String, BoxedCallback>,
tick_routes: HashMap<u32, BoxedTickCallback>,
tick_uid: u32,
async_routes: HashMap<String, BoxedAsyncCallback>,
async_tick_routes: HashMap<u32, BoxedAsyncTickCallback>,
}
impl MechActorInternal {
fn new(receiver: mpsc::Receiver<ActorMessage>) -> Self {
MechActorInternal {
receiver,
next_id: 0,
routes: HashMap::new(),
tick_routes: HashMap::new(),
tick_uid: 0,
async_routes: HashMap::new(),
async_tick_routes: HashMap::new(),
}
}
fn handle_message(&mut self, msg: ActorMessage) {
match msg {
ActorMessage::RegisterRoute {
respond_to,
route,
cb,
} => {
self.routes.insert(route, cb);
let _ = respond_to.send(self.next_id);
}
ActorMessage::RegisterTick { respond_to, cb } => {
self.tick_uid += 1;
self.tick_routes.insert(self.tick_uid, cb);
let _ = respond_to.send(self.tick_uid);
}
ActorMessage::RegisterAsyncRoute {
respond_to,
route,
cb,
} => {
self.async_routes.insert(route, cb);
let _ = respond_to.send(self.next_id);
}
ActorMessage::RegisterAsyncTick { respond_to, cb } => {
self.tick_uid += 1;
self.async_tick_routes.insert(self.tick_uid, cb);
let _ = respond_to.send(self.tick_uid);
}
_ => {
}
}
}
async fn tick(&mut self) {
for (_, cb) in &mut self.tick_routes {
cb();
}
for (_, cb) in &mut self.async_tick_routes {
cb().await;
}
}
}
async fn run_my_actor(mut actor: MechActorInternal, interval: Duration) {
let mut ticker = tokio::time::interval(interval);
loop {
tokio::select! {
_ = ticker.tick() => {
actor.tick().await;
},
msg = actor.receiver.recv() => {
if let Some(m) = msg {
match m {
ActorMessage::Shutdown => {
break;
},
ActorMessage::ProcessRoute { respond_to : _, route, arg} => {
let arg_copy = arg.clone();
if let Some(cb) = actor.async_routes.get_mut(&route) {
cb(&arg_copy).await;
}
if let Some(cb) = actor.routes.get_mut(&route) {
cb(&arg_copy);
}
},
_ => {
actor.handle_message(m);
}
}
}
}
}
}
}
#[derive(Clone)]
pub struct MechActor {
sender: mpsc::Sender<ActorMessage>,
}
impl MechActor {
pub fn new(interval: Duration) -> Self {
let (sender, receiver) = mpsc::channel(8);
let actor = MechActorInternal::new(receiver);
tokio::spawn(run_my_actor(actor, interval));
Self { sender: sender }
}
pub async fn register_route<C>(&self, route: &str, cb: C) -> Result<(), anyhow::Error>
where
C: Fn(&CommandMessage) -> CommandMessage + Send + 'static,
{
let (send, recv) = oneshot::channel();
let msg = ActorMessage::RegisterRoute {
respond_to: send,
route: route.to_string(),
cb: Box::new(cb),
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(_) => Ok(()),
Err(err) => return Err(anyhow!("Error registering route: {}", err)),
}
}
pub async fn register_tick<C>(&self, mut cb: C) -> Result<(), anyhow::Error>
where
C: FnMut() -> () + Send + 'static,
{
let (send, recv) = oneshot::channel();
let msg = ActorMessage::RegisterTick {
respond_to: send,
cb: Box::new(move || cb()),
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(_) => Ok(()),
Err(err) => return Err(anyhow!("Error registering tick: {}", err)),
}
}
pub async fn register_async_route<C, Fut>(
&self,
route: &str,
cb: C,
) -> Result<(), anyhow::Error>
where
C: Fn(&CommandMessage) -> Fut + Send + 'static,
Fut: Future<Output = CommandMessage> + Send + 'static,
{
let (send, recv) = oneshot::channel();
let msg = ActorMessage::RegisterAsyncRoute {
respond_to: send,
route: route.to_string(),
cb: Box::new(move |cmd| Box::pin(cb(cmd))),
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(_) => Ok(()),
Err(err) => return Err(anyhow!("Error registering route: {}", err)),
}
}
pub async fn register_async_tick<C, Fut>(&self, mut cb: C) -> Result<(), anyhow::Error>
where
C: FnMut() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let (send, recv) = oneshot::channel();
let msg = ActorMessage::RegisterAsyncTick {
respond_to: send,
cb: Box::new(move || Box::pin(cb())),
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(_) => Ok(()),
Err(err) => return Err(anyhow!("Error registering tick: {}", err)),
}
}
pub async fn shutdown(&self) -> Result<(), anyhow::Error> {
let (_, recv) = oneshot::channel::<u32>();
let msg = ActorMessage::Shutdown;
let _ = self.sender.send(msg).await;
match recv.await {
Ok(_) => Ok(()),
Err(err) => return Err(anyhow!("Error shutting down Actor: {}", err)),
}
}
}