arbiter_proxy/
middleware.rs1use bytes::Bytes;
8use http_body_util::Full;
9use hyper::{Request, Response, StatusCode};
10
11use crate::config::MiddlewareConfig;
12
13pub type MiddlewareResult = Result<Request<hyper::body::Incoming>, Box<Response<Full<Bytes>>>>;
16
17pub trait Middleware: Send + Sync {
19 fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult;
21}
22
23pub struct PathBlocker {
25 blocked: Vec<String>,
26}
27
28impl PathBlocker {
29 pub fn new(blocked: Vec<String>) -> Self {
31 Self { blocked }
32 }
33}
34
35fn normalize_path(path: &str) -> Option<String> {
39 if path.bytes().any(|b| b == 0) {
41 return None;
42 }
43 let decoded = percent_decode(path);
45 if decoded.bytes().any(|b| b == 0) {
47 return None;
48 }
49 let mut segments: Vec<&str> = Vec::new();
51 for segment in decoded.split('/') {
52 match segment {
53 "." => {}
54 ".." => {
55 segments.pop();
56 }
57 s => segments.push(s),
58 }
59 }
60 let normalized = format!("/{}", segments.join("/"));
61 Some(normalized.replace("//", "/"))
63}
64
65fn percent_decode(s: &str) -> String {
67 let mut result = Vec::new();
68 let bytes = s.as_bytes();
69 let mut i = 0;
70 while i < bytes.len() {
71 if bytes[i] == b'%'
72 && i + 2 < bytes.len()
73 && let (Some(hi), Some(lo)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2]))
74 {
75 result.push(hi * 16 + lo);
76 i += 3;
77 continue;
78 }
79 result.push(bytes[i]);
80 i += 1;
81 }
82 String::from_utf8_lossy(&result).into_owned()
83}
84
85fn hex_val(b: u8) -> Option<u8> {
86 match b {
87 b'0'..=b'9' => Some(b - b'0'),
88 b'a'..=b'f' => Some(b - b'a' + 10),
89 b'A'..=b'F' => Some(b - b'A' + 10),
90 _ => None,
91 }
92}
93
94impl Middleware for PathBlocker {
95 fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
96 let raw_path = req.uri().path();
97 let path = match normalize_path(raw_path) {
98 Some(p) => p,
99 None => {
100 tracing::warn!(path = raw_path, "request blocked: path contains null bytes");
101 let resp = Response::builder()
102 .status(StatusCode::BAD_REQUEST)
103 .body(Full::new(Bytes::from("Bad Request")))
104 .expect("building static response cannot fail");
105 return Err(Box::new(resp));
106 }
107 };
108 if self
109 .blocked
110 .iter()
111 .any(|b| path == *b || raw_path == b.as_str())
112 {
113 tracing::warn!(path, "request blocked by path blocker");
114 let resp = Response::builder()
115 .status(StatusCode::FORBIDDEN)
116 .body(Full::new(Bytes::from("Forbidden")))
117 .expect("building static response cannot fail");
118 return Err(Box::new(resp));
119 }
120 Ok(req)
121 }
122}
123
124pub struct RequiredHeaders {
126 required: Vec<String>,
127}
128
129impl RequiredHeaders {
130 pub fn new(required: Vec<String>) -> Self {
132 Self { required }
133 }
134}
135
136impl Middleware for RequiredHeaders {
137 fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
138 for header in &self.required {
139 if !req.headers().contains_key(header.as_str()) {
140 tracing::warn!(header, "request rejected: missing required header");
141 let resp = Response::builder()
144 .status(StatusCode::BAD_REQUEST)
145 .body(Full::new(Bytes::from("Bad Request")))
146 .expect("building static response cannot fail");
147 return Err(Box::new(resp));
148 }
149 }
150 Ok(req)
151 }
152}
153
154pub struct MethodAllowlist {
158 allowed: Vec<hyper::Method>,
159}
160
161impl MethodAllowlist {
162 pub fn default_safe() -> Self {
164 Self {
165 allowed: vec![
166 hyper::Method::GET,
167 hyper::Method::POST,
168 hyper::Method::PUT,
169 hyper::Method::DELETE,
170 hyper::Method::PATCH,
171 hyper::Method::HEAD,
172 hyper::Method::OPTIONS,
173 ],
174 }
175 }
176}
177
178impl Middleware for MethodAllowlist {
179 fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
180 if !self.allowed.contains(req.method()) {
181 tracing::warn!(method = %req.method(), "request blocked: method not in allowlist");
182 let resp = Response::builder()
183 .status(StatusCode::METHOD_NOT_ALLOWED)
184 .body(Full::new(Bytes::from("Method Not Allowed")))
185 .expect("building static response cannot fail");
186 return Err(Box::new(resp));
187 }
188 Ok(req)
189 }
190}
191
192pub struct MiddlewareChain {
194 middlewares: Vec<Box<dyn Middleware>>,
195}
196
197impl MiddlewareChain {
198 pub fn from_config(config: &MiddlewareConfig) -> Self {
200 let mut middlewares: Vec<Box<dyn Middleware>> = Vec::new();
201
202 middlewares.push(Box::new(MethodAllowlist::default_safe()));
204
205 if !config.blocked_paths.is_empty() {
206 middlewares.push(Box::new(PathBlocker::new(config.blocked_paths.clone())));
207 }
208 if !config.required_headers.is_empty() {
209 middlewares.push(Box::new(RequiredHeaders::new(
210 config.required_headers.clone(),
211 )));
212 }
213
214 Self { middlewares }
215 }
216
217 pub fn empty() -> Self {
219 Self {
220 middlewares: Vec::new(),
221 }
222 }
223
224 pub fn execute(&self, mut req: Request<hyper::body::Incoming>) -> MiddlewareResult {
227 for mw in &self.middlewares {
228 req = mw.process(req)?;
229 }
230 Ok(req)
231 }
232}