1pub mod logging;
7pub mod timing;
8pub mod tracing;
9pub mod timeout;
10pub mod body_limit;
11
12use std::future::Future;
13use std::pin::Pin;
14use axum::{
15 response::{Response, IntoResponse},
16 extract::Request,
17};
18
19use crate::{HttpResult, HttpError};
20
21pub type MiddlewareResult = HttpResult<Response>;
23
24pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
26
27pub trait Middleware: Send + Sync {
30 fn process_request<'a>(
33 &'a self,
34 request: Request
35 ) -> BoxFuture<'a, Result<Request, Response>> {
36 Box::pin(async move { Ok(request) })
37 }
38
39 fn process_response<'a>(
42 &'a self,
43 response: Response
44 ) -> BoxFuture<'a, Response> {
45 Box::pin(async move { response })
46 }
47
48 fn name(&self) -> &'static str {
50 "Middleware"
51 }
52}
53
54#[derive(Default)]
56pub struct MiddlewarePipeline {
57 middleware: Vec<Box<dyn Middleware>>,
58}
59
60impl MiddlewarePipeline {
61 pub fn new() -> Self {
63 Self {
64 middleware: Vec::new(),
65 }
66 }
67
68 pub fn add<M: Middleware + 'static>(mut self, middleware: M) -> Self {
70 self.middleware.push(Box::new(middleware));
71 self
72 }
73
74 pub async fn process_request(&self, mut request: Request) -> Result<Request, Response> {
76 for middleware in &self.middleware {
77 match middleware.process_request(request).await {
78 Ok(req) => request = req,
79 Err(response) => return Err(response),
80 }
81 }
82 Ok(request)
83 }
84
85 pub async fn process_response(&self, mut response: Response) -> Response {
87 for middleware in self.middleware.iter().rev() {
89 response = middleware.process_response(response).await;
90 }
91 response
92 }
93
94 pub fn len(&self) -> usize {
96 self.middleware.len()
97 }
98
99 pub fn is_empty(&self) -> bool {
101 self.middleware.is_empty()
102 }
103
104 pub fn names(&self) -> Vec<&'static str> {
106 self.middleware.iter().map(|m| m.name()).collect()
107 }
108}
109
110pub struct ErrorHandlingMiddleware<M> {
112 inner: M,
113}
114
115impl<M> ErrorHandlingMiddleware<M> {
116 pub fn new(middleware: M) -> Self {
117 Self { inner: middleware }
118 }
119}
120
121impl<M: Middleware> Middleware for ErrorHandlingMiddleware<M> {
122 fn process_request<'a>(
123 &'a self,
124 request: Request
125 ) -> BoxFuture<'a, Result<Request, Response>> {
126 Box::pin(async move {
127 match self.inner.process_request(request).await {
129 Ok(req) => Ok(req),
130 Err(response) => Err(response),
131 }
132 })
133 }
134
135 fn process_response<'a>(
136 &'a self,
137 response: Response
138 ) -> BoxFuture<'a, Response> {
139 Box::pin(async move {
140 self.inner.process_response(response).await
141 })
142 }
143
144 fn name(&self) -> &'static str {
145 self.inner.name()
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use axum::http::{StatusCode, Method};
153
154 struct TestMiddleware {
155 name: &'static str,
156 }
157
158 impl TestMiddleware {
159 fn new(name: &'static str) -> Self {
160 Self { name }
161 }
162 }
163
164 impl Middleware for TestMiddleware {
165 fn process_request<'a>(
166 &'a self,
167 mut request: Request
168 ) -> BoxFuture<'a, Result<Request, Response>> {
169 Box::pin(async move {
170 let headers = request.headers_mut();
172 headers.insert("X-Middleware", self.name.parse().unwrap());
173 Ok(request)
174 })
175 }
176
177 fn process_response<'a>(
178 &'a self,
179 mut response: Response
180 ) -> BoxFuture<'a, Response> {
181 Box::pin(async move {
182 let headers = response.headers_mut();
184 headers.insert("X-Response-Middleware", self.name.parse().unwrap());
185 response
186 })
187 }
188
189 fn name(&self) -> &'static str {
190 self.name
191 }
192 }
193
194 #[tokio::test]
195 async fn test_middleware_pipeline() {
196 let pipeline = MiddlewarePipeline::new()
197 .add(TestMiddleware::new("First"))
198 .add(TestMiddleware::new("Second"));
199
200 let request = Request::builder()
202 .method(Method::GET)
203 .uri("/test")
204 .body(axum::body::Body::empty())
205 .unwrap();
206
207 let processed_request = pipeline.process_request(request).await.unwrap();
209
210 assert_eq!(
212 processed_request.headers().get("X-Middleware").unwrap(),
213 "Second"
214 );
215
216 let response = Response::builder()
218 .status(StatusCode::OK)
219 .body(axum::body::Body::empty())
220 .unwrap();
221
222 let processed_response = pipeline.process_response(response).await;
224
225 assert_eq!(
227 processed_response.headers().get("X-Response-Middleware").unwrap(),
228 "First"
229 );
230 }
231
232 #[tokio::test]
233 async fn test_pipeline_info() {
234 let pipeline = MiddlewarePipeline::new()
235 .add(TestMiddleware::new("Test1"))
236 .add(TestMiddleware::new("Test2"));
237
238 assert_eq!(pipeline.len(), 2);
239 assert!(!pipeline.is_empty());
240 assert_eq!(pipeline.names(), vec!["Test1", "Test2"]);
241 }
242
243 #[tokio::test]
244 async fn test_empty_pipeline() {
245 let pipeline = MiddlewarePipeline::new();
246
247 let request = Request::builder()
248 .method(Method::GET)
249 .uri("/test")
250 .body(axum::body::Body::empty())
251 .unwrap();
252
253 let processed_request = pipeline.process_request(request).await.unwrap();
254
255 assert_eq!(processed_request.method(), Method::GET);
257 assert_eq!(processed_request.uri().path(), "/test");
258 }
259}