use crate::request::Request;
use actus_reply::{ReplyData, WebError};
use async_trait::async_trait;
use std::sync::Arc;
#[derive(Debug)]
pub enum Outcome {
Continue,
Respond(ReplyData),
}
#[async_trait]
pub trait Middleware: Send + Sync {
async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
Ok(Outcome::Continue)
}
async fn after(&self, _request: &Request, _response: &mut ReplyData) -> Result<(), WebError> {
Ok(())
}
}
#[derive(Clone, Default)]
pub struct MiddlewareChain {
middlewares: Vec<Arc<dyn Middleware>>,
}
impl MiddlewareChain {
pub fn new() -> Self {
Self::default()
}
pub fn add<M: Middleware + 'static>(&mut self, middleware: M) {
self.middlewares.push(Arc::new(middleware));
}
pub async fn process_request(&self, request: &mut Request) -> Result<Outcome, WebError> {
for middleware in &self.middlewares {
match middleware.before(request).await? {
Outcome::Continue => {}
respond @ Outcome::Respond(_) => return Ok(respond),
}
}
Ok(Outcome::Continue)
}
pub async fn process_response(
&self,
request: &Request,
response: &mut ReplyData,
) -> Result<(), WebError> {
for middleware in self.middlewares.iter().rev() {
middleware.after(request, response).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http::{HeaderMap, Method};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
fn req() -> Request {
Request {
method: Method::GET,
path_parts: Vec::new(),
query_params: HashMap::new(),
body: Bytes::new(),
headers: HeaderMap::new(),
rate_limit_class: None,
}
}
struct Continues;
#[async_trait]
impl Middleware for Continues {}
struct ShortCircuits;
#[async_trait]
impl Middleware for ShortCircuits {
async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
Ok(Outcome::Respond(ReplyData::Empty))
}
}
struct Records(Arc<AtomicBool>);
#[async_trait]
impl Middleware for Records {
async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
self.0.store(true, Ordering::SeqCst);
Ok(Outcome::Continue)
}
}
#[tokio::test]
async fn before_short_circuit_skips_the_rest() {
let later_ran = Arc::new(AtomicBool::new(false));
let mut chain = MiddlewareChain::new();
chain.add(Continues);
chain.add(ShortCircuits);
chain.add(Records(later_ran.clone()));
match chain.process_request(&mut req()).await {
Ok(Outcome::Respond(ReplyData::Empty)) => {}
other => panic!("expected Respond(Empty), got {other:?}"),
}
assert!(
!later_ran.load(Ordering::SeqCst),
"a middleware after the short-circuit still ran"
);
}
#[tokio::test]
async fn all_continue_yields_continue() {
let mut chain = MiddlewareChain::new();
chain.add(Continues);
chain.add(Continues);
assert!(matches!(
chain.process_request(&mut req()).await,
Ok(Outcome::Continue)
));
}
struct Trace {
tag: &'static str,
log: Arc<std::sync::Mutex<Vec<String>>>,
}
#[async_trait]
impl Middleware for Trace {
async fn after(
&self,
request: &Request,
_response: &mut ReplyData,
) -> Result<(), WebError> {
let path = request.path_parts.join("/");
self.log
.lock()
.unwrap()
.push(format!("{}:{}", self.tag, path));
Ok(())
}
}
#[tokio::test]
async fn after_runs_in_reverse_order_and_sees_the_request() {
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut chain = MiddlewareChain::new();
chain.add(Trace {
tag: "A",
log: log.clone(),
});
chain.add(Trace {
tag: "B",
log: log.clone(),
});
chain.add(Trace {
tag: "C",
log: log.clone(),
});
let mut request = req();
request.path_parts = vec!["api".into(), "users".into()];
let mut data = ReplyData::Empty;
chain.process_response(&request, &mut data).await.unwrap();
assert_eq!(
*log.lock().unwrap(),
vec!["C:api/users", "B:api/users", "A:api/users"]
);
}
}