use std::any::Any;
use std::collections::HashMap;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock, RwLock};
use std::task::{Context, Poll};
use std::time::Instant;
use axum::body::Body;
use axum::http::{Request, Response, StatusCode};
use tower::Service;
pub type ReceiverFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ReceiverId(u64);
#[derive(Debug, Clone)]
pub struct RequestStartedContext {
pub method: String,
pub path: String,
pub query: String,
}
#[derive(Debug, Clone)]
pub struct RequestFinishedContext {
pub method: String,
pub path: String,
pub status: u16,
pub elapsed_ms: f64,
}
#[derive(Debug, Clone)]
pub struct RequestExceptionContext {
pub method: String,
pub path: String,
pub error: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum SignalKind {
Started,
Finished,
Exception,
}
type ReceiverEntry = (ReceiverId, Box<dyn Any + Send + Sync>);
type Bag = Vec<ReceiverEntry>;
fn registry() -> &'static RwLock<HashMap<SignalKind, Bag>> {
static REG: OnceLock<RwLock<HashMap<SignalKind, Bag>>> = OnceLock::new();
REG.get_or_init(|| RwLock::new(HashMap::new()))
}
fn next_id() -> ReceiverId {
static COUNTER: AtomicU64 = AtomicU64::new(1);
ReceiverId(COUNTER.fetch_add(1, Ordering::Relaxed))
}
fn insert_receiver<R: Any + Send + Sync>(kind: SignalKind, receiver: R) -> ReceiverId {
let id = next_id();
let mut reg = registry().write().unwrap_or_else(|e| e.into_inner());
reg.entry(kind).or_default().push((id, Box::new(receiver)));
id
}
fn remove_receiver(kind: SignalKind, id: ReceiverId) -> bool {
let mut reg = registry().write().unwrap_or_else(|e| e.into_inner());
let Some(bag) = reg.get_mut(&kind) else {
return false;
};
let before = bag.len();
bag.retain(|(rid, _)| *rid != id);
bag.len() != before
}
fn snapshot<R: Any + Send + Sync + Clone>(kind: SignalKind) -> Vec<R> {
let reg = registry().read().unwrap_or_else(|e| e.into_inner());
let Some(bag) = reg.get(&kind) else {
return Vec::new();
};
bag.iter()
.filter_map(|(_, b)| b.downcast_ref::<R>().cloned())
.collect()
}
type StartedReceiver = Arc<dyn Fn(RequestStartedContext) -> ReceiverFuture + Send + Sync>;
type FinishedReceiver = Arc<dyn Fn(RequestFinishedContext) -> ReceiverFuture + Send + Sync>;
type ExceptionReceiver = Arc<dyn Fn(RequestExceptionContext) -> ReceiverFuture + Send + Sync>;
pub fn connect_request_started<F, Fut>(receiver: F) -> ReceiverId
where
F: Fn(RequestStartedContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let boxed: StartedReceiver = Arc::new(move |ctx| Box::pin(receiver(ctx)));
insert_receiver(SignalKind::Started, boxed)
}
pub fn disconnect_request_started(id: ReceiverId) -> bool {
remove_receiver(SignalKind::Started, id)
}
pub async fn send_request_started(ctx: RequestStartedContext) {
let receivers: Vec<StartedReceiver> = snapshot(SignalKind::Started);
for r in receivers {
r(ctx.clone()).await;
}
}
pub fn connect_request_finished<F, Fut>(receiver: F) -> ReceiverId
where
F: Fn(RequestFinishedContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let boxed: FinishedReceiver = Arc::new(move |ctx| Box::pin(receiver(ctx)));
insert_receiver(SignalKind::Finished, boxed)
}
pub fn disconnect_request_finished(id: ReceiverId) -> bool {
remove_receiver(SignalKind::Finished, id)
}
pub async fn send_request_finished(ctx: RequestFinishedContext) {
let receivers: Vec<FinishedReceiver> = snapshot(SignalKind::Finished);
for r in receivers {
r(ctx.clone()).await;
}
}
pub fn connect_got_request_exception<F, Fut>(receiver: F) -> ReceiverId
where
F: Fn(RequestExceptionContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let boxed: ExceptionReceiver = Arc::new(move |ctx| Box::pin(receiver(ctx)));
insert_receiver(SignalKind::Exception, boxed)
}
pub fn disconnect_got_request_exception(id: ReceiverId) -> bool {
remove_receiver(SignalKind::Exception, id)
}
pub async fn send_got_request_exception(ctx: RequestExceptionContext) {
let receivers: Vec<ExceptionReceiver> = snapshot(SignalKind::Exception);
for r in receivers {
r(ctx.clone()).await;
}
}
pub fn clear_all() {
registry()
.write()
.unwrap_or_else(|e| e.into_inner())
.clear();
}
#[must_use]
pub fn receiver_count() -> usize {
let reg = registry().read().unwrap_or_else(|e| e.into_inner());
[
SignalKind::Started,
SignalKind::Finished,
SignalKind::Exception,
]
.iter()
.map(|k| reg.get(k).map_or(0, Vec::len))
.sum()
}
#[derive(Clone, Default, Debug)]
pub struct RequestSignalsLayer;
impl RequestSignalsLayer {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl<S> tower::Layer<S> for RequestSignalsLayer {
type Service = RequestSignalsService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestSignalsService { inner }
}
}
#[derive(Clone)]
pub struct RequestSignalsService<S> {
inner: S,
}
impl<S> Service<Request<Body>> for RequestSignalsService<S>
where
S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Body>;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Response<Body>, Infallible>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
let method = req.method().as_str().to_owned();
let path = req.uri().path().to_owned();
let query = req.uri().query().unwrap_or_default().to_owned();
Box::pin(async move {
let started_at = Instant::now();
send_request_started(RequestStartedContext {
method: method.clone(),
path: path.clone(),
query,
})
.await;
match inner.call(req).await {
Ok(resp) => {
let elapsed_ms = (started_at.elapsed().as_micros() as f64) / 1000.0;
let status = resp.status().as_u16();
send_request_finished(RequestFinishedContext {
method,
path,
status,
elapsed_ms,
})
.await;
Ok(resp)
}
Err(_unreachable) => {
#[allow(unreachable_code)]
{
send_got_request_exception(RequestExceptionContext {
method,
path,
error: "inner service returned Err".to_owned(),
})
.await;
Ok(error_response())
}
}
}
})
}
}
#[allow(dead_code)]
fn error_response() -> Response<Body> {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Internal Server Error"))
.expect("static response builder")
}