1use std::collections::HashMap;
5use std::sync::Arc;
6
7use axum::body::Body;
8use axum::http::{Request, Response, StatusCode};
9use axum::middleware::Next;
10use tower::ServiceBuilder;
11use tower_http::trace::TraceLayer;
12
13use crate::container::Container;
14use crate::Error;
15
16pub type MiddlewareFn = Arc<
17 dyn Fn(Request<Body>, Next) -> futures::future::BoxFuture<'static, Result<Response<Body>, Error>>
18 + Send
19 + Sync,
20>;
21
22#[derive(Clone)]
25pub struct NamedMiddleware {
26 pub name: String,
27 pub handler: MiddlewareFn,
28}
29
30#[derive(Default, Clone)]
31pub struct MiddlewareRegistry {
32 middleware: Arc<parking_lot::RwLock<HashMap<String, MiddlewareFn>>>,
33}
34
35impl MiddlewareRegistry {
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn register<F, Fut>(&self, name: impl Into<String>, handler: F)
41 where
42 F: Fn(Request<Body>, Next) -> Fut + Send + Sync + 'static,
43 Fut: std::future::Future<Output = Result<Response<Body>, Error>> + Send + 'static,
44 {
45 let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(handler(req, next)));
46 self.middleware.write().insert(name.into(), wrapped);
47 }
48
49 pub fn get(&self, name: &str) -> Option<MiddlewareFn> {
50 let parsed = MiddlewareSpec::parse(name);
51 self.middleware.read().get(&parsed.name).cloned()
52 }
53
54 pub fn names(&self) -> Vec<String> {
55 self.middleware.read().keys().cloned().collect()
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct MiddlewareSpec {
62 pub name: String,
63 pub args: Vec<String>,
64}
65
66impl MiddlewareSpec {
67 pub fn parse(spec: &str) -> Self {
68 if let Some((name, args)) = spec.split_once(':') {
69 MiddlewareSpec {
70 name: name.to_string(),
71 args: args.split(',').map(|s| s.trim().to_string()).collect(),
72 }
73 } else {
74 MiddlewareSpec {
75 name: spec.to_string(),
76 args: vec![],
77 }
78 }
79 }
80}
81
82pub mod builtin {
84 use super::*;
85 use axum::extract::{FromRequestParts, Request};
86 use axum::http::Method;
87 use rand::RngCore;
88 use tower_sessions::Session;
89
90 pub const CSRF_SESSION_KEY: &str = "_csrf.token";
91 pub const CSRF_HEADER: &str = "x-csrf-token";
92
93 pub async fn ensure_csrf_token(session: &Session) -> Result<String, Error> {
96 if let Some(existing) = session
97 .get::<String>(CSRF_SESSION_KEY)
98 .await
99 .map_err(|e| Error::Internal(e.to_string()))?
100 {
101 return Ok(existing);
102 }
103 let token = generate_csrf_token();
104 session
105 .insert(CSRF_SESSION_KEY, token.clone())
106 .await
107 .map_err(|e| Error::Internal(e.to_string()))?;
108 Ok(token)
109 }
110
111 fn generate_csrf_token() -> String {
112 use base64::engine::Engine;
113 let mut bytes = [0u8; 32];
114 rand::thread_rng().fill_bytes(&mut bytes);
115 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
116 }
117
118 pub async fn csrf(req: Request, next: Next) -> Result<Response<Body>, Error> {
122 let method = req.method().clone();
123 let (mut parts, body) = req.into_parts();
124
125 let session = match Session::from_request_parts(&mut parts, &()).await {
126 Ok(s) => s,
127 Err(_) => {
128 let req = Request::from_parts(parts, body);
131 return Ok(next.run(req).await);
132 }
133 };
134
135 let session_token = ensure_csrf_token(&session).await?;
136
137 if matches!(method, Method::GET | Method::HEAD | Method::OPTIONS) {
139 let req = Request::from_parts(parts, body);
140 return Ok(next.run(req).await);
141 }
142
143 let header_token = parts
145 .headers
146 .get(CSRF_HEADER)
147 .and_then(|v| v.to_str().ok())
148 .map(|s| s.to_string());
149
150 let body_bytes = axum::body::to_bytes(body, 16 * 1024 * 1024)
151 .await
152 .map_err(|e| Error::bad_request(format!("body read failed: {e}")))?;
153
154 let body_token = extract_body_token(&parts, &body_bytes);
155
156 let submitted = header_token.or(body_token);
157
158 if submitted.as_deref() != Some(session_token.as_str()) {
159 return Err(Error::forbidden("CSRF token mismatch"));
160 }
161
162 let req = Request::from_parts(parts, axum::body::Body::from(body_bytes));
163 Ok(next.run(req).await)
164 }
165
166 fn extract_body_token(parts: &axum::http::request::Parts, body: &[u8]) -> Option<String> {
167 let content_type = parts
168 .headers
169 .get(axum::http::header::CONTENT_TYPE)
170 .and_then(|v| v.to_str().ok())
171 .unwrap_or("");
172
173 if content_type.starts_with("application/x-www-form-urlencoded") {
174 let pairs: Vec<(String, String)> =
175 serde_urlencoded::from_bytes(body).unwrap_or_default();
176 return pairs.into_iter().find_map(|(k, v)| (k == "_token").then_some(v));
177 }
178 if content_type.starts_with("application/json") {
179 let value: serde_json::Value = serde_json::from_slice(body).ok()?;
180 return value
181 .get("_token")
182 .and_then(|v| v.as_str())
183 .map(|s| s.to_string());
184 }
185 None
186 }
187
188 pub async fn auth_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
191 Ok(next.run(req).await)
192 }
193
194 pub async fn throttle_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
196 Ok(next.run(req).await)
197 }
198}
199
200pub fn install_defaults(registry: &MiddlewareRegistry) {
201 registry.register("auth", builtin::auth_passthrough);
202 registry.register("csrf", builtin::csrf);
203 registry.register("throttle", builtin::throttle_passthrough);
204}
205
206pub async fn invoke(
209 mw: MiddlewareFn,
210 req: Request<Body>,
211 next: Next,
212) -> Response<Body> {
213 match mw(req, next).await {
214 Ok(resp) => resp,
215 Err(err) => {
216 tracing::error!(?err, "middleware error");
217 axum::response::IntoResponse::into_response((StatusCode::INTERNAL_SERVER_ERROR, err))
218 }
219 }
220}
221
222pub fn trace_layer() -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>> {
224 TraceLayer::new_for_http()
225}
226
227pub async fn inject_container_mw(
230 container: Container,
231 req: Request<Body>,
232 next: Next,
233) -> Response<Body> {
234 crate::container::with_container(container, async move { next.run(req).await }).await
235}
236
237pub fn standard_layers() -> ServiceBuilder<tower::layer::util::Stack<TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>>, tower::layer::util::Identity>> {
238 ServiceBuilder::new().layer(trace_layer())
239}