1pub mod logging;
7pub mod timing;
8
9use std::future::Future;
10use std::pin::Pin;
11use axum::{
12 response::{Response, IntoResponse},
13 extract::Request,
14};
15
16use crate::{HttpResult, HttpError};
17
18pub type MiddlewareResult = HttpResult<Response>;
20
21pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
23
24pub trait Middleware: Send + Sync {
27 fn process_request<'a>(
30 &'a self,
31 request: Request
32 ) -> BoxFuture<'a, Result<Request, Response>> {
33 Box::pin(async move { Ok(request) })
34 }
35
36 fn process_response<'a>(
39 &'a self,
40 response: Response
41 ) -> BoxFuture<'a, Response> {
42 Box::pin(async move { response })
43 }
44
45 fn name(&self) -> &'static str {
47 "Middleware"
48 }
49}
50
51#[derive(Default)]
53pub struct MiddlewarePipeline {
54 middleware: Vec<Box<dyn Middleware>>,
55}
56
57impl MiddlewarePipeline {
58 pub fn new() -> Self {
60 Self {
61 middleware: Vec::new(),
62 }
63 }
64
65 pub fn add<M: Middleware + 'static>(mut self, middleware: M) -> Self {
67 self.middleware.push(Box::new(middleware));
68 self
69 }
70
71 pub async fn process_request(&self, mut request: Request) -> Result<Request, Response> {
73 for middleware in &self.middleware {
74 match middleware.process_request(request).await {
75 Ok(req) => request = req,
76 Err(response) => return Err(response),
77 }
78 }
79 Ok(request)
80 }
81
82 pub async fn process_response(&self, mut response: Response) -> Response {
84 for middleware in self.middleware.iter().rev() {
86 response = middleware.process_response(response).await;
87 }
88 response
89 }
90
91 pub fn len(&self) -> usize {
93 self.middleware.len()
94 }
95
96 pub fn is_empty(&self) -> bool {
98 self.middleware.is_empty()
99 }
100
101 pub fn names(&self) -> Vec<&'static str> {
103 self.middleware.iter().map(|m| m.name()).collect()
104 }
105}
106
107pub struct ErrorHandlingMiddleware<M> {
109 inner: M,
110}
111
112impl<M> ErrorHandlingMiddleware<M> {
113 pub fn new(middleware: M) -> Self {
114 Self { inner: middleware }
115 }
116}
117
118impl<M: Middleware> Middleware for ErrorHandlingMiddleware<M> {
119 fn process_request<'a>(
120 &'a self,
121 request: Request
122 ) -> BoxFuture<'a, Result<Request, Response>> {
123 Box::pin(async move {
124 match self.inner.process_request(request).await {
126 Ok(req) => Ok(req),
127 Err(response) => Err(response),
128 }
129 })
130 }
131
132 fn process_response<'a>(
133 &'a self,
134 response: Response
135 ) -> BoxFuture<'a, Response> {
136 Box::pin(async move {
137 self.inner.process_response(response).await
138 })
139 }
140
141 fn name(&self) -> &'static str {
142 self.inner.name()
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use axum::http::{StatusCode, Method};
150
151 struct TestMiddleware {
152 name: &'static str,
153 }
154
155 impl TestMiddleware {
156 fn new(name: &'static str) -> Self {
157 Self { name }
158 }
159 }
160
161 impl Middleware for TestMiddleware {
162 fn process_request<'a>(
163 &'a self,
164 mut request: Request
165 ) -> BoxFuture<'a, Result<Request, Response>> {
166 Box::pin(async move {
167 let headers = request.headers_mut();
169 headers.insert("X-Middleware", self.name.parse().unwrap());
170 Ok(request)
171 })
172 }
173
174 fn process_response<'a>(
175 &'a self,
176 mut response: Response
177 ) -> BoxFuture<'a, Response> {
178 Box::pin(async move {
179 let headers = response.headers_mut();
181 headers.insert("X-Response-Middleware", self.name.parse().unwrap());
182 response
183 })
184 }
185
186 fn name(&self) -> &'static str {
187 self.name
188 }
189 }
190
191 #[tokio::test]
192 async fn test_middleware_pipeline() {
193 let pipeline = MiddlewarePipeline::new()
194 .add(TestMiddleware::new("First"))
195 .add(TestMiddleware::new("Second"));
196
197 let request = Request::builder()
199 .method(Method::GET)
200 .uri("/test")
201 .body(axum::body::Body::empty())
202 .unwrap();
203
204 let processed_request = pipeline.process_request(request).await.unwrap();
206
207 assert_eq!(
209 processed_request.headers().get("X-Middleware").unwrap(),
210 "Second"
211 );
212
213 let response = Response::builder()
215 .status(StatusCode::OK)
216 .body(axum::body::Body::empty())
217 .unwrap();
218
219 let processed_response = pipeline.process_response(response).await;
221
222 assert_eq!(
224 processed_response.headers().get("X-Response-Middleware").unwrap(),
225 "First"
226 );
227 }
228
229 #[tokio::test]
230 async fn test_pipeline_info() {
231 let pipeline = MiddlewarePipeline::new()
232 .add(TestMiddleware::new("Test1"))
233 .add(TestMiddleware::new("Test2"));
234
235 assert_eq!(pipeline.len(), 2);
236 assert!(!pipeline.is_empty());
237 assert_eq!(pipeline.names(), vec!["Test1", "Test2"]);
238 }
239
240 #[tokio::test]
241 async fn test_empty_pipeline() {
242 let pipeline = MiddlewarePipeline::new();
243
244 let request = Request::builder()
245 .method(Method::GET)
246 .uri("/test")
247 .body(axum::body::Body::empty())
248 .unwrap();
249
250 let processed_request = pipeline.process_request(request).await.unwrap();
251
252 assert_eq!(processed_request.method(), Method::GET);
254 assert_eq!(processed_request.uri().path(), "/test");
255 }
256}