oxidite_core/
router.rs

1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse};
3use hyper::Method;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use tower_service::Service;
10use regex::Regex;
11
12pub trait Handler: Send + Sync + 'static {
13    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
14}
15
16impl<F, Fut> Handler for F
17where
18    F: Fn(OxiditeRequest) -> Fut + Send + Sync + 'static,
19    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
20{
21    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
22        Box::pin(self(req))
23    }
24}
25
26struct Route {
27    pattern: Regex,
28    param_names: Vec<String>,
29    handler: Arc<dyn Handler>,
30}
31
32#[derive(Clone)]
33pub struct Router {
34    routes: HashMap<Method, Vec<Arc<Route>>>,
35}
36
37impl Router {
38    pub fn new() -> Self {
39        Self {
40            routes: HashMap::new(),
41        }
42    }
43
44    pub fn get<H>(&mut self, path: &str, handler: H)
45    where
46        H: Handler,
47    {
48        self.add_route(Method::GET, path, handler);
49    }
50    
51    pub fn post<H>(&mut self, path: &str, handler: H)
52    where
53        H: Handler,
54    {
55        self.add_route(Method::POST, path, handler);
56    }
57
58    pub fn put<H>(&mut self, path: &str, handler: H)
59    where
60        H: Handler,
61    {
62        self.add_route(Method::PUT, path, handler);
63    }
64
65    pub fn delete<H>(&mut self, path: &str, handler: H)
66    where
67        H: Handler,
68    {
69        self.add_route(Method::DELETE, path, handler);
70    }
71
72    pub fn patch<H>(&mut self, path: &str, handler: H)
73    where
74        H: Handler,
75    {
76        self.add_route(Method::PATCH, path, handler);
77    }
78
79    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
80    where
81        H: Handler,
82    {
83        let (pattern, param_names) = compile_path(path);
84        let route = Arc::new(Route {
85            pattern,
86            param_names,
87            handler: Arc::new(handler),
88        });
89        
90        self.routes
91            .entry(method)
92            .or_insert_with(Vec::new)
93            .push(route);
94    }
95
96    pub async fn handle(&self, mut req: OxiditeRequest) -> Result<OxiditeResponse> {
97        let method = req.method().clone();
98        let path = req.uri().path().to_string();
99        let path_for_error = path.clone();
100
101        // Helper to try matching routes for a specific method
102        let try_match = |target_method: &Method, req: &mut OxiditeRequest| -> Option<Arc<Route>> {
103            if let Some(routes) = self.routes.get(target_method) {
104                for route in routes {
105                    if let Some(captures) = route.pattern.captures(&path) {
106                        // Extract path parameters
107                        let mut params = serde_json::Map::new();
108                        for (i, name) in route.param_names.iter().enumerate() {
109                            if let Some(value) = captures.get(i + 1) {
110                                params.insert(
111                                    name.clone(),
112                                    serde_json::Value::String(value.as_str().to_string()),
113                                );
114                            }
115                        }
116
117                        // Store params in request extensions
118                        if !params.is_empty() {
119                            req.extensions_mut().insert(crate::extract::PathParams(
120                                serde_json::Value::Object(params),
121                            ));
122                        }
123                        
124                        return Some(route.clone());
125                    }
126                }
127            }
128            None
129        };
130
131        // 1. Try exact method match
132        if let Some(route) = try_match(&method, &mut req) {
133            return route.handler.call(req).await;
134        }
135
136        // 2. If HEAD, try GET
137        if method == Method::HEAD {
138            if let Some(route) = try_match(&Method::GET, &mut req) {
139                // For HEAD requests, we execute the GET handler but the server/hyper 
140                // will strip the body automatically since it's a HEAD response.
141                return route.handler.call(req).await;
142            }
143        }
144
145        // Log which path was not found
146        eprintln!("🔍 Route not found: {} {}", method, path_for_error);
147        Err(Error::NotFound("Route not found".to_string()))
148    }
149}
150
151impl Service<OxiditeRequest> for Router {
152    type Response = OxiditeResponse;
153    type Error = Error;
154    type Future = Pin<Box<dyn Future<Output = Result<Self::Response>> + Send>>;
155
156    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
157        Poll::Ready(Ok(()))
158    }
159
160    fn call(&mut self, req: OxiditeRequest) -> Self::Future {
161        let router = self.clone();
162        Box::pin(async move {
163            router.handle(req).await
164        })
165    }
166}
167
168impl Default for Router {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174/// Compile a route path pattern into a regex
175/// Converts `/users/:id` to `^/users/([^/]+)$` and returns param names
176fn compile_path(path: &str) -> (Regex, Vec<String>) {
177    let mut pattern = String::from("^");
178    let mut param_names = Vec::new();
179    let mut chars = path.chars().peekable();
180
181    while let Some(ch) = chars.next() {
182        match ch {
183            ':' => {
184                // Extract parameter name
185                let mut param_name = String::new();
186                while let Some(&next_ch) = chars.peek() {
187                    if next_ch.is_alphanumeric() || next_ch == '_' {
188                        param_name.push(next_ch);
189                        chars.next();
190                    } else {
191                        break;
192                    }
193                }
194                param_names.push(param_name);
195                pattern.push_str("([^/]+)");
196            }
197            '*' => {
198                // Wildcard
199                pattern.push_str("(.*)");
200            }
201            '.' | '+' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
202                // Escape regex special characters
203                pattern.push('\\');
204                pattern.push(ch);
205            }
206            _ => {
207                pattern.push(ch);
208            }
209        }
210    }
211
212    pattern.push('$');
213    let regex = Regex::new(&pattern).expect("Invalid route pattern");
214    (regex, param_names)
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_compile_path() {
223        let (regex, params) = compile_path("/users/:id");
224        assert_eq!(params, vec!["id"]);
225        assert!(regex.is_match("/users/123"));
226        assert!(!regex.is_match("/users/123/posts"));
227
228        let (regex, params) = compile_path("/users/:user_id/posts/:post_id");
229        assert_eq!(params, vec!["user_id", "post_id"]);
230        assert!(regex.is_match("/users/1/posts/2"));
231    }
232
233    #[test]
234    fn test_exact_match() {
235        let (regex, params) = compile_path("/users");
236        assert_eq!(params.len(), 0);
237        assert!(regex.is_match("/users"));
238        assert!(!regex.is_match("/users/123"));
239    }
240}