brainwires_proxy/middleware/
mod.rs1pub mod auth;
4pub mod header_inject;
5pub mod inspector;
6pub mod logging;
7pub mod rate_limit;
8
9use crate::error::ProxyResult;
10use crate::types::{ProxyRequest, ProxyResponse};
11
12pub enum LayerAction {
14 Forward(ProxyRequest),
16 Respond(ProxyResponse),
18}
19
20#[async_trait::async_trait]
25pub trait ProxyLayer: Send + Sync {
26 async fn on_request(&self, request: ProxyRequest) -> ProxyResult<LayerAction>;
29
30 async fn on_response(&self, response: ProxyResponse) -> ProxyResult<ProxyResponse> {
33 Ok(response)
34 }
35
36 fn name(&self) -> &str;
38}
39
40pub struct MiddlewareStack {
42 layers: Vec<Box<dyn ProxyLayer>>,
43}
44
45impl MiddlewareStack {
46 pub fn new() -> Self {
47 Self { layers: Vec::new() }
48 }
49
50 pub fn push(&mut self, layer: impl ProxyLayer + 'static) {
53 self.layers.push(Box::new(layer));
54 }
55
56 pub async fn process_request(
60 &self,
61 mut request: ProxyRequest,
62 ) -> ProxyResult<Result<(ProxyRequest, usize), ProxyResponse>> {
63 for layer in self.layers.iter() {
64 match layer.on_request(request).await? {
65 LayerAction::Forward(req) => request = req,
66 LayerAction::Respond(resp) => return Ok(Err(resp)),
67 }
68 }
69 Ok(Ok((request, self.layers.len())))
70 }
71
72 pub async fn process_response(
75 &self,
76 mut response: ProxyResponse,
77 depth: usize,
78 ) -> ProxyResult<ProxyResponse> {
79 for layer in self.layers[..depth].iter().rev() {
80 response = layer.on_response(response).await?;
81 }
82 Ok(response)
83 }
84
85 pub fn is_empty(&self) -> bool {
86 self.layers.is_empty()
87 }
88
89 pub fn len(&self) -> usize {
90 self.layers.len()
91 }
92}
93
94impl Default for MiddlewareStack {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use http::{Method, StatusCode};
104 use std::sync::Arc;
105
106 fn make_request() -> ProxyRequest {
107 ProxyRequest::new(Method::GET, "/test".parse().unwrap()).with_body("hello")
108 }
109
110 struct MarkerLayer {
112 name: String,
113 order: Arc<std::sync::Mutex<Vec<String>>>,
114 }
115
116 #[async_trait::async_trait]
117 impl ProxyLayer for MarkerLayer {
118 async fn on_request(&self, mut request: ProxyRequest) -> ProxyResult<LayerAction> {
119 self.order
120 .lock()
121 .unwrap()
122 .push(format!("{}-req", self.name));
123 request.headers.insert(
124 http::header::HeaderName::from_bytes(self.name.as_bytes()).unwrap(),
125 http::header::HeaderValue::from_static("true"),
126 );
127 Ok(LayerAction::Forward(request))
128 }
129
130 async fn on_response(&self, response: ProxyResponse) -> ProxyResult<ProxyResponse> {
131 self.order
132 .lock()
133 .unwrap()
134 .push(format!("{}-resp", self.name));
135 Ok(response)
136 }
137
138 fn name(&self) -> &str {
139 &self.name
140 }
141 }
142
143 struct BlockingLayer;
145
146 #[async_trait::async_trait]
147 impl ProxyLayer for BlockingLayer {
148 async fn on_request(&self, request: ProxyRequest) -> ProxyResult<LayerAction> {
149 Ok(LayerAction::Respond(
150 ProxyResponse::for_request(request.id, StatusCode::FORBIDDEN).with_body("blocked"),
151 ))
152 }
153 fn name(&self) -> &str {
154 "blocker"
155 }
156 }
157
158 #[tokio::test]
159 async fn empty_stack_passes_through() {
160 let stack = MiddlewareStack::new();
161 let req = make_request();
162 let result = stack.process_request(req).await.unwrap();
163 assert!(result.is_ok());
164 let (req, depth) = result.unwrap();
165 assert_eq!(depth, 0);
166 assert_eq!(req.body.as_bytes(), b"hello");
167 }
168
169 #[tokio::test]
170 async fn onion_model_order() {
171 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
172 let mut stack = MiddlewareStack::new();
173
174 stack.push(MarkerLayer {
175 name: "a".into(),
176 order: order.clone(),
177 });
178 stack.push(MarkerLayer {
179 name: "b".into(),
180 order: order.clone(),
181 });
182 stack.push(MarkerLayer {
183 name: "c".into(),
184 order: order.clone(),
185 });
186
187 let req = make_request();
188 let result = stack.process_request(req).await.unwrap().unwrap();
189 let (_, depth) = result;
190 assert_eq!(depth, 3);
191
192 let resp = ProxyResponse::new(StatusCode::OK);
193 stack.process_response(resp, depth).await.unwrap();
194
195 let log = order.lock().unwrap();
196 assert_eq!(
198 *log,
199 vec!["a-req", "b-req", "c-req", "c-resp", "b-resp", "a-resp"]
200 );
201 }
202
203 #[tokio::test]
204 async fn short_circuit_stops_processing() {
205 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
206 let mut stack = MiddlewareStack::new();
207
208 stack.push(MarkerLayer {
209 name: "a".into(),
210 order: order.clone(),
211 });
212 stack.push(BlockingLayer);
213 stack.push(MarkerLayer {
214 name: "c".into(),
215 order: order.clone(),
216 });
217
218 let req = make_request();
219 let result = stack.process_request(req).await.unwrap();
220 assert!(result.is_err()); let resp = result.unwrap_err();
222 assert_eq!(resp.status, StatusCode::FORBIDDEN);
223
224 let log = order.lock().unwrap();
225 assert_eq!(*log, vec!["a-req"]);
227 }
228
229 #[tokio::test]
230 async fn response_depth_limits_reverse_traversal() {
231 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
232 let mut stack = MiddlewareStack::new();
233
234 stack.push(MarkerLayer {
235 name: "a".into(),
236 order: order.clone(),
237 });
238 stack.push(MarkerLayer {
239 name: "b".into(),
240 order: order.clone(),
241 });
242 stack.push(MarkerLayer {
243 name: "c".into(),
244 order: order.clone(),
245 });
246
247 let resp = ProxyResponse::new(StatusCode::OK);
249 stack.process_response(resp, 2).await.unwrap();
250
251 let log = order.lock().unwrap();
252 assert_eq!(*log, vec!["b-resp", "a-resp"]);
253 }
254}