Skip to main content

arbiter_proxy/
middleware.rs

1//! Middleware trait and built-in middleware implementations.
2//!
3//! Middleware inspects or modifies an incoming request. It can either pass the
4//! request forward (returning `Ok(request)`) or reject it (returning
5//! `Err(response)`).
6
7use bytes::Bytes;
8use http_body_util::Full;
9use hyper::{Request, Response, StatusCode};
10
11use crate::config::MiddlewareConfig;
12
13/// Outcome of a middleware decision: either the (possibly modified) request
14/// continues downstream, or a response is returned immediately.
15pub type MiddlewareResult = Result<Request<hyper::body::Incoming>, Box<Response<Full<Bytes>>>>;
16
17/// A single middleware in the proxy pipeline.
18pub trait Middleware: Send + Sync {
19    /// Process the request. Return `Ok(req)` to forward, `Err(resp)` to reject.
20    fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult;
21}
22
23/// Blocks requests whose path matches a configured set of paths.
24pub struct PathBlocker {
25    blocked: Vec<String>,
26}
27
28impl PathBlocker {
29    /// Create a new path blocker from a list of blocked paths.
30    pub fn new(blocked: Vec<String>) -> Self {
31        Self { blocked }
32    }
33}
34
35/// Normalize a URL path by percent-decoding, collapsing dot segments, and
36/// rejecting null bytes. This prevents path traversal bypasses via encoded
37/// sequences like `%2e%2e` or `%2F`.
38fn normalize_path(path: &str) -> Option<String> {
39    // Reject null bytes.
40    if path.bytes().any(|b| b == 0) {
41        return None;
42    }
43    // Percent-decode the path.
44    let decoded = percent_decode(path);
45    // Reject null bytes in decoded form.
46    if decoded.bytes().any(|b| b == 0) {
47        return None;
48    }
49    // Collapse dot segments (RFC 3986 Section 5.2.4).
50    let mut segments: Vec<&str> = Vec::new();
51    for segment in decoded.split('/') {
52        match segment {
53            "." => {}
54            ".." => {
55                segments.pop();
56            }
57            s => segments.push(s),
58        }
59    }
60    let normalized = format!("/{}", segments.join("/"));
61    // Remove double slashes.
62    Some(normalized.replace("//", "/"))
63}
64
65/// Simple percent-decoding.
66fn percent_decode(s: &str) -> String {
67    let mut result = Vec::new();
68    let bytes = s.as_bytes();
69    let mut i = 0;
70    while i < bytes.len() {
71        if bytes[i] == b'%'
72            && i + 2 < bytes.len()
73            && let (Some(hi), Some(lo)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2]))
74        {
75            result.push(hi * 16 + lo);
76            i += 3;
77            continue;
78        }
79        result.push(bytes[i]);
80        i += 1;
81    }
82    String::from_utf8_lossy(&result).into_owned()
83}
84
85fn hex_val(b: u8) -> Option<u8> {
86    match b {
87        b'0'..=b'9' => Some(b - b'0'),
88        b'a'..=b'f' => Some(b - b'a' + 10),
89        b'A'..=b'F' => Some(b - b'A' + 10),
90        _ => None,
91    }
92}
93
94impl Middleware for PathBlocker {
95    fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
96        let raw_path = req.uri().path();
97        let path = match normalize_path(raw_path) {
98            Some(p) => p,
99            None => {
100                tracing::warn!(path = raw_path, "request blocked: path contains null bytes");
101                let resp = Response::builder()
102                    .status(StatusCode::BAD_REQUEST)
103                    .body(Full::new(Bytes::from("Bad Request")))
104                    .expect("building static response cannot fail");
105                return Err(Box::new(resp));
106            }
107        };
108        if self
109            .blocked
110            .iter()
111            .any(|b| path == *b || raw_path == b.as_str())
112        {
113            tracing::warn!(path, "request blocked by path blocker");
114            let resp = Response::builder()
115                .status(StatusCode::FORBIDDEN)
116                .body(Full::new(Bytes::from("Forbidden")))
117                .expect("building static response cannot fail");
118            return Err(Box::new(resp));
119        }
120        Ok(req)
121    }
122}
123
124/// Rejects requests that are missing required headers.
125pub struct RequiredHeaders {
126    required: Vec<String>,
127}
128
129impl RequiredHeaders {
130    /// Create a new required-headers middleware.
131    pub fn new(required: Vec<String>) -> Self {
132        Self { required }
133    }
134}
135
136impl Middleware for RequiredHeaders {
137    fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
138        for header in &self.required {
139            if !req.headers().contains_key(header.as_str()) {
140                tracing::warn!(header, "request rejected: missing required header");
141                // Don't expose the header name in the response body -- it leaks
142                // security policy configuration and enables iterative probing.
143                let resp = Response::builder()
144                    .status(StatusCode::BAD_REQUEST)
145                    .body(Full::new(Bytes::from("Bad Request")))
146                    .expect("building static response cannot fail");
147                return Err(Box::new(resp));
148            }
149        }
150        Ok(req)
151    }
152}
153
154/// Rejects requests using HTTP methods not in the allowlist.
155/// TRACE and CONNECT are always blocked as they can enable cross-site
156/// tracing and proxy tunneling attacks.
157pub struct MethodAllowlist {
158    allowed: Vec<hyper::Method>,
159}
160
161impl MethodAllowlist {
162    /// Create with the default safe set: GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS.
163    pub fn default_safe() -> Self {
164        Self {
165            allowed: vec![
166                hyper::Method::GET,
167                hyper::Method::POST,
168                hyper::Method::PUT,
169                hyper::Method::DELETE,
170                hyper::Method::PATCH,
171                hyper::Method::HEAD,
172                hyper::Method::OPTIONS,
173            ],
174        }
175    }
176}
177
178impl Middleware for MethodAllowlist {
179    fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
180        if !self.allowed.contains(req.method()) {
181            tracing::warn!(method = %req.method(), "request blocked: method not in allowlist");
182            let resp = Response::builder()
183                .status(StatusCode::METHOD_NOT_ALLOWED)
184                .body(Full::new(Bytes::from("Method Not Allowed")))
185                .expect("building static response cannot fail");
186            return Err(Box::new(resp));
187        }
188        Ok(req)
189    }
190}
191
192/// Ordered chain of middleware. Each middleware is executed in sequence.
193pub struct MiddlewareChain {
194    middlewares: Vec<Box<dyn Middleware>>,
195}
196
197impl MiddlewareChain {
198    /// Build a middleware chain from the proxy configuration.
199    pub fn from_config(config: &MiddlewareConfig) -> Self {
200        let mut middlewares: Vec<Box<dyn Middleware>> = Vec::new();
201
202        // Method allowlist first: reject dangerous verbs before any other processing.
203        middlewares.push(Box::new(MethodAllowlist::default_safe()));
204
205        if !config.blocked_paths.is_empty() {
206            middlewares.push(Box::new(PathBlocker::new(config.blocked_paths.clone())));
207        }
208        if !config.required_headers.is_empty() {
209            middlewares.push(Box::new(RequiredHeaders::new(
210                config.required_headers.clone(),
211            )));
212        }
213
214        Self { middlewares }
215    }
216
217    /// Create an empty middleware chain (no-op passthrough).
218    pub fn empty() -> Self {
219        Self {
220            middlewares: Vec::new(),
221        }
222    }
223
224    /// Run the request through all middleware in order.
225    /// Returns `Ok(req)` if all middleware pass, or `Err(resp)` on first rejection.
226    pub fn execute(&self, mut req: Request<hyper::body::Incoming>) -> MiddlewareResult {
227        for mw in &self.middlewares {
228            req = mw.process(req)?;
229        }
230        Ok(req)
231    }
232}