Skip to main content

anvil_core/
middleware.rs

1//! Named middleware registry. Strings like `"auth"`, `"throttle:60,1"`, `"csrf"`
2//! are resolved at app-init time to tower `Layer`s.
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use axum::body::Body;
8use axum::http::{Request, Response, StatusCode};
9use axum::middleware::Next;
10use tower::ServiceBuilder;
11use tower_http::trace::TraceLayer;
12
13use crate::container::Container;
14use crate::Error;
15
16pub type MiddlewareFn = Arc<
17    dyn Fn(Request<Body>, Next) -> futures::future::BoxFuture<'static, Result<Response<Body>, Error>>
18        + Send
19        + Sync,
20>;
21
22/// A named middleware: a function that takes a request, optionally consumes args,
23/// and returns a response.
24#[derive(Clone)]
25pub struct NamedMiddleware {
26    pub name: String,
27    pub handler: MiddlewareFn,
28}
29
30#[derive(Default, Clone)]
31pub struct MiddlewareRegistry {
32    middleware: Arc<parking_lot::RwLock<HashMap<String, MiddlewareFn>>>,
33}
34
35impl MiddlewareRegistry {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    pub fn register<F, Fut>(&self, name: impl Into<String>, handler: F)
41    where
42        F: Fn(Request<Body>, Next) -> Fut + Send + Sync + 'static,
43        Fut: std::future::Future<Output = Result<Response<Body>, Error>> + Send + 'static,
44    {
45        let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(handler(req, next)));
46        self.middleware.write().insert(name.into(), wrapped);
47    }
48
49    pub fn get(&self, name: &str) -> Option<MiddlewareFn> {
50        let parsed = MiddlewareSpec::parse(name);
51        self.middleware.read().get(&parsed.name).cloned()
52    }
53
54    pub fn names(&self) -> Vec<String> {
55        self.middleware.read().keys().cloned().collect()
56    }
57}
58
59/// Parsed form of `"throttle:60,1"` → `MiddlewareSpec { name: "throttle", args: ["60", "1"] }`.
60#[derive(Debug, Clone)]
61pub struct MiddlewareSpec {
62    pub name: String,
63    pub args: Vec<String>,
64}
65
66impl MiddlewareSpec {
67    pub fn parse(spec: &str) -> Self {
68        if let Some((name, args)) = spec.split_once(':') {
69            MiddlewareSpec {
70                name: name.to_string(),
71                args: args.split(',').map(|s| s.trim().to_string()).collect(),
72            }
73        } else {
74            MiddlewareSpec {
75                name: spec.to_string(),
76                args: vec![],
77            }
78        }
79    }
80}
81
82/// Built-in middleware: install on the registry during bootstrap.
83pub mod builtin {
84    use super::*;
85    use axum::extract::Request;
86
87    /// Stub `auth` middleware: passes through. Real apps register their own
88    /// auth middleware that pulls the session and validates the user.
89    pub async fn auth_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
90        Ok(next.run(req).await)
91    }
92
93    /// Stub `csrf` middleware: passes through. Real CSRF validation lives in the
94    /// session-aware version that's installed after `tower-sessions`.
95    pub async fn csrf_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
96        Ok(next.run(req).await)
97    }
98
99    /// Stub throttle middleware: passes through. Real rate-limiting is deferred to v1.1.
100    pub async fn throttle_passthrough(req: Request, next: Next) -> Result<Response<Body>, Error> {
101        Ok(next.run(req).await)
102    }
103}
104
105pub fn install_defaults(registry: &MiddlewareRegistry) {
106    registry.register("auth", builtin::auth_passthrough);
107    registry.register("csrf", builtin::csrf_passthrough);
108    registry.register("throttle", builtin::throttle_passthrough);
109}
110
111/// Apply a middleware by name to an axum router-style handler chain.
112/// The error from `MiddlewareFn` is converted to a 500 response if not handled.
113pub async fn invoke(
114    mw: MiddlewareFn,
115    req: Request<Body>,
116    next: Next,
117) -> Response<Body> {
118    match mw(req, next).await {
119        Ok(resp) => resp,
120        Err(err) => {
121            tracing::error!(?err, "middleware error");
122            axum::response::IntoResponse::into_response((StatusCode::INTERNAL_SERVER_ERROR, err))
123        }
124    }
125}
126
127/// Convenience for constructing a tracing layer with sensible defaults.
128pub fn trace_layer() -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>> {
129    TraceLayer::new_for_http()
130}
131
132/// Container injection middleware. Installs the container into the task-local context
133/// for the duration of the request so facade functions work.
134pub async fn inject_container_mw(
135    container: Container,
136    req: Request<Body>,
137    next: Next,
138) -> Response<Body> {
139    crate::container::with_container(container, async move { next.run(req).await }).await
140}
141
142pub fn standard_layers() -> ServiceBuilder<tower::layer::util::Stack<TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>>, tower::layer::util::Identity>> {
143    ServiceBuilder::new().layer(trace_layer())
144}