use std::sync::Arc;
use async_trait::async_trait;
use axum::extract::{Request, State};
use axum::middleware::Next;
use axum::response::Response;
#[async_trait]
pub trait Middleware: Send + Sync + 'static {
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
fn order(&self) -> i32 {
0
}
async fn before_request(&self, req: Request) -> Result<Request, Response> {
Ok(req)
}
async fn after_response(&self, res: Response) -> Response {
res
}
}
#[derive(Clone, Default)]
pub struct MiddlewareStack {
middleware: Vec<Arc<dyn Middleware>>,
}
impl MiddlewareStack {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, mw: Arc<dyn Middleware>) {
self.middleware.push(mw);
}
pub fn extend(&mut self, other: impl IntoIterator<Item = Arc<dyn Middleware>>) {
self.middleware.extend(other);
}
pub fn is_empty(&self) -> bool {
self.middleware.is_empty()
}
pub fn len(&self) -> usize {
self.middleware.len()
}
pub fn apply(mut self, router: axum::Router) -> axum::Router {
if self.middleware.is_empty() {
return router;
}
self.middleware.sort_by_key(|mw| mw.order());
let state = Arc::new(self.middleware);
router.layer(axum::middleware::from_fn_with_state(state, run_stack))
}
}
async fn run_stack(
State(stack): State<Arc<Vec<Arc<dyn Middleware>>>>,
req: Request,
next: Next,
) -> Response {
let mut req_opt = Some(req);
let mut ran = 0usize;
let mut short_circuit: Option<Response> = None;
for mw in stack.iter() {
let req = req_opt
.take()
.expect("request present for each before hook");
match mw.before_request(req).await {
Ok(modified) => {
req_opt = Some(modified);
ran += 1;
}
Err(resp) => {
short_circuit = Some(resp);
break;
}
}
}
let mut res = match short_circuit {
Some(resp) => resp,
None => {
next.run(
req_opt
.take()
.expect("request present when not short-circuited"),
)
.await
}
};
for mw in stack.iter().take(ran).rev() {
res = mw.after_response(res).await;
}
res
}