actus_server/middleware/
chain.rs1use crate::request::Request;
2use actus_reply::{ReplyData, WebError};
3use async_trait::async_trait;
4use std::sync::Arc;
5
6#[derive(Debug)]
8pub enum Outcome {
9 Continue,
11 Respond(ReplyData),
16}
17
18#[async_trait]
24pub trait Middleware: Send + Sync {
25 async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
30 Ok(Outcome::Continue)
31 }
32
33 async fn after(&self, _request: &Request, _response: &mut ReplyData) -> Result<(), WebError> {
47 Ok(())
48 }
49}
50
51#[derive(Clone, Default)]
53pub struct MiddlewareChain {
54 middlewares: Vec<Arc<dyn Middleware>>,
55}
56
57impl MiddlewareChain {
58 pub fn new() -> Self {
60 Self::default()
61 }
62
63 pub fn add<M: Middleware + 'static>(&mut self, middleware: M) {
66 self.middlewares.push(Arc::new(middleware));
67 }
68
69 pub async fn process_request(&self, request: &mut Request) -> Result<Outcome, WebError> {
73 for middleware in &self.middlewares {
74 match middleware.before(request).await? {
75 Outcome::Continue => {}
76 respond @ Outcome::Respond(_) => return Ok(respond),
77 }
78 }
79 Ok(Outcome::Continue)
80 }
81
82 pub async fn process_response(
85 &self,
86 request: &Request,
87 response: &mut ReplyData,
88 ) -> Result<(), WebError> {
89 for middleware in self.middlewares.iter().rev() {
90 middleware.after(request, response).await?;
91 }
92 Ok(())
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use bytes::Bytes;
100 use http::{HeaderMap, Method};
101 use std::collections::HashMap;
102 use std::sync::atomic::{AtomicBool, Ordering};
103
104 fn req() -> Request {
105 Request {
106 method: Method::GET,
107 path_parts: Vec::new(),
108 query_params: HashMap::new(),
109 body: Bytes::new(),
110 headers: HeaderMap::new(),
111 rate_limit_class: None,
112 }
113 }
114
115 struct Continues;
117 #[async_trait]
118 impl Middleware for Continues {}
119
120 struct ShortCircuits;
121 #[async_trait]
122 impl Middleware for ShortCircuits {
123 async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
124 Ok(Outcome::Respond(ReplyData::Empty))
125 }
126 }
127
128 struct Records(Arc<AtomicBool>);
130 #[async_trait]
131 impl Middleware for Records {
132 async fn before(&self, _request: &mut Request) -> Result<Outcome, WebError> {
133 self.0.store(true, Ordering::SeqCst);
134 Ok(Outcome::Continue)
135 }
136 }
137
138 #[tokio::test]
139 async fn before_short_circuit_skips_the_rest() {
140 let later_ran = Arc::new(AtomicBool::new(false));
141 let mut chain = MiddlewareChain::new();
142 chain.add(Continues);
143 chain.add(ShortCircuits);
144 chain.add(Records(later_ran.clone()));
145 match chain.process_request(&mut req()).await {
146 Ok(Outcome::Respond(ReplyData::Empty)) => {}
147 other => panic!("expected Respond(Empty), got {other:?}"),
148 }
149 assert!(
150 !later_ran.load(Ordering::SeqCst),
151 "a middleware after the short-circuit still ran"
152 );
153 }
154
155 #[tokio::test]
156 async fn all_continue_yields_continue() {
157 let mut chain = MiddlewareChain::new();
158 chain.add(Continues);
159 chain.add(Continues);
160 assert!(matches!(
161 chain.process_request(&mut req()).await,
162 Ok(Outcome::Continue)
163 ));
164 }
165
166 struct Trace {
170 tag: &'static str,
171 log: Arc<std::sync::Mutex<Vec<String>>>,
172 }
173 #[async_trait]
174 impl Middleware for Trace {
175 async fn after(
176 &self,
177 request: &Request,
178 _response: &mut ReplyData,
179 ) -> Result<(), WebError> {
180 let path = request.path_parts.join("/");
181 self.log
182 .lock()
183 .unwrap()
184 .push(format!("{}:{}", self.tag, path));
185 Ok(())
186 }
187 }
188
189 #[tokio::test]
190 async fn after_runs_in_reverse_order_and_sees_the_request() {
191 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
192 let mut chain = MiddlewareChain::new();
193 chain.add(Trace {
194 tag: "A",
195 log: log.clone(),
196 });
197 chain.add(Trace {
198 tag: "B",
199 log: log.clone(),
200 });
201 chain.add(Trace {
202 tag: "C",
203 log: log.clone(),
204 });
205 let mut request = req();
206 request.path_parts = vec!["api".into(), "users".into()];
207 let mut data = ReplyData::Empty;
208 chain.process_response(&request, &mut data).await.unwrap();
209 assert_eq!(
211 *log.lock().unwrap(),
212 vec!["C:api/users", "B:api/users", "A:api/users"]
213 );
214 }
215}