aws_lambda_router/
middleware.rs1use async_trait::async_trait;
2use lambda_runtime::Error;
3use crate::{Request, Response};
4
5pub type Next = Box<dyn Fn(Request) -> futures::future::BoxFuture<'static, Result<Response, Error>> + Send + Sync>;
7
8#[async_trait]
10pub trait Middleware: Send + Sync {
11 async fn handle(&self, req: Request, next: Next) -> Result<Response, Error>;
13}
14
15pub struct MiddlewareFn<F>
17where
18 F: Fn(Request, Next) -> futures::future::BoxFuture<'static, Result<Response, Error>> + Send + Sync,
19{
20 func: F,
21}
22
23impl<F> MiddlewareFn<F>
24where
25 F: Fn(Request, Next) -> futures::future::BoxFuture<'static, Result<Response, Error>> + Send + Sync,
26{
27 pub fn new(func: F) -> Self {
28 Self { func }
29 }
30}
31
32#[async_trait]
33impl<F> Middleware for MiddlewareFn<F>
34where
35 F: Fn(Request, Next) -> futures::future::BoxFuture<'static, Result<Response, Error>> + Send + Sync,
36{
37 async fn handle(&self, req: Request, next: Next) -> Result<Response, Error> {
38 (self.func)(req, next).await
39 }
40}
41
42pub struct LoggingMiddleware;
44
45#[async_trait]
46impl Middleware for LoggingMiddleware {
47 async fn handle(&self, req: Request, next: Next) -> Result<Response, Error> {
48 println!("→ {} {}", req.method, req.path);
49 let response = next(req).await?;
50 println!("← {}", response.status_code);
51 Ok(response)
52 }
53}
54
55pub struct CorsMiddleware {
57 allow_origin: String,
58 allow_methods: String,
59 allow_headers: String,
60 max_age: String,
61}
62
63impl CorsMiddleware {
64 pub fn new() -> Self {
65 Self {
66 allow_origin: "*".to_string(),
67 allow_methods: "GET, POST, PUT, DELETE, OPTIONS".to_string(),
68 allow_headers: "Content-Type, Authorization".to_string(),
69 max_age: "3600".to_string(),
70 }
71 }
72
73 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
74 self.allow_origin = origin.into();
75 self
76 }
77
78 pub fn allow_methods(mut self, methods: impl Into<String>) -> Self {
79 self.allow_methods = methods.into();
80 self
81 }
82
83 pub fn allow_headers(mut self, headers: impl Into<String>) -> Self {
84 self.allow_headers = headers.into();
85 self
86 }
87}
88
89impl Default for CorsMiddleware {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95#[async_trait]
96impl Middleware for CorsMiddleware {
97 async fn handle(&self, req: Request, next: Next) -> Result<Response, Error> {
98 if req.is_preflight() {
100 return Ok(Response::cors_preflight());
101 }
102
103 let mut response = next(req).await?;
105 response = response
106 .header("Access-Control-Allow-Origin", &self.allow_origin)
107 .header("Access-Control-Allow-Methods", &self.allow_methods)
108 .header("Access-Control-Allow-Headers", &self.allow_headers)
109 .header("Access-Control-Max-Age", &self.max_age);
110
111 Ok(response)
112 }
113}