1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse};
3use crate::extract::FromRequest;
4use hyper::Method;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tower_service::Service;
11use regex::Regex;
12
13pub trait Endpoint: Send + Sync + 'static {
15 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
16}
17
18pub trait Handler<Args>: Clone + Send + Sync + 'static {
20 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
21}
22
23struct HandlerService<H, Args> {
25 handler: H,
26 _marker: std::marker::PhantomData<Args>,
27}
28
29impl<H, Args> Endpoint for HandlerService<H, Args>
30where
31 H: Handler<Args>,
32 Args: Send + Sync + 'static,
33{
34 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
35 self.handler.call(req)
36 }
37}
38
39impl<F, Fut> Handler<OxiditeRequest> for F
41where
42 F: Fn(OxiditeRequest) -> Fut + Clone + Send + Sync + 'static,
43 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
44{
45 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
46 let fut = self(req);
47 Box::pin(async move { fut.await })
48 }
49}
50
51impl<F, Fut> Handler<()> for F
53where
54 F: Fn() -> Fut + Clone + Send + Sync + 'static,
55 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
56{
57 fn call(&self, _req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
58 let fut = self();
59 Box::pin(async move { fut.await })
60 }
61}
62
63impl<F, Fut, T1> Handler<(T1,)> for F
65where
66 F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
67 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
68 T1: FromRequest + Send + 'static,
69{
70 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
71 let handler = self.clone();
72 Box::pin(async move {
73 let t1 = T1::from_request(&mut req).await?;
74 handler(t1).await
75 })
76 }
77}
78
79impl<F, Fut, T1, T2> Handler<(T1, T2)> for F
81where
82 F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
83 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
84 T1: FromRequest + Send + 'static,
85 T2: FromRequest + Send + 'static,
86{
87 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
88 let handler = self.clone();
89 Box::pin(async move {
90 let t1 = T1::from_request(&mut req).await?;
91 let t2 = T2::from_request(&mut req).await?;
92 handler(t1, t2).await
93 })
94 }
95}
96
97impl<F, Fut, T1, T2, T3> Handler<(T1, T2, T3)> for F
99where
100 F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
101 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
102 T1: FromRequest + Send + 'static,
103 T2: FromRequest + Send + 'static,
104 T3: FromRequest + Send + 'static,
105{
106 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
107 let handler = self.clone();
108 Box::pin(async move {
109 let t1 = T1::from_request(&mut req).await?;
110 let t2 = T2::from_request(&mut req).await?;
111 let t3 = T3::from_request(&mut req).await?;
112 handler(t1, t2, t3).await
113 })
114 }
115}
116
117struct Route {
118 pattern: Regex,
119 param_names: Vec<String>,
120 handler: Arc<dyn Endpoint>,
121}
122
123#[derive(Clone)]
124pub struct Router {
125 routes: HashMap<Method, Vec<Arc<Route>>>,
126 extensions: Arc<std::sync::RwLock<http::Extensions>>,
127}
128
129impl Router {
130 pub fn new() -> Self {
131 Self {
132 routes: HashMap::new(),
133 extensions: Arc::new(std::sync::RwLock::new(http::Extensions::new())),
134 }
135 }
136
137 pub fn with_state<T: Clone + Send + Sync + 'static>(&mut self, state: T) {
139 if let Ok(mut exts) = self.extensions.write() {
140 exts.insert(state);
141 }
142 }
143
144 pub fn get<H, Args>(&mut self, path: &str, handler: H)
145 where
146 H: Handler<Args>,
147 Args: Send + Sync + 'static,
148 {
149 self.add_route(Method::GET, path, handler);
150 }
151
152 pub fn post<H, Args>(&mut self, path: &str, handler: H)
153 where
154 H: Handler<Args>,
155 Args: Send + Sync + 'static,
156 {
157 self.add_route(Method::POST, path, handler);
158 }
159
160 pub fn put<H, Args>(&mut self, path: &str, handler: H)
161 where
162 H: Handler<Args>,
163 Args: Send + Sync + 'static,
164 {
165 self.add_route(Method::PUT, path, handler);
166 }
167
168 pub fn delete<H, Args>(&mut self, path: &str, handler: H)
169 where
170 H: Handler<Args>,
171 Args: Send + Sync + 'static,
172 {
173 self.add_route(Method::DELETE, path, handler);
174 }
175
176 pub fn patch<H, Args>(&mut self, path: &str, handler: H)
177 where
178 H: Handler<Args>,
179 Args: Send + Sync + 'static,
180 {
181 self.add_route(Method::PATCH, path, handler);
182 }
183
184 fn add_route<H, Args>(&mut self, method: Method, path: &str, handler: H)
185 where
186 H: Handler<Args>,
187 Args: Send + Sync + 'static,
188 {
189 let (pattern, param_names) = compile_path(path);
190 let endpoint = HandlerService {
191 handler,
192 _marker: std::marker::PhantomData,
193 };
194
195 let route = Arc::new(Route {
196 pattern,
197 param_names,
198 handler: Arc::new(endpoint),
199 });
200
201 self.routes
202 .entry(method)
203 .or_insert_with(Vec::new)
204 .push(route);
205 }
206
207 pub async fn handle(&self, mut req: OxiditeRequest) -> Result<OxiditeResponse> {
208 req.extensions_mut().insert(self.extensions.clone());
209 let method = req.method().clone();
210 let path = req.uri().path().to_string();
211
212 let try_match = |target_method: &Method, req: &mut OxiditeRequest| -> Option<Arc<Route>> {
214 if let Some(routes) = self.routes.get(target_method) {
215 for route in routes {
216 if let Some(captures) = route.pattern.captures(&path) {
217 let mut params = serde_json::Map::new();
219 for (i, name) in route.param_names.iter().enumerate() {
220 if let Some(value) = captures.get(i + 1) {
221 params.insert(
222 name.clone(),
223 serde_json::Value::String(value.as_str().to_string()),
224 );
225 }
226 }
227
228 if !params.is_empty() {
230 req.extensions_mut().insert(crate::extract::PathParams(
231 serde_json::Value::Object(params),
232 ));
233 }
234
235 return Some(route.clone());
236 }
237 }
238 }
239 None
240 };
241
242 if let Some(route) = try_match(&method, &mut req) {
244 return route.handler.call(req).await;
245 }
246
247 if method == Method::HEAD {
249 if let Some(route) = try_match(&Method::GET, &mut req) {
250 return route.handler.call(req).await;
253 }
254 }
255
256 let allowed_methods: Vec<String> = self
258 .routes
259 .iter()
260 .filter(|(route_method, _)| **route_method != method)
261 .filter_map(|(route_method, routes)| {
262 if routes.iter().any(|route| route.pattern.is_match(&path)) {
263 Some(route_method.as_str().to_string())
264 } else {
265 None
266 }
267 })
268 .collect();
269 if !allowed_methods.is_empty() {
270 return Err(Error::MethodNotAllowed(format!(
271 "{} {} (allowed: {})",
272 method,
273 path,
274 allowed_methods.join(", ")
275 )));
276 }
277
278 Err(Error::NotFound("Route not found".to_string()))
279 }
280}
281
282impl Service<OxiditeRequest> for Router {
283 type Response = OxiditeResponse;
284 type Error = Error;
285 type Future = Pin<Box<dyn Future<Output = Result<Self::Response>> + Send>>;
286
287 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
288 Poll::Ready(Ok(()))
289 }
290
291 fn call(&mut self, req: OxiditeRequest) -> Self::Future {
292 let router = self.clone();
293 Box::pin(async move {
294 router.handle(req).await
295 })
296 }
297}
298
299impl Default for Router {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305fn compile_path(path: &str) -> (Regex, Vec<String>) {
308 let mut pattern = String::from("^");
309 let mut param_names = Vec::new();
310 let mut chars = path.chars().peekable();
311
312 while let Some(ch) = chars.next() {
313 match ch {
314 ':' => {
315 let mut param_name = String::new();
317 while let Some(&next_ch) = chars.peek() {
318 if next_ch.is_alphanumeric() || next_ch == '_' {
319 param_name.push(next_ch);
320 chars.next();
321 } else {
322 break;
323 }
324 }
325 param_names.push(param_name);
326 pattern.push_str("([^/]+)");
327 }
328 '*' => {
329 pattern.push_str("(.*)");
331 }
332 '.' | '+' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
333 pattern.push('\\');
335 pattern.push(ch);
336 }
337 _ => {
338 pattern.push(ch);
339 }
340 }
341 }
342
343 pattern.push('$');
344 let regex = Regex::new(&pattern).expect("Invalid route pattern");
345 (regex, param_names)
346}
347
348pub trait IntoHandler<Args>: Handler<Args> + Sized {
367 fn into_handler(self) -> Self {
368 self
369 }
370}
371
372impl<H, Args> IntoHandler<Args> for H where H: Handler<Args> {}
373
374pub fn handler_fn<H, Args>(handler: H) -> H
392where
393 H: IntoHandler<Args>,
394 Args: Send + Sync + 'static,
395{
396 handler
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use crate::types::BoxBody;
403
404 #[test]
405 fn test_compile_path() {
406 let (regex, params) = compile_path("/users/:id");
407 assert_eq!(params, vec!["id"]);
408 assert!(regex.is_match("/users/123"));
409 assert!(!regex.is_match("/users/123/posts"));
410
411 let (regex, params) = compile_path("/users/:user_id/posts/:post_id");
412 assert_eq!(params, vec!["user_id", "post_id"]);
413 assert!(regex.is_match("/users/1/posts/2"));
414 }
415
416 #[test]
417 fn test_exact_match() {
418 let (regex, params) = compile_path("/users");
419 assert_eq!(params.len(), 0);
420 assert!(regex.is_match("/users"));
421 assert!(!regex.is_match("/users/123"));
422 }
423
424 #[tokio::test]
425 async fn test_method_not_allowed_when_path_exists() {
426 let mut router = Router::new();
427 router.get("/users", || async { Ok(crate::OxiditeResponse::text("ok")) });
428 let req = http::Request::builder()
429 .method(Method::POST)
430 .uri("/users")
431 .body(BoxBody::default())
432 .expect("request");
433
434 let result = router.handle(req).await;
435 assert!(matches!(result, Err(Error::MethodNotAllowed(_))));
436 }
437
438 #[tokio::test]
439 async fn test_not_found_when_path_missing() {
440 let router = Router::new();
441 let req = http::Request::builder()
442 .method(Method::GET)
443 .uri("/missing")
444 .body(BoxBody::default())
445 .expect("request");
446
447 let result = router.handle(req).await;
448 assert!(matches!(result, Err(Error::NotFound(_))));
449 }
450}