Skip to main content

actus_server/middleware/
chain.rs

1use crate::request::Request;
2use actus_reply::{ReplyData, WebError};
3use async_trait::async_trait;
4use std::sync::Arc;
5
6/// What a middleware's [`before`](Middleware::before) hook decided.
7#[derive(Debug)]
8pub enum Outcome {
9    /// Carry on — run the next middleware (and, eventually, the handler).
10    Continue,
11    /// Short-circuit: build this reply and return it. The handler and any
12    /// remaining `before` hooks are skipped (`after` hooks still run on the
13    /// reply). Use `Outcome::Respond(reply::empty())`, a `reply::build_reply()`
14    /// for an explicit status/headers, etc. — anything a handler could return.
15    Respond(ReplyData),
16}
17
18/// A trait for implementing middleware.
19///
20/// `before` runs in registration order; `after` runs in reverse — so a
21/// middleware wraps the ones added after it (`[A, B]`: `A.before`, `B.before`,
22/// handler, `B.after`, `A.after`).
23#[async_trait]
24pub trait Middleware: Send + Sync {
25    /// Called before the request is routed. Return [`Outcome::Continue`] to
26    /// proceed, [`Outcome::Respond`] to short-circuit with a normal response,
27    /// or `Err(WebError)` to short-circuit with an error response. The default
28    /// implementation continues.
29    async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
30        Ok(Outcome::Continue)
31    }
32
33    /// Called after the handler returns (or after a `before`-hook
34    /// `Outcome::Respond` short-circuit), on the way back out. `request` is
35    /// the request the reply is for — so the hook can decide what to do based
36    /// on the request's headers / method / path. Mutate `response` in place
37    /// (including replacing it wholesale, or stamping headers via
38    /// [`ReplyData::add_header`](actus_reply::ReplyData::add_header));
39    /// `Err(WebError)` swaps the reply for an error response. The default
40    /// implementation does nothing.
41    ///
42    /// `after` is **not** called on [`ReplyData::Upgrade`] replies — a
43    /// WebSocket handshake's `101 Switching Protocols` has no body to decorate,
44    /// and the server short-circuits that variant to the upgrade machinery
45    /// before the `after` chain runs.
46    async fn after(&self, _request: &Request, _response: &mut ReplyData) -> Result<(), WebError> {
47        Ok(())
48    }
49}
50
51/// A chain of middleware to be executed.
52#[derive(Clone, Default)]
53pub struct MiddlewareChain {
54    middlewares: Vec<Arc<dyn Middleware>>,
55}
56
57impl MiddlewareChain {
58    /// Create an empty chain.
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    /// Append a middleware to the chain. `before` hooks run in insertion
64    /// order; `after` hooks run in reverse (outermost-last).
65    pub fn add<M: Middleware + 'static>(&mut self, middleware: M) {
66        self.middlewares.push(Arc::new(middleware));
67    }
68
69    /// Run every `before` hook in order. Returns [`Outcome::Continue`] if all
70    /// of them did; [`Outcome::Respond`] from the first that short-circuited
71    /// (the rest aren't run); or the first `Err`.
72    pub async fn process_request(&self, request: &mut Request) -> Result<Outcome, WebError> {
73        for middleware in &self.middlewares {
74            match middleware.before(request).await? {
75                Outcome::Continue => {}
76                respond @ Outcome::Respond(_) => return Ok(respond),
77            }
78        }
79        Ok(Outcome::Continue)
80    }
81
82    /// Run every `after` hook in reverse order (so the middleware added first
83    /// wraps outermost), letting each decorate the outgoing reply.
84    pub async fn process_response(
85        &self,
86        request: &Request,
87        response: &mut ReplyData,
88    ) -> Result<(), WebError> {
89        for middleware in self.middlewares.iter().rev() {
90            middleware.after(request, response).await?;
91        }
92        Ok(())
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use bytes::Bytes;
100    use http::{HeaderMap, Method};
101    use std::collections::HashMap;
102    use std::sync::atomic::{AtomicBool, Ordering};
103
104    fn req() -> Request {
105        Request {
106            method: Method::GET,
107            path_parts: Vec::new(),
108            query_params: HashMap::new(),
109            body: Bytes::new(),
110            headers: HeaderMap::new(),
111            rate_limit_class: None,
112        }
113    }
114
115    /// Uses the default `before` (continues).
116    struct Continues;
117    #[async_trait]
118    impl Middleware for Continues {}
119
120    struct ShortCircuits;
121    #[async_trait]
122    impl Middleware for ShortCircuits {
123        async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
124            Ok(Outcome::Respond(ReplyData::Empty))
125        }
126    }
127
128    /// Flips a flag when its `before` runs — so a test can prove it *didn't*.
129    struct Records(Arc<AtomicBool>);
130    #[async_trait]
131    impl Middleware for Records {
132        async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
133            self.0.store(true, Ordering::SeqCst);
134            Ok(Outcome::Continue)
135        }
136    }
137
138    #[tokio::test]
139    async fn before_short_circuit_skips_the_rest() {
140        let later_ran = Arc::new(AtomicBool::new(false));
141        let mut chain = MiddlewareChain::new();
142        chain.add(Continues);
143        chain.add(ShortCircuits);
144        chain.add(Records(later_ran.clone()));
145        match chain.process_request(&mut req()).await {
146            Ok(Outcome::Respond(ReplyData::Empty)) => {}
147            other => panic!("expected Respond(Empty), got {other:?}"),
148        }
149        assert!(
150            !later_ran.load(Ordering::SeqCst),
151            "a middleware after the short-circuit still ran"
152        );
153    }
154
155    #[tokio::test]
156    async fn all_continue_yields_continue() {
157        let mut chain = MiddlewareChain::new();
158        chain.add(Continues);
159        chain.add(Continues);
160        assert!(matches!(
161            chain.process_request(&mut req()).await,
162            Ok(Outcome::Continue)
163        ));
164    }
165
166    /// Appends an identifying letter (and the request's path) to a shared
167    /// buffer when its `after` runs — so a test can prove ordering and that
168    /// the hook saw the request.
169    struct Trace {
170        tag: &'static str,
171        log: Arc<std::sync::Mutex<Vec<String>>>,
172    }
173    #[async_trait]
174    impl Middleware for Trace {
175        async fn after(
176            &self,
177            request: &Request,
178            _response: &mut ReplyData,
179        ) -> Result<(), WebError> {
180            let path = request.path_parts.join("/");
181            self.log
182                .lock()
183                .unwrap()
184                .push(format!("{}:{}", self.tag, path));
185            Ok(())
186        }
187    }
188
189    #[tokio::test]
190    async fn after_runs_in_reverse_order_and_sees_the_request() {
191        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
192        let mut chain = MiddlewareChain::new();
193        chain.add(Trace {
194            tag: "A",
195            log: log.clone(),
196        });
197        chain.add(Trace {
198            tag: "B",
199            log: log.clone(),
200        });
201        chain.add(Trace {
202            tag: "C",
203            log: log.clone(),
204        });
205        let mut request = req();
206        request.path_parts = vec!["api".into(), "users".into()];
207        let mut data = ReplyData::Empty;
208        chain.process_response(&request, &mut data).await.unwrap();
209        // Registration order [A, B, C] → `after` order [C, B, A].
210        assert_eq!(
211            *log.lock().unwrap(),
212            vec!["C:api/users", "B:api/users", "A:api/users"]
213        );
214    }
215}