use std::any::Any;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock, RwLock};
pub type ReceiverFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ReceiverId(u64);
#[derive(Debug, Clone, Default)]
pub struct AuthRequestMeta {
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub path: Option<String>,
}
#[derive(Debug, Clone)]
pub struct UserLoggedInContext {
pub source: &'static str,
pub user_id: i64,
pub username: String,
pub is_superuser: bool,
pub request: AuthRequestMeta,
}
#[derive(Debug, Clone)]
pub struct UserLoggedOutContext {
pub source: &'static str,
pub user_id: Option<i64>,
pub username: Option<String>,
pub request: AuthRequestMeta,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AuthFailureReason {
InvalidCredentials,
Inactive,
Locked,
Other,
}
#[derive(Debug, Clone)]
pub struct UserLoginFailedContext {
pub source: &'static str,
pub attempted_username: Option<String>,
pub reason: AuthFailureReason,
pub request: AuthRequestMeta,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum SignalKind {
LoggedIn,
LoggedOut,
LoginFailed,
}
type ReceiverEntry = (ReceiverId, Box<dyn Any + Send + Sync>);
type Bag = Vec<ReceiverEntry>;
fn registry() -> &'static RwLock<HashMap<SignalKind, Bag>> {
static REG: OnceLock<RwLock<HashMap<SignalKind, Bag>>> = OnceLock::new();
REG.get_or_init(|| RwLock::new(HashMap::new()))
}
fn next_id() -> ReceiverId {
static COUNTER: AtomicU64 = AtomicU64::new(1);
ReceiverId(COUNTER.fetch_add(1, Ordering::Relaxed))
}
fn insert_receiver<R: Any + Send + Sync>(kind: SignalKind, receiver: R) -> ReceiverId {
let id = next_id();
let mut reg = registry().write().unwrap_or_else(|e| e.into_inner());
reg.entry(kind).or_default().push((id, Box::new(receiver)));
id
}
fn remove_receiver(kind: SignalKind, id: ReceiverId) -> bool {
let mut reg = registry().write().unwrap_or_else(|e| e.into_inner());
let Some(bag) = reg.get_mut(&kind) else {
return false;
};
let before = bag.len();
bag.retain(|(rid, _)| *rid != id);
bag.len() != before
}
fn snapshot<R: Any + Send + Sync + Clone>(kind: SignalKind) -> Vec<R> {
let reg = registry().read().unwrap_or_else(|e| e.into_inner());
let Some(bag) = reg.get(&kind) else {
return Vec::new();
};
bag.iter()
.filter_map(|(_, b)| b.downcast_ref::<R>().cloned())
.collect()
}
type LoggedInReceiver = Arc<dyn Fn(UserLoggedInContext) -> ReceiverFuture + Send + Sync>;
type LoggedOutReceiver = Arc<dyn Fn(UserLoggedOutContext) -> ReceiverFuture + Send + Sync>;
type LoginFailedReceiver = Arc<dyn Fn(UserLoginFailedContext) -> ReceiverFuture + Send + Sync>;
pub fn connect_user_logged_in<F, Fut>(receiver: F) -> ReceiverId
where
F: Fn(UserLoggedInContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let boxed: LoggedInReceiver = Arc::new(move |ctx| Box::pin(receiver(ctx)));
insert_receiver(SignalKind::LoggedIn, boxed)
}
pub fn disconnect_user_logged_in(id: ReceiverId) -> bool {
remove_receiver(SignalKind::LoggedIn, id)
}
pub async fn send_user_logged_in(ctx: UserLoggedInContext) {
let receivers: Vec<LoggedInReceiver> = snapshot(SignalKind::LoggedIn);
for r in receivers {
r(ctx.clone()).await;
}
}
pub fn connect_user_logged_out<F, Fut>(receiver: F) -> ReceiverId
where
F: Fn(UserLoggedOutContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let boxed: LoggedOutReceiver = Arc::new(move |ctx| Box::pin(receiver(ctx)));
insert_receiver(SignalKind::LoggedOut, boxed)
}
pub fn disconnect_user_logged_out(id: ReceiverId) -> bool {
remove_receiver(SignalKind::LoggedOut, id)
}
pub async fn send_user_logged_out(ctx: UserLoggedOutContext) {
let receivers: Vec<LoggedOutReceiver> = snapshot(SignalKind::LoggedOut);
for r in receivers {
r(ctx.clone()).await;
}
}
pub fn connect_user_login_failed<F, Fut>(receiver: F) -> ReceiverId
where
F: Fn(UserLoginFailedContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let boxed: LoginFailedReceiver = Arc::new(move |ctx| Box::pin(receiver(ctx)));
insert_receiver(SignalKind::LoginFailed, boxed)
}
pub fn disconnect_user_login_failed(id: ReceiverId) -> bool {
remove_receiver(SignalKind::LoginFailed, id)
}
pub async fn send_user_login_failed(ctx: UserLoginFailedContext) {
let receivers: Vec<LoginFailedReceiver> = snapshot(SignalKind::LoginFailed);
for r in receivers {
r(ctx.clone()).await;
}
}
pub fn meta_from_headers(headers: &axum::http::HeaderMap, path: Option<&str>) -> AuthRequestMeta {
let header_str = |name: &str| {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
};
let ip_address = header_str("x-real-ip").or_else(|| {
header_str("x-forwarded-for").and_then(|s| s.split(',').next().map(|f| f.trim().to_owned()))
});
let user_agent = header_str("user-agent");
AuthRequestMeta {
ip_address,
user_agent,
path: path.map(str::to_owned),
}
}
pub fn clear_all() {
registry()
.write()
.unwrap_or_else(|e| e.into_inner())
.clear();
}
pub fn receiver_count() -> usize {
let reg = registry().read().unwrap_or_else(|e| e.into_inner());
[
SignalKind::LoggedIn,
SignalKind::LoggedOut,
SignalKind::LoginFailed,
]
.iter()
.map(|kind| reg.get(kind).map_or(0, Vec::len))
.sum()
}