Skip to main content

oxidite_core/
router.rs

1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse};
3use crate::extract::FromRequest;
4use hyper::Method;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tower_service::Service;
11use regex::Regex;
12
13/// Trait for type-erased handlers stored in the router
14pub trait Endpoint: Send + Sync + 'static {
15    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
16}
17
18/// Trait for async functions that can be used as handlers
19pub trait Handler<Args>: Clone + Send + Sync + 'static {
20    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
21}
22
23// Wrapper to convert Handler<Args> into Endpoint
24struct HandlerService<H, Args> {
25    handler: H,
26    _marker: std::marker::PhantomData<Args>,
27}
28
29impl<H, Args> Endpoint for HandlerService<H, Args>
30where
31    H: Handler<Args>,
32    Args: Send + Sync + 'static,
33{
34    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
35        self.handler.call(req)
36    }
37}
38
39// Implement Handler for Fn(OxiditeRequest) -> Fut
40impl<F, Fut> Handler<OxiditeRequest> for F
41where
42    F: Fn(OxiditeRequest) -> Fut + Clone + Send + Sync + 'static,
43    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
44{
45    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
46        let fut = self(req);
47        Box::pin(async move { fut.await })
48    }
49}
50
51// Implement Handler for Fn() -> Fut
52impl<F, Fut> Handler<()> for F
53where
54    F: Fn() -> Fut + Clone + Send + Sync + 'static,
55    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
56{
57    fn call(&self, _req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
58        let fut = self();
59        Box::pin(async move { fut.await })
60    }
61}
62
63// Implement Handler for Fn(T1) -> Fut
64impl<F, Fut, T1> Handler<(T1,)> for F
65where
66    F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
67    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
68    T1: FromRequest + Send + 'static,
69{
70    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
71        let handler = self.clone();
72        Box::pin(async move {
73            let t1 = T1::from_request(&mut req).await?;
74            handler(t1).await
75        })
76    }
77}
78
79// Implement Handler for Fn(T1, T2) -> Fut
80impl<F, Fut, T1, T2> Handler<(T1, T2)> for F
81where
82    F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
83    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
84    T1: FromRequest + Send + 'static,
85    T2: FromRequest + Send + 'static,
86{
87    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
88        let handler = self.clone();
89        Box::pin(async move {
90            let t1 = T1::from_request(&mut req).await?;
91            let t2 = T2::from_request(&mut req).await?;
92            handler(t1, t2).await
93        })
94    }
95}
96
97// Implement Handler for Fn(T1, T2, T3) -> Fut
98impl<F, Fut, T1, T2, T3> Handler<(T1, T2, T3)> for F
99where
100    F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
101    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
102    T1: FromRequest + Send + 'static,
103    T2: FromRequest + Send + 'static,
104    T3: FromRequest + Send + 'static,
105{
106    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
107        let handler = self.clone();
108        Box::pin(async move {
109            let t1 = T1::from_request(&mut req).await?;
110            let t2 = T2::from_request(&mut req).await?;
111            let t3 = T3::from_request(&mut req).await?;
112            handler(t1, t2, t3).await
113        })
114    }
115}
116
117struct Route {
118    pattern: Regex,
119    param_names: Vec<String>,
120    handler: Arc<dyn Endpoint>,
121}
122
123#[derive(Clone)]
124pub struct Router {
125    routes: HashMap<Method, Vec<Arc<Route>>>,
126    extensions: Arc<std::sync::RwLock<http::Extensions>>,
127}
128
129impl Router {
130    pub fn new() -> Self {
131        Self {
132            routes: HashMap::new(),
133            extensions: Arc::new(std::sync::RwLock::new(http::Extensions::new())),
134        }
135    }
136
137    /// Add a shared state to the router that will be available in all handlers
138    pub fn with_state<T: Clone + Send + Sync + 'static>(&mut self, state: T) {
139        if let Ok(mut exts) = self.extensions.write() {
140            exts.insert(state);
141        }
142    }
143
144    pub fn get<H, Args>(&mut self, path: &str, handler: H)
145    where
146        H: Handler<Args>,
147        Args: Send + Sync + 'static,
148    {
149        self.add_route(Method::GET, path, handler);
150    }
151    
152    pub fn post<H, Args>(&mut self, path: &str, handler: H)
153    where
154        H: Handler<Args>,
155        Args: Send + Sync + 'static,
156    {
157        self.add_route(Method::POST, path, handler);
158    }
159
160    pub fn put<H, Args>(&mut self, path: &str, handler: H)
161    where
162        H: Handler<Args>,
163        Args: Send + Sync + 'static,
164    {
165        self.add_route(Method::PUT, path, handler);
166    }
167
168    pub fn delete<H, Args>(&mut self, path: &str, handler: H)
169    where
170        H: Handler<Args>,
171        Args: Send + Sync + 'static,
172    {
173        self.add_route(Method::DELETE, path, handler);
174    }
175
176    pub fn patch<H, Args>(&mut self, path: &str, handler: H)
177    where
178        H: Handler<Args>,
179        Args: Send + Sync + 'static,
180    {
181        self.add_route(Method::PATCH, path, handler);
182    }
183
184    fn add_route<H, Args>(&mut self, method: Method, path: &str, handler: H)
185    where
186        H: Handler<Args>,
187        Args: Send + Sync + 'static,
188    {
189        let (pattern, param_names) = compile_path(path);
190        let endpoint = HandlerService {
191            handler,
192            _marker: std::marker::PhantomData,
193        };
194        
195        let route = Arc::new(Route {
196            pattern,
197            param_names,
198            handler: Arc::new(endpoint),
199        });
200        
201        self.routes
202            .entry(method)
203            .or_insert_with(Vec::new)
204            .push(route);
205    }
206
207    pub async fn handle(&self, mut req: OxiditeRequest) -> Result<OxiditeResponse> {
208        req.extensions_mut().insert(self.extensions.clone());
209        let method = req.method().clone();
210        let path = req.uri().path().to_string();
211
212        // Helper to try matching routes for a specific method
213        let try_match = |target_method: &Method, req: &mut OxiditeRequest| -> Option<Arc<Route>> {
214            if let Some(routes) = self.routes.get(target_method) {
215                for route in routes {
216                    if let Some(captures) = route.pattern.captures(&path) {
217                        // Extract path parameters
218                        let mut params = serde_json::Map::new();
219                        for (i, name) in route.param_names.iter().enumerate() {
220                            if let Some(value) = captures.get(i + 1) {
221                                params.insert(
222                                    name.clone(),
223                                    serde_json::Value::String(value.as_str().to_string()),
224                                );
225                            }
226                        }
227
228                        // Store params in request extensions
229                        if !params.is_empty() {
230                            req.extensions_mut().insert(crate::extract::PathParams(
231                                serde_json::Value::Object(params),
232                            ));
233                        }
234                        
235                        return Some(route.clone());
236                    }
237                }
238            }
239            None
240        };
241
242        // 1. Try exact method match
243        if let Some(route) = try_match(&method, &mut req) {
244            return route.handler.call(req).await;
245        }
246
247        // 2. If HEAD, try GET
248        if method == Method::HEAD {
249            if let Some(route) = try_match(&Method::GET, &mut req) {
250                // For HEAD requests, we execute the GET handler but the server/hyper 
251                // will strip the body automatically since it's a HEAD response.
252                return route.handler.call(req).await;
253            }
254        }
255
256        // 3. Path exists for other methods => method not allowed
257        let allowed_methods: Vec<String> = self
258            .routes
259            .iter()
260            .filter(|(route_method, _)| **route_method != method)
261            .filter_map(|(route_method, routes)| {
262                if routes.iter().any(|route| route.pattern.is_match(&path)) {
263                    Some(route_method.as_str().to_string())
264                } else {
265                    None
266                }
267            })
268            .collect();
269        if !allowed_methods.is_empty() {
270            return Err(Error::MethodNotAllowed(format!(
271                "{} {} (allowed: {})",
272                method,
273                path,
274                allowed_methods.join(", ")
275            )));
276        }
277
278        Err(Error::NotFound("Route not found".to_string()))
279    }
280}
281
282impl Service<OxiditeRequest> for Router {
283    type Response = OxiditeResponse;
284    type Error = Error;
285    type Future = Pin<Box<dyn Future<Output = Result<Self::Response>> + Send>>;
286
287    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
288        Poll::Ready(Ok(()))
289    }
290
291    fn call(&mut self, req: OxiditeRequest) -> Self::Future {
292        let router = self.clone();
293        Box::pin(async move {
294            router.handle(req).await
295        })
296    }
297}
298
299impl Default for Router {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305/// Compile a route path pattern into a regex
306/// Converts `/users/:id` to `^/users/([^/]+)$` and returns param names
307fn compile_path(path: &str) -> (Regex, Vec<String>) {
308    let mut pattern = String::from("^");
309    let mut param_names = Vec::new();
310    let mut chars = path.chars().peekable();
311
312    while let Some(ch) = chars.next() {
313        match ch {
314            ':' => {
315                // Extract parameter name
316                let mut param_name = String::new();
317                while let Some(&next_ch) = chars.peek() {
318                    if next_ch.is_alphanumeric() || next_ch == '_' {
319                        param_name.push(next_ch);
320                        chars.next();
321                    } else {
322                        break;
323                    }
324                }
325                param_names.push(param_name);
326                pattern.push_str("([^/]+)");
327            }
328            '*' => {
329                // Wildcard
330                pattern.push_str("(.*)");
331            }
332            '.' | '+' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
333                // Escape regex special characters
334                pattern.push('\\');
335                pattern.push(ch);
336            }
337            _ => {
338                pattern.push(ch);
339            }
340        }
341    }
342
343    pattern.push('$');
344    let regex = Regex::new(&pattern).expect("Invalid route pattern");
345    (regex, param_names)
346}
347
348/// Trait that provides compile-time verification that a function is a valid handler.
349///
350/// This is used by the [`handler_fn`] function to give clear, readable error messages
351/// when a function doesn't satisfy the handler requirements, rather than the cryptic
352/// trait-bound errors that would otherwise surface from the router.
353///
354/// # Example
355///
356/// ```rust,ignore
357/// use oxidite::prelude::*;
358///
359/// // This compiles because the function matches Handler<(State<Arc<AppState>>,)>
360/// let h = handler_fn(my_handler);
361/// router.get("/users", h);
362///
363/// // This would fail at compile time with a clear error if the function
364/// // has extractors that don't implement FromRequest.
365/// ```
366pub trait IntoHandler<Args>: Handler<Args> + Sized {
367    fn into_handler(self) -> Self {
368        self
369    }
370}
371
372impl<H, Args> IntoHandler<Args> for H where H: Handler<Args> {}
373
374/// Compile-time handler verification helper.
375///
376/// Wraps a handler function and ensures at compile time that all its extractors
377/// implement `FromRequest`. Provides clearer error messages than raw trait bounds.
378///
379/// # Example
380///
381/// ```rust,ignore
382/// use oxidite::prelude::*;
383///
384/// async fn index(State(s): State<Arc<AppState>>) -> Result<OxiditeResponse> {
385///     Ok(response::json(serde_json::json!({"ok": true})))
386/// }
387///
388/// // Verified at compile time:
389/// router.get("/", handler_fn(index));
390/// ```
391pub fn handler_fn<H, Args>(handler: H) -> H
392where
393    H: IntoHandler<Args>,
394    Args: Send + Sync + 'static,
395{
396    handler
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use crate::types::BoxBody;
403
404    #[test]
405    fn test_compile_path() {
406        let (regex, params) = compile_path("/users/:id");
407        assert_eq!(params, vec!["id"]);
408        assert!(regex.is_match("/users/123"));
409        assert!(!regex.is_match("/users/123/posts"));
410
411        let (regex, params) = compile_path("/users/:user_id/posts/:post_id");
412        assert_eq!(params, vec!["user_id", "post_id"]);
413        assert!(regex.is_match("/users/1/posts/2"));
414    }
415
416    #[test]
417    fn test_exact_match() {
418        let (regex, params) = compile_path("/users");
419        assert_eq!(params.len(), 0);
420        assert!(regex.is_match("/users"));
421        assert!(!regex.is_match("/users/123"));
422    }
423
424    #[tokio::test]
425    async fn test_method_not_allowed_when_path_exists() {
426        let mut router = Router::new();
427        router.get("/users", || async { Ok(crate::OxiditeResponse::text("ok")) });
428        let req = http::Request::builder()
429            .method(Method::POST)
430            .uri("/users")
431            .body(BoxBody::default())
432            .expect("request");
433
434        let result = router.handle(req).await;
435        assert!(matches!(result, Err(Error::MethodNotAllowed(_))));
436    }
437
438    #[tokio::test]
439    async fn test_not_found_when_path_missing() {
440        let router = Router::new();
441        let req = http::Request::builder()
442            .method(Method::GET)
443            .uri("/missing")
444            .body(BoxBody::default())
445            .expect("request");
446
447        let result = router.handle(req).await;
448        assert!(matches!(result, Err(Error::NotFound(_))));
449    }
450}