use std::collections::HashMap;
use std::panic;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)]
pub enum CallbackResult {
Success(String),
Error(u32),
}
impl CallbackResult {
pub fn into_result(self) -> Result<String, u32> {
match self {
CallbackResult::Success(data) => Ok(data),
CallbackResult::Error(code) => Err(code),
}
}
}
enum CallbackEntry {
Oneshot(oneshot::Sender<CallbackResult>),
Stream(mpsc::Sender<CallbackResult>),
Handler(Arc<dyn Fn(CallbackResult) + Send + Sync>),
}
struct CallbackRegistry {
callbacks: Mutex<HashMap<u64, CallbackEntry>>,
next_id: AtomicU64,
}
impl CallbackRegistry {
fn new() -> Self {
Self {
callbacks: Mutex::new(HashMap::new()),
next_id: AtomicU64::new(1),
}
}
fn register_oneshot(&self) -> (u64, oneshot::Receiver<CallbackResult>) {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let (sender, receiver) = oneshot::channel();
{
let mut callbacks = self.callbacks.lock().unwrap();
callbacks.insert(id, CallbackEntry::Oneshot(sender));
}
(id, receiver)
}
fn register_stream(&self) -> (u64, mpsc::Receiver<CallbackResult>) {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let (sender, receiver) = mpsc::channel(16);
{
let mut callbacks = self.callbacks.lock().unwrap();
callbacks.insert(id, CallbackEntry::Stream(sender));
}
(id, receiver)
}
fn register_handler<F>(&self, handler: F) -> u64
where
F: Fn(CallbackResult) + Send + Sync + 'static,
{
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
{
let mut callbacks = self.callbacks.lock().unwrap();
callbacks.insert(id, CallbackEntry::Handler(Arc::new(handler)));
}
id
}
fn unregister(&self, id: u64) -> bool {
let mut callbacks = self.callbacks.lock().unwrap();
callbacks.remove(&id).is_some()
}
fn invoke(&self, id: u64, result: CallbackResult) -> bool {
enum Action {
Oneshot(oneshot::Sender<CallbackResult>),
Stream(mpsc::Sender<CallbackResult>),
Handler(Arc<dyn Fn(CallbackResult) + Send + Sync>),
None,
}
let action = {
let mut callbacks = self.callbacks.lock().unwrap();
match callbacks.get(&id) {
Some(CallbackEntry::Oneshot(_)) => {
if let Some(CallbackEntry::Oneshot(sender)) = callbacks.remove(&id) {
Action::Oneshot(sender)
} else {
Action::None
}
}
Some(CallbackEntry::Stream(sender)) => Action::Stream(sender.clone()),
Some(CallbackEntry::Handler(handler)) => Action::Handler(handler.clone()),
None => Action::None,
}
};
match action {
Action::Oneshot(sender) => {
let _ = sender.send(result);
true
}
Action::Stream(sender) => match sender.try_send(result) {
Ok(_) => true,
Err(mpsc::error::TrySendError::Full(_payload)) => {
false
}
Err(mpsc::error::TrySendError::Closed(_payload)) => {
let mut callbacks = self.callbacks.lock().unwrap();
callbacks.remove(&id);
false
}
},
Action::Handler(handler) => {
let handled = panic::catch_unwind(panic::AssertUnwindSafe(|| (handler)(result)));
if handled.is_err() {
let mut callbacks = self.callbacks.lock().unwrap();
callbacks.remove(&id);
false
} else {
true
}
}
Action::None => false,
}
}
}
static CALLBACK_REGISTRY: OnceLock<CallbackRegistry> = OnceLock::new();
fn get_callback_registry() -> &'static CallbackRegistry {
CALLBACK_REGISTRY.get_or_init(CallbackRegistry::new)
}
pub fn get_callback() -> (u64, oneshot::Receiver<CallbackResult>) {
get_callback_registry().register_oneshot()
}
pub fn get_stream_callback() -> (u64, mpsc::Receiver<CallbackResult>) {
get_callback_registry().register_stream()
}
pub fn register_handler<F>(handler: F) -> u64
where
F: Fn(CallbackResult) + Send + Sync + 'static,
{
get_callback_registry().register_handler(handler)
}
pub fn remove_callback(id: u64) -> bool {
get_callback_registry().unregister(id)
}
pub fn invoke_callback(id: u64, result: Result<String, u32>) -> bool {
let cb_result = match result {
Ok(data) => CallbackResult::Success(data),
Err(code) => CallbackResult::Error(code),
};
get_callback_registry().invoke(id, cb_result)
}
#[derive(Debug, Clone)]
pub struct Event {
pub name: String,
pub data: String,
}
struct EventRegistry {
listeners: Mutex<HashMap<String, Vec<mpsc::Sender<Event>>>>,
}
impl EventRegistry {
fn new() -> Self {
Self {
listeners: Mutex::new(HashMap::new()),
}
}
fn subscribe(&self, event_name: String) -> mpsc::Receiver<Event> {
let (sender, receiver) = mpsc::channel(16);
let mut listeners = self.listeners.lock().unwrap();
listeners.entry(event_name).or_default().push(sender);
receiver
}
fn publish(&self, name: &str, data: &str) {
let mut listeners = self.listeners.lock().unwrap();
if let Some(senders) = listeners.get_mut(name) {
let event = Event {
name: name.to_string(),
data: data.to_string(),
};
senders.retain(|sender| {
match sender.try_send(event.clone()) {
Ok(_) => true, Err(mpsc::error::TrySendError::Full(_)) => true, Err(mpsc::error::TrySendError::Closed(_)) => false, }
});
}
}
}
static EVENT_REGISTRY: OnceLock<EventRegistry> = OnceLock::new();
fn get_event_registry() -> &'static EventRegistry {
EVENT_REGISTRY.get_or_init(EventRegistry::new)
}
pub fn subscribe(event_name: String) -> mpsc::Receiver<Event> {
get_event_registry().subscribe(event_name)
}
pub fn publish(name: String, data: String) {
get_event_registry().publish(&name, &data);
}