oxidite_core/
router.rs

1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse};
3use crate::extract::FromRequest;
4use hyper::Method;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tower_service::Service;
11use regex::Regex;
12
13/// Trait for type-erased handlers stored in the router
14pub trait Endpoint: Send + Sync + 'static {
15    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
16}
17
18/// Trait for async functions that can be used as handlers
19pub trait Handler<Args>: Clone + Send + Sync + 'static {
20    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
21}
22
23// Wrapper to convert Handler<Args> into Endpoint
24struct HandlerService<H, Args> {
25    handler: H,
26    _marker: std::marker::PhantomData<Args>,
27}
28
29impl<H, Args> Endpoint for HandlerService<H, Args>
30where
31    H: Handler<Args>,
32    Args: Send + Sync + 'static,
33{
34    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
35        self.handler.call(req)
36    }
37}
38
39// Implement Handler for Fn(OxiditeRequest) -> Fut
40impl<F, Fut> Handler<OxiditeRequest> for F
41where
42    F: Fn(OxiditeRequest) -> Fut + Clone + Send + Sync + 'static,
43    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
44{
45    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
46        let fut = self(req);
47        Box::pin(async move { fut.await })
48    }
49}
50
51// Implement Handler for Fn() -> Fut
52impl<F, Fut> Handler<()> for F
53where
54    F: Fn() -> Fut + Clone + Send + Sync + 'static,
55    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
56{
57    fn call(&self, _req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
58        let fut = self();
59        Box::pin(async move { fut.await })
60    }
61}
62
63// Implement Handler for Fn(T1) -> Fut
64impl<F, Fut, T1> Handler<(T1,)> for F
65where
66    F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
67    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
68    T1: FromRequest + Send + 'static,
69{
70    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
71        let handler = self.clone();
72        Box::pin(async move {
73            let t1 = T1::from_request(&mut req).await?;
74            handler(t1).await
75        })
76    }
77}
78
79// Implement Handler for Fn(T1, T2) -> Fut
80impl<F, Fut, T1, T2> Handler<(T1, T2)> for F
81where
82    F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
83    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
84    T1: FromRequest + Send + 'static,
85    T2: FromRequest + Send + 'static,
86{
87    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
88        let handler = self.clone();
89        Box::pin(async move {
90            let t1 = T1::from_request(&mut req).await?;
91            let t2 = T2::from_request(&mut req).await?;
92            handler(t1, t2).await
93        })
94    }
95}
96
97// Implement Handler for Fn(T1, T2, T3) -> Fut
98impl<F, Fut, T1, T2, T3> Handler<(T1, T2, T3)> for F
99where
100    F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
101    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
102    T1: FromRequest + Send + 'static,
103    T2: FromRequest + Send + 'static,
104    T3: FromRequest + Send + 'static,
105{
106    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
107        let handler = self.clone();
108        Box::pin(async move {
109            let t1 = T1::from_request(&mut req).await?;
110            let t2 = T2::from_request(&mut req).await?;
111            let t3 = T3::from_request(&mut req).await?;
112            handler(t1, t2, t3).await
113        })
114    }
115}
116
117struct Route {
118    pattern: Regex,
119    param_names: Vec<String>,
120    handler: Arc<dyn Endpoint>,
121}
122
123#[derive(Clone)]
124pub struct Router {
125    routes: HashMap<Method, Vec<Arc<Route>>>,
126}
127
128impl Router {
129    pub fn new() -> Self {
130        Self {
131            routes: HashMap::new(),
132        }
133    }
134
135    pub fn get<H, Args>(&mut self, path: &str, handler: H)
136    where
137        H: Handler<Args>,
138        Args: Send + Sync + 'static,
139    {
140        self.add_route(Method::GET, path, handler);
141    }
142    
143    pub fn post<H, Args>(&mut self, path: &str, handler: H)
144    where
145        H: Handler<Args>,
146        Args: Send + Sync + 'static,
147    {
148        self.add_route(Method::POST, path, handler);
149    }
150
151    pub fn put<H, Args>(&mut self, path: &str, handler: H)
152    where
153        H: Handler<Args>,
154        Args: Send + Sync + 'static,
155    {
156        self.add_route(Method::PUT, path, handler);
157    }
158
159    pub fn delete<H, Args>(&mut self, path: &str, handler: H)
160    where
161        H: Handler<Args>,
162        Args: Send + Sync + 'static,
163    {
164        self.add_route(Method::DELETE, path, handler);
165    }
166
167    pub fn patch<H, Args>(&mut self, path: &str, handler: H)
168    where
169        H: Handler<Args>,
170        Args: Send + Sync + 'static,
171    {
172        self.add_route(Method::PATCH, path, handler);
173    }
174
175    fn add_route<H, Args>(&mut self, method: Method, path: &str, handler: H)
176    where
177        H: Handler<Args>,
178        Args: Send + Sync + 'static,
179    {
180        let (pattern, param_names) = compile_path(path);
181        let endpoint = HandlerService {
182            handler,
183            _marker: std::marker::PhantomData,
184        };
185        
186        let route = Arc::new(Route {
187            pattern,
188            param_names,
189            handler: Arc::new(endpoint),
190        });
191        
192        self.routes
193            .entry(method)
194            .or_insert_with(Vec::new)
195            .push(route);
196    }
197
198    pub async fn handle(&self, mut req: OxiditeRequest) -> Result<OxiditeResponse> {
199        let method = req.method().clone();
200        let path = req.uri().path().to_string();
201        let path_for_error = path.clone();
202
203        // Helper to try matching routes for a specific method
204        let try_match = |target_method: &Method, req: &mut OxiditeRequest| -> Option<Arc<Route>> {
205            if let Some(routes) = self.routes.get(target_method) {
206                for route in routes {
207                    if let Some(captures) = route.pattern.captures(&path) {
208                        // Extract path parameters
209                        let mut params = serde_json::Map::new();
210                        for (i, name) in route.param_names.iter().enumerate() {
211                            if let Some(value) = captures.get(i + 1) {
212                                params.insert(
213                                    name.clone(),
214                                    serde_json::Value::String(value.as_str().to_string()),
215                                );
216                            }
217                        }
218
219                        // Store params in request extensions
220                        if !params.is_empty() {
221                            req.extensions_mut().insert(crate::extract::PathParams(
222                                serde_json::Value::Object(params),
223                            ));
224                        }
225                        
226                        return Some(route.clone());
227                    }
228                }
229            }
230            None
231        };
232
233        // 1. Try exact method match
234        if let Some(route) = try_match(&method, &mut req) {
235            return route.handler.call(req).await;
236        }
237
238        // 2. If HEAD, try GET
239        if method == Method::HEAD {
240            if let Some(route) = try_match(&Method::GET, &mut req) {
241                // For HEAD requests, we execute the GET handler but the server/hyper 
242                // will strip the body automatically since it's a HEAD response.
243                return route.handler.call(req).await;
244            }
245        }
246
247        // Log which path was not found
248        eprintln!("🔍 Route not found: {} {}", method, path_for_error);
249        Err(Error::NotFound("Route not found".to_string()))
250    }
251}
252
253impl Service<OxiditeRequest> for Router {
254    type Response = OxiditeResponse;
255    type Error = Error;
256    type Future = Pin<Box<dyn Future<Output = Result<Self::Response>> + Send>>;
257
258    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
259        Poll::Ready(Ok(()))
260    }
261
262    fn call(&mut self, req: OxiditeRequest) -> Self::Future {
263        let router = self.clone();
264        Box::pin(async move {
265            router.handle(req).await
266        })
267    }
268}
269
270impl Default for Router {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276/// Compile a route path pattern into a regex
277/// Converts `/users/:id` to `^/users/([^/]+)$` and returns param names
278fn compile_path(path: &str) -> (Regex, Vec<String>) {
279    let mut pattern = String::from("^");
280    let mut param_names = Vec::new();
281    let mut chars = path.chars().peekable();
282
283    while let Some(ch) = chars.next() {
284        match ch {
285            ':' => {
286                // Extract parameter name
287                let mut param_name = String::new();
288                while let Some(&next_ch) = chars.peek() {
289                    if next_ch.is_alphanumeric() || next_ch == '_' {
290                        param_name.push(next_ch);
291                        chars.next();
292                    } else {
293                        break;
294                    }
295                }
296                param_names.push(param_name);
297                pattern.push_str("([^/]+)");
298            }
299            '*' => {
300                // Wildcard
301                pattern.push_str("(.*)");
302            }
303            '.' | '+' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
304                // Escape regex special characters
305                pattern.push('\\');
306                pattern.push(ch);
307            }
308            _ => {
309                pattern.push(ch);
310            }
311        }
312    }
313
314    pattern.push('$');
315    let regex = Regex::new(&pattern).expect("Invalid route pattern");
316    (regex, param_names)
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_compile_path() {
325        let (regex, params) = compile_path("/users/:id");
326        assert_eq!(params, vec!["id"]);
327        assert!(regex.is_match("/users/123"));
328        assert!(!regex.is_match("/users/123/posts"));
329
330        let (regex, params) = compile_path("/users/:user_id/posts/:post_id");
331        assert_eq!(params, vec!["user_id", "post_id"]);
332        assert!(regex.is_match("/users/1/posts/2"));
333    }
334
335    #[test]
336    fn test_exact_match() {
337        let (regex, params) = compile_path("/users");
338        assert_eq!(params.len(), 0);
339        assert!(regex.is_match("/users"));
340        assert!(!regex.is_match("/users/123"));
341    }
342}