1use async_trait::async_trait;
2use futures::future::BoxFuture;
3use lambda_runtime::{Error, LambdaEvent};
4use serde_json::Value;
5use std::sync::Arc;
6
7use crate::middleware::CorsMiddleware;
8use crate::{Context, Middleware, PathMatcher, Request, Response, Result, RouterError};
9
10pub type HandlerFn =
12 Arc<dyn Fn(Request, Context) -> BoxFuture<'static, Result<Response>> + Send + Sync>;
13
14#[async_trait]
16pub trait Handler: Send + Sync {
17 async fn handle(&self, req: Request, ctx: Context) -> Result<Response>;
18}
19
20struct Route {
22 method: String,
23 matcher: PathMatcher,
24 handler: HandlerFn,
25}
26
27pub struct Router {
29 routes: Vec<Route>,
30 middlewares: Vec<Arc<dyn Middleware>>,
31 not_found_handler: Option<HandlerFn>,
32}
33
34impl Router {
35 pub fn new() -> Self {
37 Self {
38 routes: Vec::new(),
39 middlewares: vec![Arc::new(CorsMiddleware::new())],
40 not_found_handler: None,
41 }
42 }
43
44 pub fn use_middleware(&mut self, middleware: impl Middleware + 'static) {
46 self.middlewares.push(Arc::new(middleware));
47 }
48
49 pub fn not_found<F>(&mut self, handler: F)
51 where
52 F: Fn(Request, Context) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static,
53 {
54 self.not_found_handler = Some(Arc::new(handler));
55 }
56
57 pub fn get<F>(&mut self, path: &str, handler: F)
59 where
60 F: Fn(Request, Context) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static,
61 {
62 self.add_route("GET", path, handler);
63 }
64
65 pub fn post<F>(&mut self, path: &str, handler: F)
67 where
68 F: Fn(Request, Context) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static,
69 {
70 self.add_route("POST", path, handler);
71 }
72
73 pub fn put<F>(&mut self, path: &str, handler: F)
75 where
76 F: Fn(Request, Context) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static,
77 {
78 self.add_route("PUT", path, handler);
79 }
80
81 pub fn delete<F>(&mut self, path: &str, handler: F)
83 where
84 F: Fn(Request, Context) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static,
85 {
86 self.add_route("DELETE", path, handler);
87 }
88
89 pub fn patch<F>(&mut self, path: &str, handler: F)
91 where
92 F: Fn(Request, Context) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static,
93 {
94 self.add_route("PATCH", path, handler);
95 }
96
97 pub fn add_route<F>(&mut self, method: &str, path: &str, handler: F)
99 where
100 F: Fn(Request, Context) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static,
101 {
102 self.routes.push(Route {
103 method: method.to_uppercase(),
104 matcher: PathMatcher::new(path),
105 handler: Arc::new(handler),
106 });
107 }
108
109 async fn handle_request(&self, mut req: Request) -> Result<Response> {
111 let route = self
113 .routes
114 .iter()
115 .find(|r| r.method == req.method && r.matcher.matches(&req.path).is_some());
116
117 match route {
118 Some(route) => {
119 if let Some(params) = route.matcher.matches(&req.path) {
121 req.set_path_params(params);
122 }
123
124 let handler = route.handler.clone();
126 let middlewares = self.middlewares.clone();
127
128 if middlewares.is_empty() {
130 let ctx = req.context.clone();
132 (handler)(req, ctx).await
133 } else {
134 self.execute_middleware_chain(req, middlewares, handler)
136 .await
137 }
138 }
139 None => {
140 if let Some(handler) = &self.not_found_handler {
141 let ctx = req.context.clone();
142 (handler)(req, ctx).await
143 } else {
144 Err(RouterError::RouteNotFound {
145 method: req.method.clone(),
146 path: req.path.clone(),
147 })
148 }
149 }
150 }
151 }
152
153 async fn execute_middleware_chain(
155 &self,
156 req: Request,
157 middlewares: Vec<Arc<dyn Middleware>>,
158 handler: Arc<
159 dyn Fn(Request, Context) -> BoxFuture<'static, Result<Response>> + Send + Sync,
160 >,
161 ) -> Result<Response> {
162 use std::sync::Arc as StdArc;
163
164 let final_handler: StdArc<
166 dyn Fn(Request) -> BoxFuture<'static, std::result::Result<Response, Error>>
167 + Send
168 + Sync,
169 > = StdArc::new(move |req: Request| {
170 let handler = handler.clone();
171 let ctx = req.context.clone();
172 Box::pin(async move {
173 handler(req, ctx)
174 .await
175 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
176 })
177 });
178
179 let mut current_handler = final_handler;
181
182 for middleware in middlewares.into_iter().rev() {
183 let next_handler = current_handler.clone();
184 current_handler = StdArc::new(move |req: Request| {
185 let middleware = middleware.clone();
186 let next = next_handler.clone();
187 let next_fn: Box<
188 dyn Fn(Request) -> BoxFuture<'static, std::result::Result<Response, Error>>
189 + Send
190 + Sync,
191 > = Box::new(move |req: Request| {
192 let next = next.clone();
193 (next)(req)
194 });
195 Box::pin(async move { middleware.handle(req, next_fn).await })
196 });
197 }
198
199 (current_handler)(req)
201 .await
202 .map_err(|e| RouterError::HandlerError(anyhow::anyhow!("{}", e)))
203 }
204
205 pub fn into_service(
207 self,
208 ) -> impl Fn(LambdaEvent<Value>) -> BoxFuture<'static, std::result::Result<Value, Error>> {
209 let router = Arc::new(self);
210
211 move |event: LambdaEvent<Value>| {
212 let router = router.clone();
213 Box::pin(async move {
214 let (event_payload, _context) = event.into_parts();
215
216 let req = Request::from_lambda_event(event_payload);
218
219 if req.is_preflight() {
221 return Ok(Response::cors_preflight().to_json());
222 }
223
224 let response = match router.handle_request(req).await {
226 Ok(resp) => resp,
227 Err(e) => e.to_response(),
228 };
229
230 Ok(response.to_json())
231 })
232 }
233 }
234}
235
236impl Default for Router {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242#[macro_export]
244macro_rules! handler {
245 ($func:expr) => {
246 |req: Request, ctx: Context| Box::pin($func(req, ctx))
247 };
248}