1pub mod logging;
7pub mod enhanced_logging;
8pub mod timing;
9pub mod tracing;
10pub mod timeout;
11pub mod body_limit;
12
13use std::future::Future;
14use std::pin::Pin;
15use axum::{
16 response::{Response, IntoResponse},
17 extract::Request,
18};
19
20use crate::{HttpResult, HttpError};
21
22pub type MiddlewareResult = HttpResult<Response>;
24
25pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
27
28pub trait Middleware: Send + Sync {
31 fn process_request<'a>(
34 &'a self,
35 request: Request
36 ) -> BoxFuture<'a, Result<Request, Response>> {
37 Box::pin(async move { Ok(request) })
38 }
39
40 fn process_response<'a>(
43 &'a self,
44 response: Response
45 ) -> BoxFuture<'a, Response> {
46 Box::pin(async move { response })
47 }
48
49 fn name(&self) -> &'static str {
51 "Middleware"
52 }
53}
54
55#[derive(Default)]
57pub struct MiddlewarePipeline {
58 middleware: Vec<Box<dyn Middleware>>,
59}
60
61impl MiddlewarePipeline {
62 pub fn new() -> Self {
64 Self {
65 middleware: Vec::new(),
66 }
67 }
68
69 pub fn add<M: Middleware + 'static>(mut self, middleware: M) -> Self {
71 self.middleware.push(Box::new(middleware));
72 self
73 }
74
75 pub async fn process_request(&self, mut request: Request) -> Result<Request, Response> {
77 for middleware in &self.middleware {
78 match middleware.process_request(request).await {
79 Ok(req) => request = req,
80 Err(response) => return Err(response),
81 }
82 }
83 Ok(request)
84 }
85
86 pub async fn process_response(&self, mut response: Response) -> Response {
88 for middleware in self.middleware.iter().rev() {
90 response = middleware.process_response(response).await;
91 }
92 response
93 }
94
95 pub fn len(&self) -> usize {
97 self.middleware.len()
98 }
99
100 pub fn is_empty(&self) -> bool {
102 self.middleware.is_empty()
103 }
104
105 pub fn names(&self) -> Vec<&'static str> {
107 self.middleware.iter().map(|m| m.name()).collect()
108 }
109}
110
111pub struct ErrorHandlingMiddleware<M> {
113 inner: M,
114}
115
116impl<M> ErrorHandlingMiddleware<M> {
117 pub fn new(middleware: M) -> Self {
118 Self { inner: middleware }
119 }
120}
121
122impl<M: Middleware> Middleware for ErrorHandlingMiddleware<M> {
123 fn process_request<'a>(
124 &'a self,
125 request: Request
126 ) -> BoxFuture<'a, Result<Request, Response>> {
127 Box::pin(async move {
128 match self.inner.process_request(request).await {
130 Ok(req) => Ok(req),
131 Err(response) => Err(response),
132 }
133 })
134 }
135
136 fn process_response<'a>(
137 &'a self,
138 response: Response
139 ) -> BoxFuture<'a, Response> {
140 Box::pin(async move {
141 self.inner.process_response(response).await
142 })
143 }
144
145 fn name(&self) -> &'static str {
146 self.inner.name()
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::response::ElifStatusCode as StatusCode;
154 use axum::http::Method; struct TestMiddleware {
157 name: &'static str,
158 }
159
160 impl TestMiddleware {
161 fn new(name: &'static str) -> Self {
162 Self { name }
163 }
164 }
165
166 impl Middleware for TestMiddleware {
167 fn process_request<'a>(
168 &'a self,
169 mut request: Request
170 ) -> BoxFuture<'a, Result<Request, Response>> {
171 Box::pin(async move {
172 let headers = request.headers_mut();
174 headers.insert("X-Middleware", self.name.parse().unwrap());
175 Ok(request)
176 })
177 }
178
179 fn process_response<'a>(
180 &'a self,
181 mut response: Response
182 ) -> BoxFuture<'a, Response> {
183 Box::pin(async move {
184 let headers = response.headers_mut();
186 headers.insert("X-Response-Middleware", self.name.parse().unwrap());
187 response
188 })
189 }
190
191 fn name(&self) -> &'static str {
192 self.name
193 }
194 }
195
196 #[tokio::test]
197 async fn test_middleware_pipeline() {
198 let pipeline = MiddlewarePipeline::new()
199 .add(TestMiddleware::new("First"))
200 .add(TestMiddleware::new("Second"));
201
202 let request = Request::builder()
204 .method(Method::GET)
205 .uri("/test")
206 .body(axum::body::Body::empty())
207 .unwrap();
208
209 let processed_request = pipeline.process_request(request).await.unwrap();
211
212 assert_eq!(
214 processed_request.headers().get("X-Middleware").unwrap(),
215 "Second"
216 );
217
218 let response = Response::builder()
220 .status(StatusCode::OK)
221 .body(axum::body::Body::empty())
222 .unwrap();
223
224 let processed_response = pipeline.process_response(response).await;
226
227 assert_eq!(
229 processed_response.headers().get("X-Response-Middleware").unwrap(),
230 "First"
231 );
232 }
233
234 #[tokio::test]
235 async fn test_pipeline_info() {
236 let pipeline = MiddlewarePipeline::new()
237 .add(TestMiddleware::new("Test1"))
238 .add(TestMiddleware::new("Test2"));
239
240 assert_eq!(pipeline.len(), 2);
241 assert!(!pipeline.is_empty());
242 assert_eq!(pipeline.names(), vec!["Test1", "Test2"]);
243 }
244
245 #[tokio::test]
246 async fn test_empty_pipeline() {
247 let pipeline = MiddlewarePipeline::new();
248
249 let request = Request::builder()
250 .method(Method::GET)
251 .uri("/test")
252 .body(axum::body::Body::empty())
253 .unwrap();
254
255 let processed_request = pipeline.process_request(request).await.unwrap();
256
257 assert_eq!(processed_request.method(), Method::GET);
259 assert_eq!(processed_request.uri().path(), "/test");
260 }
261}