Skip to main content

anvil_core/
middleware.rs

1//! Named middleware registry. Strings like `"auth"`, `"throttle:60,1"`, `"csrf"`
2//! are resolved at app-init time to tower `Layer`s.
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use axum::body::Body;
8use axum::http::{Request, Response, StatusCode};
9use axum::middleware::Next;
10use tower::ServiceBuilder;
11use tower_http::trace::TraceLayer;
12
13use crate::container::Container;
14use crate::Error;
15
16pub type MiddlewareFn = Arc<
17    dyn Fn(
18            Request<Body>,
19            Next,
20        ) -> futures::future::BoxFuture<'static, Result<Response<Body>, Error>>
21        + Send
22        + Sync,
23>;
24
25/// A named middleware: a function that takes a request, optionally consumes args,
26/// and returns a response.
27#[derive(Clone)]
28pub struct NamedMiddleware {
29    pub name: String,
30    pub handler: MiddlewareFn,
31}
32
33#[derive(Default, Clone)]
34pub struct MiddlewareRegistry {
35    middleware: Arc<parking_lot::RwLock<HashMap<String, MiddlewareFn>>>,
36}
37
38impl MiddlewareRegistry {
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    pub fn register<F, Fut>(&self, name: impl Into<String>, handler: F)
44    where
45        F: Fn(Request<Body>, Next) -> Fut + Send + Sync + 'static,
46        Fut: std::future::Future<Output = Result<Response<Body>, Error>> + Send + 'static,
47    {
48        let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(handler(req, next)));
49        self.middleware.write().insert(name.into(), wrapped);
50    }
51
52    pub fn get(&self, name: &str) -> Option<MiddlewareFn> {
53        let parsed = MiddlewareSpec::parse(name);
54        self.middleware.read().get(&parsed.name).cloned()
55    }
56
57    pub fn names(&self) -> Vec<String> {
58        self.middleware.read().keys().cloned().collect()
59    }
60}
61
62/// Parsed form of `"throttle:60,1"` → `MiddlewareSpec { name: "throttle", args: ["60", "1"] }`.
63#[derive(Debug, Clone)]
64pub struct MiddlewareSpec {
65    pub name: String,
66    pub args: Vec<String>,
67}
68
69impl MiddlewareSpec {
70    pub fn parse(spec: &str) -> Self {
71        if let Some((name, args)) = spec.split_once(':') {
72            MiddlewareSpec {
73                name: name.to_string(),
74                args: args.split(',').map(|s| s.trim().to_string()).collect(),
75            }
76        } else {
77            MiddlewareSpec {
78                name: spec.to_string(),
79                args: vec![],
80            }
81        }
82    }
83}
84
85/// Built-in middleware: installed on the registry during bootstrap.
86pub mod builtin {
87    use super::*;
88    use axum::extract::{FromRequestParts, Request};
89    use axum::http::Method;
90    use rand::RngCore;
91    use tower_sessions::Session;
92
93    pub const CSRF_SESSION_KEY: &str = "_csrf.token";
94    pub const CSRF_HEADER: &str = "x-csrf-token";
95
96    /// Read the current CSRF token from the session, generating one if missing.
97    /// Used by templates (`@csrf` directive) and the CSRF middleware itself.
98    pub async fn ensure_csrf_token(session: &Session) -> Result<String, Error> {
99        if let Some(existing) = session
100            .get::<String>(CSRF_SESSION_KEY)
101            .await
102            .map_err(|e| Error::Internal(e.to_string()))?
103        {
104            return Ok(existing);
105        }
106        let token = generate_csrf_token();
107        session
108            .insert(CSRF_SESSION_KEY, token.clone())
109            .await
110            .map_err(|e| Error::Internal(e.to_string()))?;
111        Ok(token)
112    }
113
114    fn generate_csrf_token() -> String {
115        use base64::engine::Engine;
116        let mut bytes = [0u8; 32];
117        rand::thread_rng().fill_bytes(&mut bytes);
118        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
119    }
120
121    /// Real CSRF middleware: ensures every session has a token; verifies
122    /// state-changing requests carry a matching token via `_token` form field
123    /// or `X-CSRF-TOKEN` header.
124    pub async fn csrf(req: Request, next: Next) -> Result<Response<Body>, Error> {
125        let method = req.method().clone();
126        let (mut parts, body) = req.into_parts();
127
128        let session = match Session::from_request_parts(&mut parts, &()).await {
129            Ok(s) => s,
130            Err(_) => {
131                // No session installed — request passes through unchallenged.
132                // Apps that want CSRF protection must add the session layer.
133                let req = Request::from_parts(parts, body);
134                return Ok(next.run(req).await);
135            }
136        };
137
138        let session_token = ensure_csrf_token(&session).await?;
139
140        // Safe methods don't need verification.
141        if matches!(method, Method::GET | Method::HEAD | Method::OPTIONS) {
142            let req = Request::from_parts(parts, body);
143            return Ok(next.run(req).await);
144        }
145
146        // Look for the token in headers first, then body.
147        let header_token = parts
148            .headers
149            .get(CSRF_HEADER)
150            .and_then(|v| v.to_str().ok())
151            .map(|s| s.to_string());
152
153        let body_bytes = axum::body::to_bytes(body, 16 * 1024 * 1024)
154            .await
155            .map_err(|e| Error::bad_request(format!("body read failed: {e}")))?;
156
157        let body_token = extract_body_token(&parts, &body_bytes);
158
159        let submitted = header_token.or(body_token);
160
161        if submitted.as_deref() != Some(session_token.as_str()) {
162            return Err(Error::forbidden("CSRF token mismatch"));
163        }
164
165        let req = Request::from_parts(parts, axum::body::Body::from(body_bytes));
166        Ok(next.run(req).await)
167    }
168
169    fn extract_body_token(parts: &axum::http::request::Parts, body: &[u8]) -> Option<String> {
170        let content_type = parts
171            .headers
172            .get(axum::http::header::CONTENT_TYPE)
173            .and_then(|v| v.to_str().ok())
174            .unwrap_or("");
175
176        if content_type.starts_with("application/x-www-form-urlencoded") {
177            let pairs: Vec<(String, String)> =
178                serde_urlencoded::from_bytes(body).unwrap_or_default();
179            return pairs
180                .into_iter()
181                .find_map(|(k, v)| (k == "_token").then_some(v));
182        }
183        if content_type.starts_with("application/json") {
184            let value: serde_json::Value = serde_json::from_slice(body).ok()?;
185            return value
186                .get("_token")
187                .and_then(|v| v.as_str())
188                .map(|s| s.to_string());
189        }
190        None
191    }
192
193    /// Stub `auth` middleware: passes through. Real auth lives in the per-app
194    /// middleware (registered against the app's User model via `Auth<User>`).
195    pub async fn auth_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
196        Ok(next.run(req).await)
197    }
198
199    /// Stub throttle middleware: passes through. Real rate-limiting is deferred to v0.2.
200    pub async fn throttle_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
201        Ok(next.run(req).await)
202    }
203}
204
205pub fn install_defaults(registry: &MiddlewareRegistry) {
206    registry.register("auth", builtin::auth_passthrough);
207    registry.register("csrf", builtin::csrf);
208    registry.register("throttle", builtin::throttle_passthrough);
209}
210
211/// Apply a middleware by name to an axum router-style handler chain.
212/// The error from `MiddlewareFn` is converted to a 500 response if not handled.
213pub async fn invoke(mw: MiddlewareFn, req: Request<Body>, next: Next) -> Response<Body> {
214    match mw(req, next).await {
215        Ok(resp) => resp,
216        Err(err) => {
217            tracing::error!(?err, "middleware error");
218            axum::response::IntoResponse::into_response((StatusCode::INTERNAL_SERVER_ERROR, err))
219        }
220    }
221}
222
223/// Convenience for constructing a tracing layer with sensible defaults.
224pub fn trace_layer(
225) -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>>
226{
227    TraceLayer::new_for_http()
228}
229
230/// Container injection middleware. Installs the container into the task-local context
231/// for the duration of the request so facade functions work.
232pub async fn inject_container_mw(
233    container: Container,
234    req: Request<Body>,
235    next: Next,
236) -> Response<Body> {
237    crate::container::with_container(container, async move { next.run(req).await }).await
238}
239
240pub fn standard_layers() -> ServiceBuilder<
241    tower::layer::util::Stack<
242        TraceLayer<
243            tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
244        >,
245        tower::layer::util::Identity,
246    >,
247> {
248    ServiceBuilder::new().layer(trace_layer())
249}