use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;
use arc_swap::ArcSwap;
use futures_util::future::BoxFuture;
use futures_util::future::join_all;
use once_cell::sync::Lazy;
use scc::HashMap as SccHashMap;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
#[cfg(not(feature = "compio"))]
use tokio::time::timeout;
use crate::types::BuildHasher;
const DEFAULT_BROADCAST_CAPACITY: usize = 64;
static GLOBAL_BROADCAST_CAPACITY: AtomicUsize = AtomicUsize::new(DEFAULT_BROADCAST_CAPACITY);
static EXPORTER_KEY_COUNTER: AtomicU64 = AtomicU64::new(0);
type HandlerList = Arc<ArcSwap<Vec<SignalHandler>>>;
pub mod ids {
pub const SERVER_STARTED: &str = "server.started";
pub const SERVER_STOPPED: &str = "server.stopped";
pub const CONNECTION_OPENED: &str = "connection.opened";
pub const CONNECTION_CLOSED: &str = "connection.closed";
pub const REQUEST_STARTED: &str = "request.started";
pub const REQUEST_COMPLETED: &str = "request.completed";
pub const ROUTER_HOT_RELOAD: &str = "router.hot_reload";
pub const RPC_ERROR: &str = "rpc.error";
pub const ROUTE_REQUEST_STARTED: &str = "route.request.started";
pub const ROUTE_REQUEST_COMPLETED: &str = "route.request.completed";
}
pub mod bus {
use async_trait::async_trait;
use super::Signal;
#[async_trait]
pub trait SignalBus: Send + Sync + 'static {
async fn publish(&self, signal: &Signal);
}
#[derive(Clone, Default)]
pub struct LocalBus;
#[async_trait]
impl SignalBus for LocalBus {
async fn publish(&self, _signal: &Signal) {}
}
}
#[derive(Clone, Debug, Default)]
pub struct Signal {
pub id: String,
pub metadata: HashMap<String, String, BuildHasher>,
}
impl Signal {
#[inline]
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
metadata: HashMap::with_hasher(BuildHasher::default()),
}
}
#[inline]
#[must_use]
pub fn with_capacity(id: impl Into<String>, capacity: usize) -> Self {
Self {
id: id.into(),
metadata: HashMap::with_capacity_and_hasher(capacity, BuildHasher::default()),
}
}
#[inline]
pub fn meta(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
#[inline]
#[must_use]
pub fn with_metadata(
id: impl Into<String>,
metadata: HashMap<String, String, BuildHasher>,
) -> Self {
Self {
id: id.into(),
metadata,
}
}
#[inline]
#[must_use]
pub fn from_payload<P: SignalPayload>(payload: &P) -> Self {
Self {
id: payload.id().to_string(),
metadata: payload.to_metadata(),
}
}
}
pub trait SignalPayload {
fn id(&self) -> &'static str;
fn to_metadata(&self) -> HashMap<String, String, BuildHasher>;
}
pub type SignalHandler = Arc<dyn Fn(Signal) -> BoxFuture<'static, ()> + Send + Sync>;
pub type RpcHandler = Arc<
dyn Fn(Arc<dyn Any + Send + Sync>) -> BoxFuture<'static, Arc<dyn Any + Send + Sync>>
+ Send
+ Sync,
>;
pub type SignalExporter = Arc<dyn Fn(&Signal) + Send + Sync>;
pub type SignalStream = mpsc::Receiver<Signal>;
pub const FILTERED_SUBSCRIPTION_BUFFER: usize = 1024;
pub const MAX_BROADCAST_CAPACITY: usize = 1 << 20;
#[derive(Default)]
struct Inner {
handlers: SccHashMap<String, HandlerList>,
topics: SccHashMap<String, broadcast::Sender<Signal>>,
rpc: SccHashMap<String, RpcHandler>,
exporters: SccHashMap<u64, SignalExporter>,
}
fn new_handler_list() -> HandlerList {
Arc::new(ArcSwap::new(Arc::new(Vec::new())))
}
#[derive(Clone, Default)]
pub struct SignalArbiter {
inner: Arc<Inner>,
}
static APP_SIGNAL_ARBITER: Lazy<SignalArbiter> = Lazy::new(SignalArbiter::new);
pub fn app_signals() -> &'static SignalArbiter {
&APP_SIGNAL_ARBITER
}
pub fn app_events() -> &'static SignalArbiter {
app_signals()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RpcError {
NoHandler,
TypeMismatch,
}
impl std::fmt::Display for RpcError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoHandler => write!(f, "no handler registered for RPC method"),
Self::TypeMismatch => write!(f, "RPC response type mismatch"),
}
}
}
impl std::error::Error for RpcError {}
pub type RpcResult<T> = Result<T, RpcError>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RpcTimeoutError {
Timeout,
Rpc(RpcError),
}
impl std::fmt::Display for RpcTimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Timeout => write!(f, "RPC call timed out"),
Self::Rpc(err) => write!(f, "{err}"),
}
}
}
impl std::error::Error for RpcTimeoutError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Rpc(err) => Some(err),
Self::Timeout => None,
}
}
}
impl From<RpcError> for RpcTimeoutError {
#[inline]
fn from(err: RpcError) -> Self {
Self::Rpc(err)
}
}
impl SignalArbiter {
pub fn new() -> Self {
Self::default()
}
pub fn set_global_broadcast_capacity(capacity: usize) {
let cap = capacity.clamp(1, MAX_BROADCAST_CAPACITY);
GLOBAL_BROADCAST_CAPACITY.store(cap, Ordering::SeqCst);
}
pub fn global_broadcast_capacity() -> usize {
GLOBAL_BROADCAST_CAPACITY.load(Ordering::SeqCst)
}
pub(crate) fn topic_sender(&self, id: &str) -> broadcast::Sender<Signal> {
if let Some(existing) = self.inner.topics.get_sync(id) {
existing.clone()
} else {
let cap = GLOBAL_BROADCAST_CAPACITY.load(Ordering::SeqCst);
let (tx, _rx) = broadcast::channel(cap);
let entry = self.inner.topics.entry_sync(id.to_string()).or_insert(tx);
entry.clone()
}
}
pub fn on<F, Fut>(&self, id: impl Into<String>, handler: F)
where
F: Fn(Signal) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let id = id.into();
let handler: SignalHandler = Arc::new(move |signal: Signal| {
let fut = handler(signal);
Box::pin(fut)
});
let list = self.handler_list_for(id);
list.rcu(|current| {
let mut next = Vec::with_capacity(current.len() + 1);
next.extend(current.iter().cloned());
next.push(handler.clone());
Arc::new(next)
});
}
fn handler_list_for(&self, id: String) -> HandlerList {
let entry = self
.inner
.handlers
.entry_sync(id)
.or_insert_with(new_handler_list);
entry.clone()
}
pub fn subscribe(&self, id: impl AsRef<str>) -> broadcast::Receiver<Signal> {
let id_str = id.as_ref();
let sender = self.topic_sender(id_str);
sender.subscribe()
}
pub fn subscribe_prefix(&self, prefix: impl AsRef<str>) -> broadcast::Receiver<Signal> {
let mut key = prefix.as_ref().to_string();
if !key.ends_with('*') {
key.push('*');
}
let sender = self.topic_sender(&key);
sender.subscribe()
}
pub fn subscribe_all(&self) -> broadcast::Receiver<Signal> {
self.subscribe_prefix("")
}
pub(crate) fn broadcast(&self, signal: Signal) {
if let Some(sender) = self.inner.topics.get_sync(&signal.id) {
let _ = sender.send(signal.clone());
}
let mut targets: Vec<broadcast::Sender<Signal>> = Vec::new();
self.inner.topics.iter_sync(|key, v| {
if let Some(prefix) = key.strip_suffix('*')
&& signal.id.starts_with(prefix)
{
targets.push(v.clone());
}
true
});
for sender in targets {
let _ = sender.send(signal.clone());
}
}
pub fn subscribe_filtered<F>(&self, id: impl AsRef<str>, filter: F) -> SignalStream
where
F: Fn(&Signal) -> bool + Send + Sync + 'static,
{
let mut rx = self.subscribe(id);
let (tx, out_rx) = mpsc::channel(FILTERED_SUBSCRIPTION_BUFFER);
let filter = Arc::new(filter);
#[cfg(not(feature = "compio"))]
tokio::spawn(async move {
while let Ok(signal) = rx.recv().await {
if filter(&signal) {
match tx.try_send(signal) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
}
Err(mpsc::error::TrySendError::Closed(_)) => break,
}
}
}
});
#[cfg(feature = "compio")]
compio::runtime::spawn(async move {
while let Ok(signal) = rx.recv().await {
if filter(&signal) {
match tx.try_send(signal) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {}
Err(mpsc::error::TrySendError::Closed(_)) => break,
}
}
}
})
.detach();
out_rx
}
pub async fn once(&self, id: impl AsRef<str>) -> Option<Signal> {
let mut rx = self.subscribe(id);
loop {
match rx.recv().await {
Ok(sig) => return Some(sig),
Err(broadcast::error::RecvError::Lagged(_)) => {}
Err(_) => return None,
}
}
}
pub fn register_rpc<Req, Res, F, Fut>(&self, id: impl Into<String>, f: F)
where
Req: Send + Sync + 'static,
Res: Send + Sync + 'static,
F: Fn(Arc<Req>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Res> + Send + 'static,
{
let id_str = id.into();
let id_for_panic = id_str.clone();
let func = Arc::new(f);
let handler: RpcHandler = Arc::new(move |raw: Arc<dyn Any + Send + Sync>| {
let func = func.clone();
let id_for_panic = id_for_panic.clone();
Box::pin(async move {
let req = raw
.downcast::<Req>()
.unwrap_or_else(|_| panic!("Signal RPC type mismatch for id: {id_for_panic}"));
let res = func(req).await;
Arc::new(res) as Arc<dyn Any + Send + Sync>
})
});
self.inner.rpc.upsert_sync(id_str, handler);
}
pub async fn call_rpc_arc<Req, Res>(&self, id: impl AsRef<str>, req: Req) -> Option<Arc<Res>>
where
Req: Send + Sync + 'static,
Res: Send + Sync + 'static,
{
let id_str = id.as_ref();
let entry = self.inner.rpc.get_async(id_str).await?;
let handler = entry.clone();
drop(entry);
let raw_req: Arc<dyn Any + Send + Sync> = Arc::new(req);
let raw_res = handler(raw_req).await;
raw_res.downcast::<Res>().ok()
}
pub async fn call_rpc_result<Req, Res>(&self, id: impl AsRef<str>, req: Req) -> RpcResult<Res>
where
Req: Send + Sync + 'static,
Res: Send + Sync + Clone + 'static,
{
let id_str = id.as_ref();
let Some(entry) = self.inner.rpc.get_async(id_str).await else {
return Err(RpcError::NoHandler);
};
let handler = entry.clone();
drop(entry);
let raw_req: Arc<dyn Any + Send + Sync> = Arc::new(req);
let raw_res = handler(raw_req).await;
match raw_res.downcast::<Res>() {
Ok(res) => Ok((*res).clone()),
Err(_) => Err(RpcError::TypeMismatch),
}
}
pub async fn call_rpc<Req, Res>(&self, id: impl AsRef<str>, req: Req) -> Option<Res>
where
Req: Send + Sync + 'static,
Res: Send + Sync + Clone + 'static,
{
self.call_rpc_result::<Req, Res>(id, req).await.ok()
}
#[cfg(not(feature = "compio"))]
pub async fn call_rpc_timeout<Req, Res>(
&self,
id: impl AsRef<str>,
req: Req,
dur: Duration,
) -> Result<Res, RpcTimeoutError>
where
Req: Send + Sync + 'static,
Res: Send + Sync + Clone + 'static,
{
match timeout(dur, self.call_rpc_result::<Req, Res>(id, req)).await {
Ok(Ok(res)) => Ok(res),
Ok(Err(e)) => Err(RpcTimeoutError::Rpc(e)),
Err(_) => Err(RpcTimeoutError::Timeout),
}
}
#[cfg(feature = "compio")]
pub async fn call_rpc_timeout<Req, Res>(
&self,
id: impl AsRef<str>,
req: Req,
dur: Duration,
) -> Result<Res, RpcTimeoutError>
where
Req: Send + Sync + 'static,
Res: Send + Sync + Clone + 'static,
{
let sleep = std::pin::pin!(compio::time::sleep(dur));
let work = std::pin::pin!(self.call_rpc_result::<Req, Res>(id, req));
match futures_util::future::select(work, sleep).await {
futures_util::future::Either::Left((Ok(res), _)) => Ok(res),
futures_util::future::Either::Left((Err(e), _)) => Err(RpcTimeoutError::Rpc(e)),
futures_util::future::Either::Right(((), _)) => Err(RpcTimeoutError::Timeout),
}
}
pub async fn emit(&self, signal: Signal) {
self.broadcast(signal.clone());
self
.inner
.exporters
.iter_async(|_, v| {
v(&signal);
true
})
.await;
if let Some(entry) = self.inner.handlers.get_async(&signal.id).await {
let list = entry.clone();
drop(entry);
let handlers = list.load_full();
let futures = handlers.iter().map(|handler| {
let s = signal.clone();
handler(s)
});
let _ = join_all(futures).await;
}
}
pub async fn emit_app(signal: Signal) {
app_signals().emit(signal).await;
}
pub fn register_exporter<F>(&self, exporter: F)
where
F: Fn(&Signal) + Send + Sync + 'static,
{
let key = EXPORTER_KEY_COUNTER.fetch_add(1, Ordering::Relaxed);
let exporter: SignalExporter = Arc::new(exporter);
self.inner.exporters.upsert_sync(key, exporter);
}
pub(crate) fn merge_from(&self, other: &SignalArbiter) {
other.inner.handlers.iter_sync(|k, other_list| {
let other_handlers = other_list.load_full();
if other_handlers.is_empty() {
return true;
}
let target_list = self.handler_list_for(k.clone());
target_list.rcu(|current| {
let mut next = Vec::with_capacity(current.len() + other_handlers.len());
next.extend(current.iter().cloned());
next.extend(other_handlers.iter().cloned());
Arc::new(next)
});
true
});
other.inner.topics.iter_sync(|k, v| {
self.inner.topics.entry_sync(k.clone()).or_insert(v.clone());
true
});
other.inner.rpc.iter_sync(|k, v| {
self.inner.rpc.upsert_sync(k.clone(), v.clone());
true
});
other.inner.exporters.iter_sync(|k, v| {
self.inner.exporters.upsert_sync(*k, v.clone());
true
});
}
pub fn signal_ids(&self) -> Vec<String> {
let mut ids = Vec::new();
self.inner.topics.iter_sync(|k, _| {
if !k.ends_with('*') {
ids.push(k.clone());
}
true
});
ids
}
pub fn signal_prefixes(&self) -> Vec<String> {
let mut prefixes = Vec::new();
self.inner.topics.iter_sync(|k, _| {
if k.ends_with('*') {
prefixes.push(k.clone());
}
true
});
prefixes
}
pub fn rpc_ids(&self) -> Vec<String> {
let mut ids = Vec::new();
self.inner.rpc.iter_sync(|k, _| {
ids.push(k.clone());
true
});
ids
}
}
pub mod transport {
use super::Signal;
use super::SignalArbiter;
use super::ids;
pub async fn emit_server_started(addr: &str, transport: &str, tls: bool) {
SignalArbiter::emit_app(
Signal::with_capacity(ids::SERVER_STARTED, 3)
.meta("addr", addr)
.meta("transport", transport)
.meta("tls", if tls { "true" } else { "false" }),
)
.await;
}
pub async fn emit_connection_opened(remote_addr: &str, tls: bool, protocol: Option<&str>) {
let mut sig = Signal::with_capacity(ids::CONNECTION_OPENED, 3)
.meta("remote_addr", remote_addr)
.meta("tls", if tls { "true" } else { "false" });
if let Some(p) = protocol {
sig = sig.meta("protocol", p);
}
SignalArbiter::emit_app(sig).await;
}
pub async fn emit_connection_closed(remote_addr: &str, tls: bool, protocol: Option<&str>) {
let mut sig = Signal::with_capacity(ids::CONNECTION_CLOSED, 3)
.meta("remote_addr", remote_addr)
.meta("tls", if tls { "true" } else { "false" });
if let Some(p) = protocol {
sig = sig.meta("protocol", p);
}
SignalArbiter::emit_app(sig).await;
}
}