1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse};
3use crate::extract::FromRequest;
4use hyper::Method;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tower_service::Service;
11use regex::Regex;
12
13pub trait Endpoint: Send + Sync + 'static {
15 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
16}
17
18pub trait Handler<Args>: Clone + Send + Sync + 'static {
20 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
21}
22
23struct HandlerService<H, Args> {
25 handler: H,
26 _marker: std::marker::PhantomData<Args>,
27}
28
29impl<H, Args> Endpoint for HandlerService<H, Args>
30where
31 H: Handler<Args>,
32 Args: Send + Sync + 'static,
33{
34 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
35 self.handler.call(req)
36 }
37}
38
39impl<F, Fut> Handler<OxiditeRequest> for F
41where
42 F: Fn(OxiditeRequest) -> Fut + Clone + Send + Sync + 'static,
43 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
44{
45 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
46 let fut = self(req);
47 Box::pin(async move { fut.await })
48 }
49}
50
51impl<F, Fut> Handler<()> for F
53where
54 F: Fn() -> Fut + Clone + Send + Sync + 'static,
55 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
56{
57 fn call(&self, _req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
58 let fut = self();
59 Box::pin(async move { fut.await })
60 }
61}
62
63impl<F, Fut, T1> Handler<(T1,)> for F
65where
66 F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
67 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
68 T1: FromRequest + Send + 'static,
69{
70 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
71 let handler = self.clone();
72 Box::pin(async move {
73 let t1 = T1::from_request(&mut req).await?;
74 handler(t1).await
75 })
76 }
77}
78
79impl<F, Fut, T1, T2> Handler<(T1, T2)> for F
81where
82 F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
83 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
84 T1: FromRequest + Send + 'static,
85 T2: FromRequest + Send + 'static,
86{
87 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
88 let handler = self.clone();
89 Box::pin(async move {
90 let t1 = T1::from_request(&mut req).await?;
91 let t2 = T2::from_request(&mut req).await?;
92 handler(t1, t2).await
93 })
94 }
95}
96
97impl<F, Fut, T1, T2, T3> Handler<(T1, T2, T3)> for F
99where
100 F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
101 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
102 T1: FromRequest + Send + 'static,
103 T2: FromRequest + Send + 'static,
104 T3: FromRequest + Send + 'static,
105{
106 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
107 let handler = self.clone();
108 Box::pin(async move {
109 let t1 = T1::from_request(&mut req).await?;
110 let t2 = T2::from_request(&mut req).await?;
111 let t3 = T3::from_request(&mut req).await?;
112 handler(t1, t2, t3).await
113 })
114 }
115}
116
117struct Route {
118 pattern: Regex,
119 param_names: Vec<String>,
120 handler: Arc<dyn Endpoint>,
121}
122
123#[derive(Clone)]
124pub struct Router {
125 routes: HashMap<Method, Vec<Arc<Route>>>,
126}
127
128impl Router {
129 pub fn new() -> Self {
130 Self {
131 routes: HashMap::new(),
132 }
133 }
134
135 pub fn get<H, Args>(&mut self, path: &str, handler: H)
136 where
137 H: Handler<Args>,
138 Args: Send + Sync + 'static,
139 {
140 self.add_route(Method::GET, path, handler);
141 }
142
143 pub fn post<H, Args>(&mut self, path: &str, handler: H)
144 where
145 H: Handler<Args>,
146 Args: Send + Sync + 'static,
147 {
148 self.add_route(Method::POST, path, handler);
149 }
150
151 pub fn put<H, Args>(&mut self, path: &str, handler: H)
152 where
153 H: Handler<Args>,
154 Args: Send + Sync + 'static,
155 {
156 self.add_route(Method::PUT, path, handler);
157 }
158
159 pub fn delete<H, Args>(&mut self, path: &str, handler: H)
160 where
161 H: Handler<Args>,
162 Args: Send + Sync + 'static,
163 {
164 self.add_route(Method::DELETE, path, handler);
165 }
166
167 pub fn patch<H, Args>(&mut self, path: &str, handler: H)
168 where
169 H: Handler<Args>,
170 Args: Send + Sync + 'static,
171 {
172 self.add_route(Method::PATCH, path, handler);
173 }
174
175 fn add_route<H, Args>(&mut self, method: Method, path: &str, handler: H)
176 where
177 H: Handler<Args>,
178 Args: Send + Sync + 'static,
179 {
180 let (pattern, param_names) = compile_path(path);
181 let endpoint = HandlerService {
182 handler,
183 _marker: std::marker::PhantomData,
184 };
185
186 let route = Arc::new(Route {
187 pattern,
188 param_names,
189 handler: Arc::new(endpoint),
190 });
191
192 self.routes
193 .entry(method)
194 .or_insert_with(Vec::new)
195 .push(route);
196 }
197
198 pub async fn handle(&self, mut req: OxiditeRequest) -> Result<OxiditeResponse> {
199 let method = req.method().clone();
200 let path = req.uri().path().to_string();
201
202 let try_match = |target_method: &Method, req: &mut OxiditeRequest| -> Option<Arc<Route>> {
204 if let Some(routes) = self.routes.get(target_method) {
205 for route in routes {
206 if let Some(captures) = route.pattern.captures(&path) {
207 let mut params = serde_json::Map::new();
209 for (i, name) in route.param_names.iter().enumerate() {
210 if let Some(value) = captures.get(i + 1) {
211 params.insert(
212 name.clone(),
213 serde_json::Value::String(value.as_str().to_string()),
214 );
215 }
216 }
217
218 if !params.is_empty() {
220 req.extensions_mut().insert(crate::extract::PathParams(
221 serde_json::Value::Object(params),
222 ));
223 }
224
225 return Some(route.clone());
226 }
227 }
228 }
229 None
230 };
231
232 if let Some(route) = try_match(&method, &mut req) {
234 return route.handler.call(req).await;
235 }
236
237 if method == Method::HEAD {
239 if let Some(route) = try_match(&Method::GET, &mut req) {
240 return route.handler.call(req).await;
243 }
244 }
245
246 let allowed_methods: Vec<String> = self
248 .routes
249 .iter()
250 .filter(|(route_method, _)| **route_method != method)
251 .filter_map(|(route_method, routes)| {
252 if routes.iter().any(|route| route.pattern.is_match(&path)) {
253 Some(route_method.as_str().to_string())
254 } else {
255 None
256 }
257 })
258 .collect();
259 if !allowed_methods.is_empty() {
260 return Err(Error::MethodNotAllowed(format!(
261 "{} {} (allowed: {})",
262 method,
263 path,
264 allowed_methods.join(", ")
265 )));
266 }
267
268 Err(Error::NotFound("Route not found".to_string()))
269 }
270}
271
272impl Service<OxiditeRequest> for Router {
273 type Response = OxiditeResponse;
274 type Error = Error;
275 type Future = Pin<Box<dyn Future<Output = Result<Self::Response>> + Send>>;
276
277 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
278 Poll::Ready(Ok(()))
279 }
280
281 fn call(&mut self, req: OxiditeRequest) -> Self::Future {
282 let router = self.clone();
283 Box::pin(async move {
284 router.handle(req).await
285 })
286 }
287}
288
289impl Default for Router {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295fn compile_path(path: &str) -> (Regex, Vec<String>) {
298 let mut pattern = String::from("^");
299 let mut param_names = Vec::new();
300 let mut chars = path.chars().peekable();
301
302 while let Some(ch) = chars.next() {
303 match ch {
304 ':' => {
305 let mut param_name = String::new();
307 while let Some(&next_ch) = chars.peek() {
308 if next_ch.is_alphanumeric() || next_ch == '_' {
309 param_name.push(next_ch);
310 chars.next();
311 } else {
312 break;
313 }
314 }
315 param_names.push(param_name);
316 pattern.push_str("([^/]+)");
317 }
318 '*' => {
319 pattern.push_str("(.*)");
321 }
322 '.' | '+' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
323 pattern.push('\\');
325 pattern.push(ch);
326 }
327 _ => {
328 pattern.push(ch);
329 }
330 }
331 }
332
333 pattern.push('$');
334 let regex = Regex::new(&pattern).expect("Invalid route pattern");
335 (regex, param_names)
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::types::BoxBody;
342
343 #[test]
344 fn test_compile_path() {
345 let (regex, params) = compile_path("/users/:id");
346 assert_eq!(params, vec!["id"]);
347 assert!(regex.is_match("/users/123"));
348 assert!(!regex.is_match("/users/123/posts"));
349
350 let (regex, params) = compile_path("/users/:user_id/posts/:post_id");
351 assert_eq!(params, vec!["user_id", "post_id"]);
352 assert!(regex.is_match("/users/1/posts/2"));
353 }
354
355 #[test]
356 fn test_exact_match() {
357 let (regex, params) = compile_path("/users");
358 assert_eq!(params.len(), 0);
359 assert!(regex.is_match("/users"));
360 assert!(!regex.is_match("/users/123"));
361 }
362
363 #[tokio::test]
364 async fn test_method_not_allowed_when_path_exists() {
365 let mut router = Router::new();
366 router.get("/users", || async { Ok(crate::OxiditeResponse::text("ok")) });
367 let req = http::Request::builder()
368 .method(Method::POST)
369 .uri("/users")
370 .body(BoxBody::default())
371 .expect("request");
372
373 let result = router.handle(req).await;
374 assert!(matches!(result, Err(Error::MethodNotAllowed(_))));
375 }
376
377 #[tokio::test]
378 async fn test_not_found_when_path_missing() {
379 let router = Router::new();
380 let req = http::Request::builder()
381 .method(Method::GET)
382 .uri("/missing")
383 .body(BoxBody::default())
384 .expect("request");
385
386 let result = router.handle(req).await;
387 assert!(matches!(result, Err(Error::NotFound(_))));
388 }
389}