Skip to main content

arbiter_proxy/
middleware.rs

1//! Middleware trait and built-in middleware implementations.
2//!
3//! Middleware inspects or modifies an incoming request. It can either pass the
4//! request forward (returning `Ok(request)`) or reject it (returning
5//! `Err(response)`).
6
7use bytes::Bytes;
8use http_body_util::Full;
9use hyper::{Request, Response, StatusCode};
10
11use crate::config::MiddlewareConfig;
12
13/// Outcome of a middleware decision: either the (possibly modified) request
14/// continues downstream, or a response is returned immediately.
15pub type MiddlewareResult = Result<Request<hyper::body::Incoming>, Box<Response<Full<Bytes>>>>;
16
17/// A single middleware in the proxy pipeline.
18pub trait Middleware: Send + Sync {
19    /// Process the request. Return `Ok(req)` to forward, `Err(resp)` to reject.
20    fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult;
21}
22
23/// Blocks requests whose path matches a configured set of paths.
24pub struct PathBlocker {
25    blocked: Vec<String>,
26}
27
28impl PathBlocker {
29    /// Create a new path blocker from a list of blocked paths.
30    pub fn new(blocked: Vec<String>) -> Self {
31        Self { blocked }
32    }
33}
34
35impl Middleware for PathBlocker {
36    fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
37        let path = req.uri().path();
38        if self.blocked.iter().any(|b| path == b.as_str()) {
39            tracing::warn!(path, "request blocked by path blocker");
40            let resp = Response::builder()
41                .status(StatusCode::FORBIDDEN)
42                .body(Full::new(Bytes::from("Forbidden")))
43                .expect("building static response cannot fail");
44            return Err(Box::new(resp));
45        }
46        Ok(req)
47    }
48}
49
50/// Rejects requests that are missing required headers.
51pub struct RequiredHeaders {
52    required: Vec<String>,
53}
54
55impl RequiredHeaders {
56    /// Create a new required-headers middleware.
57    pub fn new(required: Vec<String>) -> Self {
58        Self { required }
59    }
60}
61
62impl Middleware for RequiredHeaders {
63    fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
64        for header in &self.required {
65            if !req.headers().contains_key(header.as_str()) {
66                tracing::warn!(header, "request rejected: missing required header");
67                let resp = Response::builder()
68                    .status(StatusCode::BAD_REQUEST)
69                    .body(Full::new(Bytes::from(format!(
70                        "Missing required header: {header}"
71                    ))))
72                    .expect("building static response cannot fail");
73                return Err(Box::new(resp));
74            }
75        }
76        Ok(req)
77    }
78}
79
80/// Ordered chain of middleware. Each middleware is executed in sequence.
81pub struct MiddlewareChain {
82    middlewares: Vec<Box<dyn Middleware>>,
83}
84
85impl MiddlewareChain {
86    /// Build a middleware chain from the proxy configuration.
87    pub fn from_config(config: &MiddlewareConfig) -> Self {
88        let mut middlewares: Vec<Box<dyn Middleware>> = Vec::new();
89
90        if !config.blocked_paths.is_empty() {
91            middlewares.push(Box::new(PathBlocker::new(config.blocked_paths.clone())));
92        }
93        if !config.required_headers.is_empty() {
94            middlewares.push(Box::new(RequiredHeaders::new(
95                config.required_headers.clone(),
96            )));
97        }
98
99        Self { middlewares }
100    }
101
102    /// Create an empty middleware chain (no-op passthrough).
103    pub fn empty() -> Self {
104        Self {
105            middlewares: Vec::new(),
106        }
107    }
108
109    /// Run the request through all middleware in order.
110    /// Returns `Ok(req)` if all middleware pass, or `Err(resp)` on first rejection.
111    pub fn execute(&self, mut req: Request<hyper::body::Incoming>) -> MiddlewareResult {
112        for mw in &self.middlewares {
113            req = mw.process(req)?;
114        }
115        Ok(req)
116    }
117}