1use std::collections::HashMap;
5use std::sync::Arc;
6
7use axum::body::Body;
8use axum::http::{Request, Response, StatusCode};
9use axum::middleware::Next;
10use tower::ServiceBuilder;
11use tower_http::trace::TraceLayer;
12
13use crate::container::Container;
14use crate::Error;
15
16pub type MiddlewareFn = Arc<
17 dyn Fn(
18 Request<Body>,
19 Next,
20 ) -> futures::future::BoxFuture<'static, Result<Response<Body>, Error>>
21 + Send
22 + Sync,
23>;
24
25#[derive(Clone)]
28pub struct NamedMiddleware {
29 pub name: String,
30 pub handler: MiddlewareFn,
31}
32
33#[derive(Default, Clone)]
34pub struct MiddlewareRegistry {
35 middleware: Arc<parking_lot::RwLock<HashMap<String, MiddlewareFn>>>,
36}
37
38impl MiddlewareRegistry {
39 pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn register<F, Fut>(&self, name: impl Into<String>, handler: F)
44 where
45 F: Fn(Request<Body>, Next) -> Fut + Send + Sync + 'static,
46 Fut: std::future::Future<Output = Result<Response<Body>, Error>> + Send + 'static,
47 {
48 let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(handler(req, next)));
49 self.middleware.write().insert(name.into(), wrapped);
50 }
51
52 pub fn get(&self, name: &str) -> Option<MiddlewareFn> {
53 let parsed = MiddlewareSpec::parse(name);
54 self.middleware.read().get(&parsed.name).cloned()
55 }
56
57 pub fn names(&self) -> Vec<String> {
58 self.middleware.read().keys().cloned().collect()
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct MiddlewareSpec {
65 pub name: String,
66 pub args: Vec<String>,
67}
68
69impl MiddlewareSpec {
70 pub fn parse(spec: &str) -> Self {
71 if let Some((name, args)) = spec.split_once(':') {
72 MiddlewareSpec {
73 name: name.to_string(),
74 args: args.split(',').map(|s| s.trim().to_string()).collect(),
75 }
76 } else {
77 MiddlewareSpec {
78 name: spec.to_string(),
79 args: vec![],
80 }
81 }
82 }
83}
84
85pub mod builtin {
87 use super::*;
88 use axum::extract::{FromRequestParts, Request};
89 use axum::http::Method;
90 use rand::RngCore;
91 use tower_sessions::Session;
92
93 pub const CSRF_SESSION_KEY: &str = "_csrf.token";
94 pub const CSRF_HEADER: &str = "x-csrf-token";
95
96 pub async fn ensure_csrf_token(session: &Session) -> Result<String, Error> {
99 if let Some(existing) = session
100 .get::<String>(CSRF_SESSION_KEY)
101 .await
102 .map_err(|e| Error::Internal(e.to_string()))?
103 {
104 return Ok(existing);
105 }
106 let token = generate_csrf_token();
107 session
108 .insert(CSRF_SESSION_KEY, token.clone())
109 .await
110 .map_err(|e| Error::Internal(e.to_string()))?;
111 Ok(token)
112 }
113
114 fn generate_csrf_token() -> String {
115 use base64::engine::Engine;
116 let mut bytes = [0u8; 32];
117 rand::thread_rng().fill_bytes(&mut bytes);
118 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
119 }
120
121 pub async fn csrf(req: Request, next: Next) -> Result<Response<Body>, Error> {
125 let method = req.method().clone();
126 let (mut parts, body) = req.into_parts();
127
128 let session = match Session::from_request_parts(&mut parts, &()).await {
129 Ok(s) => s,
130 Err(_) => {
131 let req = Request::from_parts(parts, body);
134 return Ok(next.run(req).await);
135 }
136 };
137
138 let session_token = ensure_csrf_token(&session).await?;
139
140 if matches!(method, Method::GET | Method::HEAD | Method::OPTIONS) {
142 let req = Request::from_parts(parts, body);
143 return Ok(next.run(req).await);
144 }
145
146 let header_token = parts
148 .headers
149 .get(CSRF_HEADER)
150 .and_then(|v| v.to_str().ok())
151 .map(|s| s.to_string());
152
153 let body_bytes = axum::body::to_bytes(body, 16 * 1024 * 1024)
154 .await
155 .map_err(|e| Error::bad_request(format!("body read failed: {e}")))?;
156
157 let body_token = extract_body_token(&parts, &body_bytes);
158
159 let submitted = header_token.or(body_token);
160
161 if submitted.as_deref() != Some(session_token.as_str()) {
162 return Err(Error::forbidden("CSRF token mismatch"));
163 }
164
165 let req = Request::from_parts(parts, axum::body::Body::from(body_bytes));
166 Ok(next.run(req).await)
167 }
168
169 fn extract_body_token(parts: &axum::http::request::Parts, body: &[u8]) -> Option<String> {
170 let content_type = parts
171 .headers
172 .get(axum::http::header::CONTENT_TYPE)
173 .and_then(|v| v.to_str().ok())
174 .unwrap_or("");
175
176 if content_type.starts_with("application/x-www-form-urlencoded") {
177 let pairs: Vec<(String, String)> =
178 serde_urlencoded::from_bytes(body).unwrap_or_default();
179 return pairs
180 .into_iter()
181 .find_map(|(k, v)| (k == "_token").then_some(v));
182 }
183 if content_type.starts_with("application/json") {
184 let value: serde_json::Value = serde_json::from_slice(body).ok()?;
185 return value
186 .get("_token")
187 .and_then(|v| v.as_str())
188 .map(|s| s.to_string());
189 }
190 None
191 }
192
193 pub async fn auth_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
196 Ok(next.run(req).await)
197 }
198
199 pub async fn throttle_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
201 Ok(next.run(req).await)
202 }
203}
204
205pub fn install_defaults(registry: &MiddlewareRegistry) {
206 registry.register("auth", builtin::auth_passthrough);
207 registry.register("csrf", builtin::csrf);
208 registry.register("throttle", builtin::throttle_passthrough);
209}
210
211pub async fn invoke(mw: MiddlewareFn, req: Request<Body>, next: Next) -> Response<Body> {
214 match mw(req, next).await {
215 Ok(resp) => resp,
216 Err(err) => {
217 tracing::error!(?err, "middleware error");
218 axum::response::IntoResponse::into_response((StatusCode::INTERNAL_SERVER_ERROR, err))
219 }
220 }
221}
222
223pub fn trace_layer(
225) -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>>
226{
227 TraceLayer::new_for_http()
228}
229
230pub async fn inject_container_mw(
233 container: Container,
234 req: Request<Body>,
235 next: Next,
236) -> Response<Body> {
237 crate::container::with_container(container, async move { next.run(req).await }).await
238}
239
240pub fn standard_layers() -> ServiceBuilder<
241 tower::layer::util::Stack<
242 TraceLayer<
243 tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
244 >,
245 tower::layer::util::Identity,
246 >,
247> {
248 ServiceBuilder::new().layer(trace_layer())
249}