use std::collections::HashMap;
use std::sync::{LazyLock, Mutex};
use std::thread::{self, JoinHandle};
use crate::error::SerDeError;
use crate::ipc::OpaqueIpcReceiver;
use crate::ipc::{self, IpcMessage, IpcReceiver, IpcReceiverSet, IpcSelectionResult, IpcSender};
use crossbeam_channel::{self, Receiver, Sender};
use serde_core::{Deserialize, Serialize};
pub static ROUTER: LazyLock<RouterProxy> = LazyLock::new(RouterProxy::new);
pub struct RouterProxy {
comm: Mutex<RouterProxyComm>,
}
impl Drop for RouterProxy {
fn drop(&mut self) {
self.shutdown();
}
}
#[allow(clippy::new_without_default)]
impl RouterProxy {
pub fn new() -> RouterProxy {
let (msg_sender, msg_receiver) = crossbeam_channel::unbounded();
let (wakeup_sender, wakeup_receiver) = ipc::channel().unwrap();
let handle = thread::Builder::new()
.name("router-proxy".to_string())
.spawn(move || Router::new(msg_receiver, wakeup_receiver).run())
.expect("Failed to spawn router proxy thread");
RouterProxy {
comm: Mutex::new(RouterProxyComm {
msg_sender,
wakeup_sender,
shutdown: false,
handle: Some(handle),
}),
}
}
fn add_route(&self, receiver: OpaqueIpcReceiver, callback: RouterHandler) {
let comm = self.comm.lock().unwrap();
if comm.shutdown {
return;
}
comm.msg_sender
.send(RouterMsg::AddRoute(receiver, callback))
.unwrap();
comm.wakeup_sender.send(()).unwrap();
}
pub fn add_typed_route<T>(
&self,
receiver: IpcReceiver<T>,
mut callback: TypedRouterMultiHandler<T>,
) where
T: Serialize + for<'de> Deserialize<'de> + 'static,
{
let modified_callback = move |msg: IpcMessage| {
let typed_message = msg.to::<T>();
callback(typed_message)
};
self.add_route(
receiver.to_opaque(),
RouterHandler::Multi(Box::new(modified_callback)),
);
}
pub fn add_typed_one_shot_route<T>(
&self,
receiver: IpcReceiver<T>,
callback: TypedRouterOneShotHandler<T>,
) where
T: Serialize + for<'de> Deserialize<'de> + 'static,
{
let modified_callback = move |msg: IpcMessage| {
let typed_message = msg.to::<T>();
callback(typed_message)
};
self.add_route(
receiver.to_opaque(),
RouterHandler::Once(Some(Box::new(modified_callback))),
);
}
pub fn shutdown(&self) {
let mut comm = self.comm.lock().unwrap();
if comm.shutdown {
return;
}
comm.shutdown = true;
let (ack_sender, ack_receiver) = crossbeam_channel::unbounded();
comm.wakeup_sender
.send(())
.map(|_| {
comm.msg_sender
.send(RouterMsg::Shutdown(ack_sender))
.unwrap();
ack_receiver.recv().unwrap();
})
.unwrap();
comm.handle
.take()
.expect("Should have a join handle at shutdown")
.join()
.expect("Failed to join on the router proxy thread");
}
pub fn route_ipc_receiver_to_crossbeam_sender<T>(
&self,
ipc_receiver: IpcReceiver<T>,
crossbeam_sender: Sender<T>,
) where
T: for<'de> Deserialize<'de> + Serialize + Send + 'static,
{
self.add_typed_route(
ipc_receiver,
Box::new(move |message| drop(crossbeam_sender.send(message.unwrap()))),
)
}
pub fn route_ipc_receiver_to_new_crossbeam_receiver<T>(
&self,
ipc_receiver: IpcReceiver<T>,
) -> Receiver<T>
where
T: for<'de> Deserialize<'de> + Serialize + Send + 'static,
{
let (crossbeam_sender, crossbeam_receiver) = crossbeam_channel::unbounded();
self.route_ipc_receiver_to_crossbeam_sender(ipc_receiver, crossbeam_sender);
crossbeam_receiver
}
}
struct RouterProxyComm {
msg_sender: Sender<RouterMsg>,
wakeup_sender: IpcSender<()>,
shutdown: bool,
handle: Option<JoinHandle<()>>,
}
struct Router {
msg_receiver: Receiver<RouterMsg>,
msg_wakeup_id: u64,
ipc_receiver_set: IpcReceiverSet,
handlers: HashMap<u64, RouterHandler>,
}
impl Router {
fn new(msg_receiver: Receiver<RouterMsg>, wakeup_receiver: IpcReceiver<()>) -> Router {
let mut ipc_receiver_set = IpcReceiverSet::new().unwrap();
let msg_wakeup_id = ipc_receiver_set.add(wakeup_receiver).unwrap();
Router {
msg_receiver,
msg_wakeup_id,
ipc_receiver_set,
handlers: HashMap::new(),
}
}
fn run(&mut self) {
loop {
let results = match self.ipc_receiver_set.select() {
Ok(results) => results,
Err(_) => break,
};
for result in results.into_iter() {
match result {
IpcSelectionResult::MessageReceived(id, _) if id == self.msg_wakeup_id => {
match self.msg_receiver.recv().unwrap() {
RouterMsg::AddRoute(receiver, handler) => {
let new_receiver_id =
self.ipc_receiver_set.add_opaque(receiver).unwrap();
self.handlers.insert(new_receiver_id, handler);
},
RouterMsg::Shutdown(sender) => {
sender
.send(())
.expect("Failed to send comfirmation of shutdown.");
return;
},
}
},
IpcSelectionResult::MessageReceived(id, message) => {
match self.handlers.get_mut(&id).unwrap() {
RouterHandler::Once(handler) => {
if let Some(handler) = handler.take() {
(handler)(message);
}
},
RouterHandler::Multi(ref mut handler) => {
(handler)(message);
},
}
},
IpcSelectionResult::ChannelClosed(id) => {
let _ = self.handlers.remove(&id).unwrap();
},
}
}
}
}
}
enum RouterMsg {
AddRoute(OpaqueIpcReceiver, RouterHandler),
Shutdown(Sender<()>),
}
pub type RouterMultiHandler = Box<dyn FnMut(IpcMessage) + Send>;
pub type RouterOneShotHandler = Box<dyn FnOnce(IpcMessage) + Send>;
enum RouterHandler {
Once(Option<RouterOneShotHandler>),
Multi(RouterMultiHandler),
}
pub type TypedRouterMultiHandler<T> = Box<dyn FnMut(Result<T, SerDeError>) + Send>;
pub type TypedRouterOneShotHandler<T> = Box<dyn FnOnce(Result<T, SerDeError>) + Send>;