arbiter_proxy/
middleware.rs1use bytes::Bytes;
8use http_body_util::Full;
9use hyper::{Request, Response, StatusCode};
10
11use crate::config::MiddlewareConfig;
12
13pub type MiddlewareResult = Result<Request<hyper::body::Incoming>, Box<Response<Full<Bytes>>>>;
16
17pub trait Middleware: Send + Sync {
19 fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult;
21}
22
23pub struct PathBlocker {
25 blocked: Vec<String>,
26}
27
28impl PathBlocker {
29 pub fn new(blocked: Vec<String>) -> Self {
31 Self { blocked }
32 }
33}
34
35impl Middleware for PathBlocker {
36 fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
37 let path = req.uri().path();
38 if self.blocked.iter().any(|b| path == b.as_str()) {
39 tracing::warn!(path, "request blocked by path blocker");
40 let resp = Response::builder()
41 .status(StatusCode::FORBIDDEN)
42 .body(Full::new(Bytes::from("Forbidden")))
43 .expect("building static response cannot fail");
44 return Err(Box::new(resp));
45 }
46 Ok(req)
47 }
48}
49
50pub struct RequiredHeaders {
52 required: Vec<String>,
53}
54
55impl RequiredHeaders {
56 pub fn new(required: Vec<String>) -> Self {
58 Self { required }
59 }
60}
61
62impl Middleware for RequiredHeaders {
63 fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
64 for header in &self.required {
65 if !req.headers().contains_key(header.as_str()) {
66 tracing::warn!(header, "request rejected: missing required header");
67 let resp = Response::builder()
68 .status(StatusCode::BAD_REQUEST)
69 .body(Full::new(Bytes::from(format!(
70 "Missing required header: {header}"
71 ))))
72 .expect("building static response cannot fail");
73 return Err(Box::new(resp));
74 }
75 }
76 Ok(req)
77 }
78}
79
80pub struct MiddlewareChain {
82 middlewares: Vec<Box<dyn Middleware>>,
83}
84
85impl MiddlewareChain {
86 pub fn from_config(config: &MiddlewareConfig) -> Self {
88 let mut middlewares: Vec<Box<dyn Middleware>> = Vec::new();
89
90 if !config.blocked_paths.is_empty() {
91 middlewares.push(Box::new(PathBlocker::new(config.blocked_paths.clone())));
92 }
93 if !config.required_headers.is_empty() {
94 middlewares.push(Box::new(RequiredHeaders::new(
95 config.required_headers.clone(),
96 )));
97 }
98
99 Self { middlewares }
100 }
101
102 pub fn empty() -> Self {
104 Self {
105 middlewares: Vec::new(),
106 }
107 }
108
109 pub fn execute(&self, mut req: Request<hyper::body::Incoming>) -> MiddlewareResult {
112 for mw in &self.middlewares {
113 req = mw.process(req)?;
114 }
115 Ok(req)
116 }
117}