1use std::collections::HashMap;
5use std::sync::Arc;
6
7use axum::body::Body;
8use axum::http::{Request, Response, StatusCode};
9use axum::middleware::Next;
10use tower::ServiceBuilder;
11use tower_http::trace::TraceLayer;
12
13use crate::container::Container;
14use crate::Error;
15
16pub type MiddlewareFn = Arc<
17 dyn Fn(Request<Body>, Next) -> futures::future::BoxFuture<'static, Result<Response<Body>, Error>>
18 + Send
19 + Sync,
20>;
21
22#[derive(Clone)]
25pub struct NamedMiddleware {
26 pub name: String,
27 pub handler: MiddlewareFn,
28}
29
30#[derive(Default, Clone)]
31pub struct MiddlewareRegistry {
32 middleware: Arc<parking_lot::RwLock<HashMap<String, MiddlewareFn>>>,
33}
34
35impl MiddlewareRegistry {
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn register<F, Fut>(&self, name: impl Into<String>, handler: F)
41 where
42 F: Fn(Request<Body>, Next) -> Fut + Send + Sync + 'static,
43 Fut: std::future::Future<Output = Result<Response<Body>, Error>> + Send + 'static,
44 {
45 let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(handler(req, next)));
46 self.middleware.write().insert(name.into(), wrapped);
47 }
48
49 pub fn get(&self, name: &str) -> Option<MiddlewareFn> {
50 let parsed = MiddlewareSpec::parse(name);
51 self.middleware.read().get(&parsed.name).cloned()
52 }
53
54 pub fn names(&self) -> Vec<String> {
55 self.middleware.read().keys().cloned().collect()
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct MiddlewareSpec {
62 pub name: String,
63 pub args: Vec<String>,
64}
65
66impl MiddlewareSpec {
67 pub fn parse(spec: &str) -> Self {
68 if let Some((name, args)) = spec.split_once(':') {
69 MiddlewareSpec {
70 name: name.to_string(),
71 args: args.split(',').map(|s| s.trim().to_string()).collect(),
72 }
73 } else {
74 MiddlewareSpec {
75 name: spec.to_string(),
76 args: vec![],
77 }
78 }
79 }
80}
81
82pub mod builtin {
84 use super::*;
85 use axum::extract::Request;
86
87 pub async fn auth_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
90 Ok(next.run(req).await)
91 }
92
93 pub async fn csrf_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
96 Ok(next.run(req).await)
97 }
98
99 pub async fn throttle_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
101 Ok(next.run(req).await)
102 }
103}
104
105pub fn install_defaults(registry: &MiddlewareRegistry) {
106 registry.register("auth", builtin::auth_passthrough);
107 registry.register("csrf", builtin::csrf_passthrough);
108 registry.register("throttle", builtin::throttle_passthrough);
109}
110
111pub async fn invoke(
114 mw: MiddlewareFn,
115 req: Request<Body>,
116 next: Next,
117) -> Response<Body> {
118 match mw(req, next).await {
119 Ok(resp) => resp,
120 Err(err) => {
121 tracing::error!(?err, "middleware error");
122 axum::response::IntoResponse::into_response((StatusCode::INTERNAL_SERVER_ERROR, err))
123 }
124 }
125}
126
127pub fn trace_layer() -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>> {
129 TraceLayer::new_for_http()
130}
131
132pub async fn inject_container_mw(
135 container: Container,
136 req: Request<Body>,
137 next: Next,
138) -> Response<Body> {
139 crate::container::with_container(container, async move { next.run(req).await }).await
140}
141
142pub fn standard_layers() -> ServiceBuilder<tower::layer::util::Stack<TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>>, tower::layer::util::Identity>> {
143 ServiceBuilder::new().layer(trace_layer())
144}