Skip to main content

anvil_core/
middleware.rs

1//! Named middleware registry. Strings like `"auth"`, `"throttle:60,1"`, `"csrf"`
2//! are resolved at app-init time to tower `Layer`s.
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use axum::body::Body;
8use axum::http::{Request, Response, StatusCode};
9use axum::middleware::Next;
10use tower::ServiceBuilder;
11use tower_http::trace::TraceLayer;
12
13use crate::container::Container;
14use crate::Error;
15
16pub type MiddlewareFn = Arc<
17    dyn Fn(Request<Body>, Next) -> futures::future::BoxFuture<'static, Result<Response<Body>, Error>>
18        + Send
19        + Sync,
20>;
21
22/// A named middleware: a function that takes a request, optionally consumes args,
23/// and returns a response.
24#[derive(Clone)]
25pub struct NamedMiddleware {
26    pub name: String,
27    pub handler: MiddlewareFn,
28}
29
30#[derive(Default, Clone)]
31pub struct MiddlewareRegistry {
32    middleware: Arc<parking_lot::RwLock<HashMap<String, MiddlewareFn>>>,
33}
34
35impl MiddlewareRegistry {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    pub fn register<F, Fut>(&self, name: impl Into<String>, handler: F)
41    where
42        F: Fn(Request<Body>, Next) -> Fut + Send + Sync + 'static,
43        Fut: std::future::Future<Output = Result<Response<Body>, Error>> + Send + 'static,
44    {
45        let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(handler(req, next)));
46        self.middleware.write().insert(name.into(), wrapped);
47    }
48
49    pub fn get(&self, name: &str) -> Option<MiddlewareFn> {
50        let parsed = MiddlewareSpec::parse(name);
51        self.middleware.read().get(&parsed.name).cloned()
52    }
53
54    pub fn names(&self) -> Vec<String> {
55        self.middleware.read().keys().cloned().collect()
56    }
57}
58
59/// Parsed form of `"throttle:60,1"` → `MiddlewareSpec { name: "throttle", args: ["60", "1"] }`.
60#[derive(Debug, Clone)]
61pub struct MiddlewareSpec {
62    pub name: String,
63    pub args: Vec<String>,
64}
65
66impl MiddlewareSpec {
67    pub fn parse(spec: &str) -> Self {
68        if let Some((name, args)) = spec.split_once(':') {
69            MiddlewareSpec {
70                name: name.to_string(),
71                args: args.split(',').map(|s| s.trim().to_string()).collect(),
72            }
73        } else {
74            MiddlewareSpec {
75                name: spec.to_string(),
76                args: vec![],
77            }
78        }
79    }
80}
81
82/// Built-in middleware: installed on the registry during bootstrap.
83pub mod builtin {
84    use super::*;
85    use axum::extract::{FromRequestParts, Request};
86    use axum::http::Method;
87    use rand::RngCore;
88    use tower_sessions::Session;
89
90    pub const CSRF_SESSION_KEY: &str = "_csrf.token";
91    pub const CSRF_HEADER: &str = "x-csrf-token";
92
93    /// Read the current CSRF token from the session, generating one if missing.
94    /// Used by templates (`@csrf` directive) and the CSRF middleware itself.
95    pub async fn ensure_csrf_token(session: &Session) -> Result<String, Error> {
96        if let Some(existing) = session
97            .get::<String>(CSRF_SESSION_KEY)
98            .await
99            .map_err(|e| Error::Internal(e.to_string()))?
100        {
101            return Ok(existing);
102        }
103        let token = generate_csrf_token();
104        session
105            .insert(CSRF_SESSION_KEY, token.clone())
106            .await
107            .map_err(|e| Error::Internal(e.to_string()))?;
108        Ok(token)
109    }
110
111    fn generate_csrf_token() -> String {
112        use base64::engine::Engine;
113        let mut bytes = [0u8; 32];
114        rand::thread_rng().fill_bytes(&mut bytes);
115        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
116    }
117
118    /// Real CSRF middleware: ensures every session has a token; verifies
119    /// state-changing requests carry a matching token via `_token` form field
120    /// or `X-CSRF-TOKEN` header.
121    pub async fn csrf(req: Request, next: Next) -> Result<Response<Body>, Error> {
122        let method = req.method().clone();
123        let (mut parts, body) = req.into_parts();
124
125        let session = match Session::from_request_parts(&mut parts, &()).await {
126            Ok(s) => s,
127            Err(_) => {
128                // No session installed — request passes through unchallenged.
129                // Apps that want CSRF protection must add the session layer.
130                let req = Request::from_parts(parts, body);
131                return Ok(next.run(req).await);
132            }
133        };
134
135        let session_token = ensure_csrf_token(&session).await?;
136
137        // Safe methods don't need verification.
138        if matches!(method, Method::GET | Method::HEAD | Method::OPTIONS) {
139            let req = Request::from_parts(parts, body);
140            return Ok(next.run(req).await);
141        }
142
143        // Look for the token in headers first, then body.
144        let header_token = parts
145            .headers
146            .get(CSRF_HEADER)
147            .and_then(|v| v.to_str().ok())
148            .map(|s| s.to_string());
149
150        let body_bytes = axum::body::to_bytes(body, 16 * 1024 * 1024)
151            .await
152            .map_err(|e| Error::bad_request(format!("body read failed: {e}")))?;
153
154        let body_token = extract_body_token(&parts, &body_bytes);
155
156        let submitted = header_token.or(body_token);
157
158        if submitted.as_deref() != Some(session_token.as_str()) {
159            return Err(Error::forbidden("CSRF token mismatch"));
160        }
161
162        let req = Request::from_parts(parts, axum::body::Body::from(body_bytes));
163        Ok(next.run(req).await)
164    }
165
166    fn extract_body_token(parts: &axum::http::request::Parts, body: &[u8]) -> Option<String> {
167        let content_type = parts
168            .headers
169            .get(axum::http::header::CONTENT_TYPE)
170            .and_then(|v| v.to_str().ok())
171            .unwrap_or("");
172
173        if content_type.starts_with("application/x-www-form-urlencoded") {
174            let pairs: Vec<(String, String)> =
175                serde_urlencoded::from_bytes(body).unwrap_or_default();
176            return pairs.into_iter().find_map(|(k, v)| (k == "_token").then_some(v));
177        }
178        if content_type.starts_with("application/json") {
179            let value: serde_json::Value = serde_json::from_slice(body).ok()?;
180            return value
181                .get("_token")
182                .and_then(|v| v.as_str())
183                .map(|s| s.to_string());
184        }
185        None
186    }
187
188    /// Stub `auth` middleware: passes through. Real auth lives in the per-app
189    /// middleware (registered against the app's User model via `Auth<User>`).
190    pub async fn auth_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
191        Ok(next.run(req).await)
192    }
193
194    /// Stub throttle middleware: passes through. Real rate-limiting is deferred to v0.2.
195    pub async fn throttle_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
196        Ok(next.run(req).await)
197    }
198}
199
200pub fn install_defaults(registry: &MiddlewareRegistry) {
201    registry.register("auth", builtin::auth_passthrough);
202    registry.register("csrf", builtin::csrf);
203    registry.register("throttle", builtin::throttle_passthrough);
204}
205
206/// Apply a middleware by name to an axum router-style handler chain.
207/// The error from `MiddlewareFn` is converted to a 500 response if not handled.
208pub async fn invoke(
209    mw: MiddlewareFn,
210    req: Request<Body>,
211    next: Next,
212) -> Response<Body> {
213    match mw(req, next).await {
214        Ok(resp) => resp,
215        Err(err) => {
216            tracing::error!(?err, "middleware error");
217            axum::response::IntoResponse::into_response((StatusCode::INTERNAL_SERVER_ERROR, err))
218        }
219    }
220}
221
222/// Convenience for constructing a tracing layer with sensible defaults.
223pub fn trace_layer() -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>> {
224    TraceLayer::new_for_http()
225}
226
227/// Container injection middleware. Installs the container into the task-local context
228/// for the duration of the request so facade functions work.
229pub async fn inject_container_mw(
230    container: Container,
231    req: Request<Body>,
232    next: Next,
233) -> Response<Body> {
234    crate::container::with_container(container, async move { next.run(req).await }).await
235}
236
237pub fn standard_layers() -> ServiceBuilder<tower::layer::util::Stack<TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>>, tower::layer::util::Identity>> {
238    ServiceBuilder::new().layer(trace_layer())
239}