1use crate::server::{Request, Response};
7use crate::error::Result;
8use async_trait::async_trait;
9use http::{HeaderName, HeaderValue, StatusCode};
10use std::sync::Arc;
11
12#[async_trait]
14pub trait Middleware: Send + Sync {
15 async fn handle(&self, request: Request, next: Next) -> Result<Response>;
22}
23
24pub struct Next {
26 middleware: Vec<Arc<dyn Middleware>>,
27 index: usize,
28 handler: Option<Arc<dyn Handler>>,
29}
30
31impl Next {
32 pub fn new(middleware: Vec<Arc<dyn Middleware>>, handler: Arc<dyn Handler>) -> Self {
34 Self {
35 middleware,
36 index: 0,
37 handler: Some(handler),
38 }
39 }
40
41 pub async fn run(mut self, request: Request) -> Result<Response> {
43 if self.index < self.middleware.len() {
44 let middleware = self.middleware[self.index].clone();
45 self.index += 1;
46 middleware.handle(request, self).await
47 } else if let Some(handler) = &self.handler {
48 handler.handle(request).await
49 } else {
50 Ok(Response::text("Not Found").with_status(StatusCode::NOT_FOUND))
51 }
52 }
53}
54
55#[async_trait]
57pub trait Handler: Send + Sync {
58 async fn handle(&self, request: Request) -> Result<Response>;
60}
61
62pub struct Logger;
64
65impl Logger {
66 pub fn new() -> Self {
68 Self
69 }
70}
71
72#[async_trait]
73impl Middleware for Logger {
74 async fn handle(&self, request: Request, next: Next) -> Result<Response> {
75 let method = request.method.clone();
76 let path = request.path.clone();
77 let start = std::time::Instant::now();
78
79 println!("[REQUEST] {} {}", method, path);
80
81 let response = next.run(request).await?;
82
83 let duration = start.elapsed();
84 println!(
85 "[RESPONSE] {} {} - {} ({:?})",
86 method, path, response.status(), duration
87 );
88
89 Ok(response)
90 }
91}
92
93impl Default for Logger {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99pub struct Cors {
101 allow_origin: String,
102 allow_methods: Vec<String>,
103 allow_headers: Vec<String>,
104 max_age: u32,
105}
106
107impl Cors {
108 pub fn permissive() -> Self {
110 Self {
111 allow_origin: "*".to_string(),
112 allow_methods: vec![
113 "GET".to_string(),
114 "POST".to_string(),
115 "PUT".to_string(),
116 "DELETE".to_string(),
117 "PATCH".to_string(),
118 "OPTIONS".to_string(),
119 ],
120 allow_headers: vec![
121 "Content-Type".to_string(),
122 "Authorization".to_string(),
123 "X-Requested-With".to_string(),
124 ],
125 max_age: 86400,
126 }
127 }
128
129 pub fn new(allow_origin: impl Into<String>) -> Self {
131 Self {
132 allow_origin: allow_origin.into(),
133 allow_methods: vec!["GET".to_string(), "POST".to_string()],
134 allow_headers: vec!["Content-Type".to_string()],
135 max_age: 3600,
136 }
137 }
138
139 pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
141 self.allow_methods = methods;
142 self
143 }
144
145 pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
147 self.allow_headers = headers;
148 self
149 }
150
151 pub fn max_age(mut self, seconds: u32) -> Self {
153 self.max_age = seconds;
154 self
155 }
156}
157
158#[async_trait]
159impl Middleware for Cors {
160 async fn handle(&self, request: Request, next: Next) -> Result<Response> {
161 if request.method.as_str() == "OPTIONS" {
163 let response = Response::text("")
164 .with_header(
165 HeaderName::from_static("access-control-allow-origin"),
166 HeaderValue::from_str(&self.allow_origin).unwrap(),
167 )
168 .with_header(
169 HeaderName::from_static("access-control-allow-methods"),
170 HeaderValue::from_str(&self.allow_methods.join(", ")).unwrap(),
171 )
172 .with_header(
173 HeaderName::from_static("access-control-allow-headers"),
174 HeaderValue::from_str(&self.allow_headers.join(", ")).unwrap(),
175 )
176 .with_header(
177 HeaderName::from_static("access-control-max-age"),
178 HeaderValue::from_str(&self.max_age.to_string()).unwrap(),
179 )
180 .with_status(StatusCode::NO_CONTENT);
181 return Ok(response);
182 }
183
184 let response = next.run(request).await?;
186 let response = response.with_header(
187 HeaderName::from_static("access-control-allow-origin"),
188 HeaderValue::from_str(&self.allow_origin).unwrap(),
189 );
190
191 Ok(response)
192 }
193}
194
195#[allow(dead_code)]
197pub struct RateLimit {
198 max_requests: usize,
199 window_secs: u64,
200 }
202
203impl RateLimit {
204 pub fn new(max_requests: usize, window_secs: u64) -> Self {
206 Self {
207 max_requests,
208 window_secs,
209 }
210 }
211}
212
213#[async_trait]
214impl Middleware for RateLimit {
215 async fn handle(&self, request: Request, next: Next) -> Result<Response> {
216 next.run(request).await
219 }
220}
221
222pub struct Auth {
224 token: String,
225}
226
227impl Auth {
228 pub fn bearer(token: impl Into<String>) -> Self {
230 Self {
231 token: token.into(),
232 }
233 }
234}
235
236#[async_trait]
237impl Middleware for Auth {
238 async fn handle(&self, request: Request, next: Next) -> Result<Response> {
239 if let Some(auth_header) = request.headers.get("authorization") {
241 if let Ok(auth_str) = auth_header.to_str() {
242 if auth_str.starts_with("Bearer ") {
243 let token = &auth_str[7..];
244 if token == self.token {
245 return next.run(request).await;
246 }
247 }
248 }
249 }
250
251 Ok(Response::text("Unauthorized")
252 .with_status(StatusCode::UNAUTHORIZED)
253 .with_header(
254 HeaderName::from_static("www-authenticate"),
255 HeaderValue::from_static("Bearer"),
256 ))
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use http::Method;
264
265 struct TestHandler;
266
267 #[async_trait]
268 impl Handler for TestHandler {
269 async fn handle(&self, _request: Request) -> Result<Response> {
270 Ok(Response::text("Hello"))
271 }
272 }
273
274 #[tokio::test]
275 async fn test_logger_middleware() {
276 let logger = Logger::new();
277 let handler = Arc::new(TestHandler);
278 let request = Request {
279 method: Method::GET,
280 path: "/test".to_string(),
281 headers: http::HeaderMap::new(),
282 body: bytes::Bytes::new(),
283 };
284
285 let next = Next::new(vec![], handler);
286 let response = logger.handle(request, next).await.unwrap();
287 assert_eq!(response.status(), StatusCode::OK);
288 }
289
290 #[tokio::test]
291 async fn test_cors_middleware() {
292 let cors = Cors::permissive();
293 let handler = Arc::new(TestHandler);
294 let request = Request {
295 method: Method::GET,
296 path: "/test".to_string(),
297 headers: http::HeaderMap::new(),
298 body: bytes::Bytes::new(),
299 };
300
301 let next = Next::new(vec![], handler);
302 let response = cors.handle(request, next).await.unwrap();
303
304 assert!(response.headers().contains_key("access-control-allow-origin"));
305 }
306
307 #[tokio::test]
308 async fn test_cors_preflight() {
309 let cors = Cors::permissive();
310 let handler = Arc::new(TestHandler);
311 let request = Request {
312 method: Method::OPTIONS,
313 path: "/test".to_string(),
314 headers: http::HeaderMap::new(),
315 body: bytes::Bytes::new(),
316 };
317
318 let next = Next::new(vec![], handler);
319 let response = cors.handle(request, next).await.unwrap();
320
321 assert_eq!(response.status(), StatusCode::NO_CONTENT);
322 assert!(response.headers().contains_key("access-control-allow-methods"));
323 }
324
325 #[tokio::test]
326 async fn test_auth_middleware_success() {
327 let auth = Auth::bearer("secret-token");
328 let handler = Arc::new(TestHandler);
329
330 let mut headers = http::HeaderMap::new();
331 headers.insert(
332 "authorization",
333 http::HeaderValue::from_static("Bearer secret-token"),
334 );
335
336 let request = Request {
337 method: Method::GET,
338 path: "/protected".to_string(),
339 headers,
340 body: bytes::Bytes::new(),
341 };
342
343 let next = Next::new(vec![], handler);
344 let response = auth.handle(request, next).await.unwrap();
345 assert_eq!(response.status(), StatusCode::OK);
346 }
347
348 #[tokio::test]
349 async fn test_auth_middleware_failure() {
350 let auth = Auth::bearer("secret-token");
351 let handler = Arc::new(TestHandler);
352 let request = Request {
353 method: Method::GET,
354 path: "/protected".to_string(),
355 headers: http::HeaderMap::new(),
356 body: bytes::Bytes::new(),
357 };
358
359 let next = Next::new(vec![], handler);
360 let response = auth.handle(request, next).await.unwrap();
361 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
362 }
363}