use bytes::Bytes;
use http_body_util::Full;
use hyper::{Request, Response, StatusCode};
use crate::config::MiddlewareConfig;
pub type MiddlewareResult = Result<Request<hyper::body::Incoming>, Box<Response<Full<Bytes>>>>;
pub trait Middleware: Send + Sync {
fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult;
}
pub struct PathBlocker {
blocked: Vec<String>,
}
impl PathBlocker {
pub fn new(blocked: Vec<String>) -> Self {
Self { blocked }
}
}
impl Middleware for PathBlocker {
fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
let path = req.uri().path();
if self.blocked.iter().any(|b| path == b.as_str()) {
tracing::warn!(path, "request blocked by path blocker");
let resp = Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Full::new(Bytes::from("Forbidden")))
.expect("building static response cannot fail");
return Err(Box::new(resp));
}
Ok(req)
}
}
pub struct RequiredHeaders {
required: Vec<String>,
}
impl RequiredHeaders {
pub fn new(required: Vec<String>) -> Self {
Self { required }
}
}
impl Middleware for RequiredHeaders {
fn process(&self, req: Request<hyper::body::Incoming>) -> MiddlewareResult {
for header in &self.required {
if !req.headers().contains_key(header.as_str()) {
tracing::warn!(header, "request rejected: missing required header");
let resp = Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from(format!(
"Missing required header: {header}"
))))
.expect("building static response cannot fail");
return Err(Box::new(resp));
}
}
Ok(req)
}
}
pub struct MiddlewareChain {
middlewares: Vec<Box<dyn Middleware>>,
}
impl MiddlewareChain {
pub fn from_config(config: &MiddlewareConfig) -> Self {
let mut middlewares: Vec<Box<dyn Middleware>> = Vec::new();
if !config.blocked_paths.is_empty() {
middlewares.push(Box::new(PathBlocker::new(config.blocked_paths.clone())));
}
if !config.required_headers.is_empty() {
middlewares.push(Box::new(RequiredHeaders::new(
config.required_headers.clone(),
)));
}
Self { middlewares }
}
pub fn empty() -> Self {
Self {
middlewares: Vec::new(),
}
}
pub fn execute(&self, mut req: Request<hyper::body::Incoming>) -> MiddlewareResult {
for mw in &self.middlewares {
req = mw.process(req)?;
}
Ok(req)
}
}