Skip to main content

brainwires_proxy/middleware/
mod.rs

1//! Middleware pipeline with onion-model request/response processing.
2
3pub mod auth;
4pub mod header_inject;
5pub mod inspector;
6pub mod logging;
7pub mod rate_limit;
8
9use crate::error::ProxyResult;
10use crate::types::{ProxyRequest, ProxyResponse};
11
12/// Action a middleware layer can take on a request.
13pub enum LayerAction {
14    /// Forward the (possibly modified) request to the next layer.
15    Forward(ProxyRequest),
16    /// Short-circuit and return this response immediately.
17    Respond(ProxyResponse),
18}
19
20/// A single middleware layer in the proxy pipeline.
21///
22/// Layers form an onion: requests flow inward through `on_request()`,
23/// responses flow outward through `on_response()` in reverse order.
24#[async_trait::async_trait]
25pub trait ProxyLayer: Send + Sync {
26    /// Process an incoming request. Return `Forward` to pass it on,
27    /// or `Respond` to short-circuit with an immediate response.
28    async fn on_request(&self, request: ProxyRequest) -> ProxyResult<LayerAction>;
29
30    /// Process a response before it's sent back to the client.
31    /// Called in reverse layer order.
32    async fn on_response(&self, response: ProxyResponse) -> ProxyResult<ProxyResponse> {
33        Ok(response)
34    }
35
36    /// Human-readable name for logging.
37    fn name(&self) -> &str;
38}
39
40/// Ordered stack of middleware layers implementing the onion model.
41pub struct MiddlewareStack {
42    layers: Vec<Box<dyn ProxyLayer>>,
43}
44
45impl MiddlewareStack {
46    pub fn new() -> Self {
47        Self { layers: Vec::new() }
48    }
49
50    /// Push a layer onto the stack. Layers are processed in insertion order
51    /// for requests and reverse order for responses.
52    pub fn push(&mut self, layer: impl ProxyLayer + 'static) {
53        self.layers.push(Box::new(layer));
54    }
55
56    /// Process a request through all layers.
57    /// Returns the (possibly modified) request and the index of the deepest
58    /// layer reached, or a short-circuit response.
59    pub async fn process_request(
60        &self,
61        mut request: ProxyRequest,
62    ) -> ProxyResult<Result<(ProxyRequest, usize), ProxyResponse>> {
63        for layer in self.layers.iter() {
64            match layer.on_request(request).await? {
65                LayerAction::Forward(req) => request = req,
66                LayerAction::Respond(resp) => return Ok(Err(resp)),
67            }
68        }
69        Ok(Ok((request, self.layers.len())))
70    }
71
72    /// Process a response back through layers in reverse order.
73    /// `depth` is the number of layers the request passed through.
74    pub async fn process_response(
75        &self,
76        mut response: ProxyResponse,
77        depth: usize,
78    ) -> ProxyResult<ProxyResponse> {
79        for layer in self.layers[..depth].iter().rev() {
80            response = layer.on_response(response).await?;
81        }
82        Ok(response)
83    }
84
85    pub fn is_empty(&self) -> bool {
86        self.layers.is_empty()
87    }
88
89    pub fn len(&self) -> usize {
90        self.layers.len()
91    }
92}
93
94impl Default for MiddlewareStack {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use http::{Method, StatusCode};
104    use std::sync::Arc;
105
106    fn make_request() -> ProxyRequest {
107        ProxyRequest::new(Method::GET, "/test".parse().unwrap()).with_body("hello")
108    }
109
110    /// A layer that appends a marker header and tracks call order.
111    struct MarkerLayer {
112        name: String,
113        order: Arc<std::sync::Mutex<Vec<String>>>,
114    }
115
116    #[async_trait::async_trait]
117    impl ProxyLayer for MarkerLayer {
118        async fn on_request(&self, mut request: ProxyRequest) -> ProxyResult<LayerAction> {
119            self.order
120                .lock()
121                .unwrap()
122                .push(format!("{}-req", self.name));
123            request.headers.insert(
124                http::header::HeaderName::from_bytes(self.name.as_bytes()).unwrap(),
125                http::header::HeaderValue::from_static("true"),
126            );
127            Ok(LayerAction::Forward(request))
128        }
129
130        async fn on_response(&self, response: ProxyResponse) -> ProxyResult<ProxyResponse> {
131            self.order
132                .lock()
133                .unwrap()
134                .push(format!("{}-resp", self.name));
135            Ok(response)
136        }
137
138        fn name(&self) -> &str {
139            &self.name
140        }
141    }
142
143    /// A layer that short-circuits with a 403 response.
144    struct BlockingLayer;
145
146    #[async_trait::async_trait]
147    impl ProxyLayer for BlockingLayer {
148        async fn on_request(&self, request: ProxyRequest) -> ProxyResult<LayerAction> {
149            Ok(LayerAction::Respond(
150                ProxyResponse::for_request(request.id, StatusCode::FORBIDDEN).with_body("blocked"),
151            ))
152        }
153        fn name(&self) -> &str {
154            "blocker"
155        }
156    }
157
158    #[tokio::test]
159    async fn empty_stack_passes_through() {
160        let stack = MiddlewareStack::new();
161        let req = make_request();
162        let result = stack.process_request(req).await.unwrap();
163        assert!(result.is_ok());
164        let (req, depth) = result.unwrap();
165        assert_eq!(depth, 0);
166        assert_eq!(req.body.as_bytes(), b"hello");
167    }
168
169    #[tokio::test]
170    async fn onion_model_order() {
171        let order = Arc::new(std::sync::Mutex::new(Vec::new()));
172        let mut stack = MiddlewareStack::new();
173
174        stack.push(MarkerLayer {
175            name: "a".into(),
176            order: order.clone(),
177        });
178        stack.push(MarkerLayer {
179            name: "b".into(),
180            order: order.clone(),
181        });
182        stack.push(MarkerLayer {
183            name: "c".into(),
184            order: order.clone(),
185        });
186
187        let req = make_request();
188        let result = stack.process_request(req).await.unwrap().unwrap();
189        let (_, depth) = result;
190        assert_eq!(depth, 3);
191
192        let resp = ProxyResponse::new(StatusCode::OK);
193        stack.process_response(resp, depth).await.unwrap();
194
195        let log = order.lock().unwrap();
196        // Request order: a, b, c; Response order: c, b, a
197        assert_eq!(
198            *log,
199            vec!["a-req", "b-req", "c-req", "c-resp", "b-resp", "a-resp"]
200        );
201    }
202
203    #[tokio::test]
204    async fn short_circuit_stops_processing() {
205        let order = Arc::new(std::sync::Mutex::new(Vec::new()));
206        let mut stack = MiddlewareStack::new();
207
208        stack.push(MarkerLayer {
209            name: "a".into(),
210            order: order.clone(),
211        });
212        stack.push(BlockingLayer);
213        stack.push(MarkerLayer {
214            name: "c".into(),
215            order: order.clone(),
216        });
217
218        let req = make_request();
219        let result = stack.process_request(req).await.unwrap();
220        assert!(result.is_err()); // short-circuited
221        let resp = result.unwrap_err();
222        assert_eq!(resp.status, StatusCode::FORBIDDEN);
223
224        let log = order.lock().unwrap();
225        // Only 'a' was called, 'c' was never reached
226        assert_eq!(*log, vec!["a-req"]);
227    }
228
229    #[tokio::test]
230    async fn response_depth_limits_reverse_traversal() {
231        let order = Arc::new(std::sync::Mutex::new(Vec::new()));
232        let mut stack = MiddlewareStack::new();
233
234        stack.push(MarkerLayer {
235            name: "a".into(),
236            order: order.clone(),
237        });
238        stack.push(MarkerLayer {
239            name: "b".into(),
240            order: order.clone(),
241        });
242        stack.push(MarkerLayer {
243            name: "c".into(),
244            order: order.clone(),
245        });
246
247        // Process response with depth=2 (only a,b should run on_response)
248        let resp = ProxyResponse::new(StatusCode::OK);
249        stack.process_response(resp, 2).await.unwrap();
250
251        let log = order.lock().unwrap();
252        assert_eq!(*log, vec!["b-resp", "a-resp"]);
253    }
254}