Skip to main content

oxidite_core/
router.rs

1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse};
3use crate::extract::FromRequest;
4use hyper::Method;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tower_service::Service;
11
12use regex::Regex;
13
14/// Trait for type-erased handlers stored in the router
15pub trait Endpoint: Send + Sync + 'static {
16    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
17}
18
19impl Endpoint for Arc<dyn Endpoint> {
20    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
21        (**self).call(req)
22    }
23}
24
25pub struct EndpointService(pub Arc<dyn Endpoint>);
26
27impl Clone for EndpointService {
28    fn clone(&self) -> Self {
29        Self(self.0.clone())
30    }
31}
32
33impl Service<OxiditeRequest> for EndpointService {
34    type Response = OxiditeResponse;
35    type Error = Error;
36    type Future = Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
37
38    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
39        Poll::Ready(Ok(()))
40    }
41
42    fn call(&mut self, req: OxiditeRequest) -> Self::Future {
43        self.0.call(req)
44    }
45}
46
47impl Endpoint for EndpointService {
48    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
49        self.0.call(req)
50    }
51}
52
53/// CORS configuration for the Router
54#[derive(Clone, Debug)]
55pub struct CorsConfig {
56    /// Allowed origins (e.g., "http://localhost:3000")
57    /// Use "*" to allow all origins
58    pub allowed_origins: Vec<String>,
59    /// Allowed HTTP methods (e.g., "GET", "POST", "PUT", "DELETE")
60    /// Empty means allow all methods
61    pub allowed_methods: Vec<String>,
62    /// Allowed HTTP headers (e.g., "Content-Type", "Authorization")
63    /// Empty means allow all headers
64    pub allowed_headers: Vec<String>,
65    /// Whether to allow credentials (cookies, authorization headers)
66    pub allow_credentials: bool,
67    /// Max age for CORS preflight cache (in seconds)
68    pub max_age: u32,
69}
70
71impl Default for CorsConfig {
72    fn default() -> Self {
73        Self {
74            allowed_origins: vec!["*".to_string()],
75            allowed_methods: vec!["GET".to_string(), "POST".to_string(), "PUT".to_string(), "DELETE".to_string(), "OPTIONS".to_string(), "PATCH".to_string()],
76            allowed_headers: vec!["*".to_string()],
77            allow_credentials: false,
78            max_age: 3600,
79        }
80    }
81}
82
83impl CorsConfig {
84    /// Create a new CORS config that allows everything (useful for development)
85    pub fn permissive() -> Self {
86        Self::default()
87    }
88
89    /// Create a new CORS config with no allowed origins (restrictive default)
90    pub fn restrictive() -> Self {
91        Self {
92            allowed_origins: Vec::new(),
93            allowed_methods: vec!["GET".to_string(), "POST".to_string()],
94            allowed_headers: vec!["Content-Type".to_string()],
95            allow_credentials: false,
96            max_age: 3600,
97        }
98    }
99}
100
101// Note: tower-http middleware like Cors and Compression change the response body type,
102// so they cannot be used as Endpoint-level middleware. They should be applied at the
103// Server level instead, after body type conversion.
104//
105// For CORS at the router level, use custom middleware that only modifies headers
106// without changing the body type.
107
108
109/// Trait for async functions that can be used as handlers
110pub trait Handler<Args>: Clone + Send + Sync + 'static {
111    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
112}
113
114// Wrapper to convert Handler<Args> into Endpoint
115struct HandlerService<H, Args> {
116    handler: H,
117    _marker: std::marker::PhantomData<Args>,
118}
119
120impl<H, Args> Endpoint for HandlerService<H, Args>
121where
122    H: Handler<Args>,
123    Args: Send + Sync + 'static,
124{
125    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
126        self.handler.call(req)
127    }
128}
129
130// Implement Handler for Fn(OxiditeRequest) -> Fut
131impl<F, Fut> Handler<OxiditeRequest> for F
132where
133    F: Fn(OxiditeRequest) -> Fut + Clone + Send + Sync + 'static,
134    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
135{
136    fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
137        let fut = self(req);
138        Box::pin(async move { fut.await })
139    }
140}
141
142// Implement Handler for Fn() -> Fut
143impl<F, Fut> Handler<()> for F
144where
145    F: Fn() -> Fut + Clone + Send + Sync + 'static,
146    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
147{
148    fn call(&self, _req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
149        let fut = self();
150        Box::pin(async move { fut.await })
151    }
152}
153
154// Implement Handler for Fn(T1) -> Fut
155impl<F, Fut, T1> Handler<(T1,)> for F
156where
157    F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
158    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
159    T1: FromRequest + Send + 'static,
160{
161    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
162        let handler = self.clone();
163        Box::pin(async move {
164            let t1 = T1::from_request(&mut req).await?;
165            handler(t1).await
166        })
167    }
168}
169
170// Implement Handler for Fn(T1, T2, ..., T9) -> Fut
171impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8, T9> Handler<(T1, T2, T3, T4, T5, T6, T7, T8, T9)> for F
172where
173    F: Fn(T1, T2, T3, T4, T5, T6, T7, T8, T9) -> Fut + Clone + Send + Sync + 'static,
174    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
175    T1: FromRequest + Send + 'static,
176    T2: FromRequest + Send + 'static,
177    T3: FromRequest + Send + 'static,
178    T4: FromRequest + Send + 'static,
179    T5: FromRequest + Send + 'static,
180    T6: FromRequest + Send + 'static,
181    T7: FromRequest + Send + 'static,
182    T8: FromRequest + Send + 'static,
183    T9: FromRequest + Send + 'static,
184{
185    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
186        let handler = self.clone();
187        Box::pin(async move {
188            let t1 = T1::from_request(&mut req).await?;
189            let t2 = T2::from_request(&mut req).await?;
190            let t3 = T3::from_request(&mut req).await?;
191            let t4 = T4::from_request(&mut req).await?;
192            let t5 = T5::from_request(&mut req).await?;
193            let t6 = T6::from_request(&mut req).await?;
194            let t7 = T7::from_request(&mut req).await?;
195            let t8 = T8::from_request(&mut req).await?;
196            let t9 = T9::from_request(&mut req).await?;
197            handler(t1, t2, t3, t4, t5, t6, t7, t8, t9).await
198        })
199    }
200}
201
202// Implement Handler for Fn(T1, T2, ..., T10) -> Fut
203impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10> Handler<(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10)> for F
204where
205    F: Fn(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) -> Fut + Clone + Send + Sync + 'static,
206    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
207    T1: FromRequest + Send + 'static,
208    T2: FromRequest + Send + 'static,
209    T3: FromRequest + Send + 'static,
210    T4: FromRequest + Send + 'static,
211    T5: FromRequest + Send + 'static,
212    T6: FromRequest + Send + 'static,
213    T7: FromRequest + Send + 'static,
214    T8: FromRequest + Send + 'static,
215    T9: FromRequest + Send + 'static,
216    T10: FromRequest + Send + 'static,
217{
218    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
219        let handler = self.clone();
220        Box::pin(async move {
221            let t1 = T1::from_request(&mut req).await?;
222            let t2 = T2::from_request(&mut req).await?;
223            let t3 = T3::from_request(&mut req).await?;
224            let t4 = T4::from_request(&mut req).await?;
225            let t5 = T5::from_request(&mut req).await?;
226            let t6 = T6::from_request(&mut req).await?;
227            let t7 = T7::from_request(&mut req).await?;
228            let t8 = T8::from_request(&mut req).await?;
229            let t9 = T9::from_request(&mut req).await?;
230            let t10 = T10::from_request(&mut req).await?;
231            handler(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10).await
232        })
233    }
234}
235
236// Implement Handler for Fn(T1, T2, ..., T11) -> Fut
237impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11> Handler<(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11)> for F
238where
239    F: Fn(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) -> Fut + Clone + Send + Sync + 'static,
240    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
241    T1: FromRequest + Send + 'static,
242    T2: FromRequest + Send + 'static,
243    T3: FromRequest + Send + 'static,
244    T4: FromRequest + Send + 'static,
245    T5: FromRequest + Send + 'static,
246    T6: FromRequest + Send + 'static,
247    T7: FromRequest + Send + 'static,
248    T8: FromRequest + Send + 'static,
249    T9: FromRequest + Send + 'static,
250    T10: FromRequest + Send + 'static,
251    T11: FromRequest + Send + 'static,
252{
253    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
254        let handler = self.clone();
255        Box::pin(async move {
256            let t1 = T1::from_request(&mut req).await?;
257            let t2 = T2::from_request(&mut req).await?;
258            let t3 = T3::from_request(&mut req).await?;
259            let t4 = T4::from_request(&mut req).await?;
260            let t5 = T5::from_request(&mut req).await?;
261            let t6 = T6::from_request(&mut req).await?;
262            let t7 = T7::from_request(&mut req).await?;
263            let t8 = T8::from_request(&mut req).await?;
264            let t9 = T9::from_request(&mut req).await?;
265            let t10 = T10::from_request(&mut req).await?;
266            let t11 = T11::from_request(&mut req).await?;
267            handler(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11).await
268        })
269    }
270}
271
272// Implement Handler for Fn(T1, T2, ..., T12) -> Fut
273impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12> Handler<(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12)> for F
274where
275    F: Fn(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) -> Fut + Clone + Send + Sync + 'static,
276    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
277    T1: FromRequest + Send + 'static,
278    T2: FromRequest + Send + 'static,
279    T3: FromRequest + Send + 'static,
280    T4: FromRequest + Send + 'static,
281    T5: FromRequest + Send + 'static,
282    T6: FromRequest + Send + 'static,
283    T7: FromRequest + Send + 'static,
284    T8: FromRequest + Send + 'static,
285    T9: FromRequest + Send + 'static,
286    T10: FromRequest + Send + 'static,
287    T11: FromRequest + Send + 'static,
288    T12: FromRequest + Send + 'static,
289{
290    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
291        let handler = self.clone();
292        Box::pin(async move {
293            let t1 = T1::from_request(&mut req).await?;
294            let t2 = T2::from_request(&mut req).await?;
295            let t3 = T3::from_request(&mut req).await?;
296            let t4 = T4::from_request(&mut req).await?;
297            let t5 = T5::from_request(&mut req).await?;
298            let t6 = T6::from_request(&mut req).await?;
299            let t7 = T7::from_request(&mut req).await?;
300            let t8 = T8::from_request(&mut req).await?;
301            let t9 = T9::from_request(&mut req).await?;
302            let t10 = T10::from_request(&mut req).await?;
303            let t11 = T11::from_request(&mut req).await?;
304            let t12 = T12::from_request(&mut req).await?;
305            handler(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12).await
306        })
307    }
308}
309
310// Implement Handler for Fn(T1, T2) -> Fut
311impl<F, Fut, T1, T2> Handler<(T1, T2)> for F
312where
313    F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
314    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
315    T1: FromRequest + Send + 'static,
316    T2: FromRequest + Send + 'static,
317{
318    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
319        let handler = self.clone();
320        Box::pin(async move {
321            let t1 = T1::from_request(&mut req).await?;
322            let t2 = T2::from_request(&mut req).await?;
323            handler(t1, t2).await
324        })
325    }
326}
327
328// Implement Handler for Fn(T1, T2, T3) -> Fut
329impl<F, Fut, T1, T2, T3> Handler<(T1, T2, T3)> for F
330where
331    F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
332    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
333    T1: FromRequest + Send + 'static,
334    T2: FromRequest + Send + 'static,
335    T3: FromRequest + Send + 'static,
336{
337    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
338        let handler = self.clone();
339        Box::pin(async move {
340            let t1 = T1::from_request(&mut req).await?;
341            let t2 = T2::from_request(&mut req).await?;
342            let t3 = T3::from_request(&mut req).await?;
343            handler(t1, t2, t3).await
344        })
345    }
346}
347
348// Implement Handler for Fn(T1, T2, T3, T4) -> Fut
349impl<F, Fut, T1, T2, T3, T4> Handler<(T1, T2, T3, T4)> for F
350where
351    F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
352    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
353    T1: FromRequest + Send + 'static,
354    T2: FromRequest + Send + 'static,
355    T3: FromRequest + Send + 'static,
356    T4: FromRequest + Send + 'static,
357{
358    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
359        let handler = self.clone();
360        Box::pin(async move {
361            let t1 = T1::from_request(&mut req).await?;
362            let t2 = T2::from_request(&mut req).await?;
363            let t3 = T3::from_request(&mut req).await?;
364            let t4 = T4::from_request(&mut req).await?;
365            handler(t1, t2, t3, t4).await
366        })
367    }
368}
369
370// Implement Handler for Fn(T1, T2, T3, T4, T5) -> Fut
371impl<F, Fut, T1, T2, T3, T4, T5> Handler<(T1, T2, T3, T4, T5)> for F
372where
373    F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
374    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
375    T1: FromRequest + Send + 'static,
376    T2: FromRequest + Send + 'static,
377    T3: FromRequest + Send + 'static,
378    T4: FromRequest + Send + 'static,
379    T5: FromRequest + Send + 'static,
380{
381    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
382        let handler = self.clone();
383        Box::pin(async move {
384            let t1 = T1::from_request(&mut req).await?;
385            let t2 = T2::from_request(&mut req).await?;
386            let t3 = T3::from_request(&mut req).await?;
387            let t4 = T4::from_request(&mut req).await?;
388            let t5 = T5::from_request(&mut req).await?;
389            handler(t1, t2, t3, t4, t5).await
390        })
391    }
392}
393
394// Implement Handler for Fn(T1, T2, T3, T4, T5, T6) -> Fut
395impl<F, Fut, T1, T2, T3, T4, T5, T6> Handler<(T1, T2, T3, T4, T5, T6)> for F
396where
397    F: Fn(T1, T2, T3, T4, T5, T6) -> Fut + Clone + Send + Sync + 'static,
398    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
399    T1: FromRequest + Send + 'static,
400    T2: FromRequest + Send + 'static,
401    T3: FromRequest + Send + 'static,
402    T4: FromRequest + Send + 'static,
403    T5: FromRequest + Send + 'static,
404    T6: FromRequest + Send + 'static,
405{
406    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
407        let handler = self.clone();
408        Box::pin(async move {
409            let t1 = T1::from_request(&mut req).await?;
410            let t2 = T2::from_request(&mut req).await?;
411            let t3 = T3::from_request(&mut req).await?;
412            let t4 = T4::from_request(&mut req).await?;
413            let t5 = T5::from_request(&mut req).await?;
414            let t6 = T6::from_request(&mut req).await?;
415            handler(t1, t2, t3, t4, t5, t6).await
416        })
417    }
418}
419
420// Implement Handler for Fn(T1, T2, T3, T4, T5, T6, T7) -> Fut
421impl<F, Fut, T1, T2, T3, T4, T5, T6, T7> Handler<(T1, T2, T3, T4, T5, T6, T7)> for F
422where
423    F: Fn(T1, T2, T3, T4, T5, T6, T7) -> Fut + Clone + Send + Sync + 'static,
424    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
425    T1: FromRequest + Send + 'static,
426    T2: FromRequest + Send + 'static,
427    T3: FromRequest + Send + 'static,
428    T4: FromRequest + Send + 'static,
429    T5: FromRequest + Send + 'static,
430    T6: FromRequest + Send + 'static,
431    T7: FromRequest + Send + 'static,
432{
433    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
434        let handler = self.clone();
435        Box::pin(async move {
436            let t1 = T1::from_request(&mut req).await?;
437            let t2 = T2::from_request(&mut req).await?;
438            let t3 = T3::from_request(&mut req).await?;
439            let t4 = T4::from_request(&mut req).await?;
440            let t5 = T5::from_request(&mut req).await?;
441            let t6 = T6::from_request(&mut req).await?;
442            let t7 = T7::from_request(&mut req).await?;
443            handler(t1, t2, t3, t4, t5, t6, t7).await
444        })
445    }
446}
447
448// Implement Handler for Fn(T1, T2, T3, T4, T5, T6, T7, T8) -> Fut
449impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8> Handler<(T1, T2, T3, T4, T5, T6, T7, T8)> for F
450where
451    F: Fn(T1, T2, T3, T4, T5, T6, T7, T8) -> Fut + Clone + Send + Sync + 'static,
452    Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
453    T1: FromRequest + Send + 'static,
454    T2: FromRequest + Send + 'static,
455    T3: FromRequest + Send + 'static,
456    T4: FromRequest + Send + 'static,
457    T5: FromRequest + Send + 'static,
458    T6: FromRequest + Send + 'static,
459    T7: FromRequest + Send + 'static,
460    T8: FromRequest + Send + 'static,
461{
462    fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
463        let handler = self.clone();
464        Box::pin(async move {
465            let t1 = T1::from_request(&mut req).await?;
466            let t2 = T2::from_request(&mut req).await?;
467            let t3 = T3::from_request(&mut req).await?;
468            let t4 = T4::from_request(&mut req).await?;
469            let t5 = T5::from_request(&mut req).await?;
470            let t6 = T6::from_request(&mut req).await?;
471            let t7 = T7::from_request(&mut req).await?;
472            let t8 = T8::from_request(&mut req).await?;
473            handler(t1, t2, t3, t4, t5, t6, t7, t8).await
474        })
475    }
476}
477
478struct Route {
479    pattern: Regex,
480    param_names: Vec<String>,
481    handler: Arc<dyn Endpoint>,
482}
483
484#[derive(Clone)]
485pub struct Router {
486    routes: HashMap<Method, Vec<Arc<Route>>>,
487    extensions: Arc<std::sync::RwLock<http::Extensions>>,
488    middleware: Vec<Arc<dyn Fn(Arc<dyn Endpoint>) -> Arc<dyn Endpoint> + Send + Sync>>,
489    cors_config: Option<CorsConfig>,
490}
491
492impl Router {
493    pub fn new() -> Self {
494        Self {
495            routes: HashMap::new(),
496            extensions: Arc::new(std::sync::RwLock::new(http::Extensions::new())),
497            middleware: Vec::new(),
498            cors_config: None,
499        }
500    }
501
502    /// Add a shared state to the router that will be available in all handlers
503    pub fn with_state<T: Clone + Send + Sync + 'static>(&mut self, state: T) {
504        if let Ok(mut exts) = self.extensions.write() {
505            exts.insert(state);
506        }
507    }
508
509    pub fn get<H, Args>(&mut self, path: &str, handler: H)
510    where
511        H: Handler<Args>,
512        Args: Send + Sync + 'static,
513    {
514        self.add_route(Method::GET, path, handler);
515    }
516    
517    pub fn post<H, Args>(&mut self, path: &str, handler: H)
518    where
519        H: Handler<Args>,
520        Args: Send + Sync + 'static,
521    {
522        self.add_route(Method::POST, path, handler);
523    }
524
525    pub fn put<H, Args>(&mut self, path: &str, handler: H)
526    where
527        H: Handler<Args>,
528        Args: Send + Sync + 'static,
529    {
530        self.add_route(Method::PUT, path, handler);
531    }
532
533    pub fn delete<H, Args>(&mut self, path: &str, handler: H)
534    where
535        H: Handler<Args>,
536        Args: Send + Sync + 'static,
537    {
538        self.add_route(Method::DELETE, path, handler);
539    }
540
541    pub fn patch<H, Args>(&mut self, path: &str, handler: H)
542    where
543        H: Handler<Args>,
544        Args: Send + Sync + 'static,
545    {
546        self.add_route(Method::PATCH, path, handler);
547    }
548
549    /// Add a middleware layer to all routes in the router.
550    ///
551    /// The layer must implement `tower::Layer<EndpointService>` and return a new `Endpoint`.
552    ///
553    /// # Limitations
554    ///
555    /// Body-type-changing middleware (like `CorsLayer` and `CompressionLayer` from tower-http)
556    /// **cannot** be used with this method. These middleware change the HTTP response body type,
557    /// which is incompatible with the `Endpoint` trait that expects `OxiditeResponse`.
558    ///
559    /// For such middleware, use `ServiceBuilder` instead:
560    /// ```rust,ignore
561    /// let service = ServiceBuilder::new()
562    ///     .layer(CorsLayer::permissive())
563    ///     .layer(CompressionLayer::new())
564    ///     .service(router);
565    /// ```
566    pub fn layer<L>(mut self, layer: L) -> Self
567    where
568        L: tower::Layer<EndpointService> + Send + Sync + 'static,
569        L::Service: Endpoint,
570    {
571        let layer = Arc::new(layer);
572        self.middleware.push(Arc::new(move |endpoint| {
573            Arc::new(layer.layer(EndpointService(endpoint)))
574        }));
575        self
576    }
577
578    /// Alias for `.layer()`.
579    pub fn with_layer<L>(self, layer: L) -> Self
580    where
581        L: tower::Layer<EndpointService> + Send + Sync + 'static,
582        L::Service: Endpoint,
583    {
584        self.layer(layer)
585    }
586
587    /// Configure CORS for this router.
588    ///
589    /// This adds CORS headers to all responses, including preflight OPTIONS requests.
590    /// This is a framework-level CORS implementation that doesn't require tower-http.
591    ///
592    /// # Example
593    ///
594    /// ```rust,ignore
595    /// use oxidite::prelude::*;
596    ///
597    /// let router = Router::new()
598    ///     .with_cors(CorsConfig {
599    ///         allowed_origins: vec!["http://localhost:3000".to_string()],
600    ///         allowed_methods: vec!["GET".to_string(), "POST".to_string()],
601    ///         allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
602    ///         allow_credentials: true,
603    ///         max_age: 3600,
604    ///     });
605    /// ```
606    ///
607    /// For development, you can use the permissive config:
608    ///
609    /// ```rust,ignore
610    /// let router = Router::new()
611    ///     .with_cors(CorsConfig::permissive());
612    /// ```
613    pub fn with_cors(mut self, config: CorsConfig) -> Self {
614        self.cors_config = Some(config);
615        self
616    }
617
618    fn add_route<H, Args>(&mut self, method: Method, path: &str, handler: H)
619    where
620        H: Handler<Args>,
621        Args: Send + Sync + 'static,
622    {
623        let (pattern, param_names) = compile_path(path);
624        let mut endpoint: Arc<dyn Endpoint> = Arc::new(HandlerService {
625            handler,
626            _marker: std::marker::PhantomData,
627        });
628
629        // Apply router-level middleware in reverse order (so the first added is the outermost)
630        for mw in self.middleware.iter().rev() {
631            endpoint = mw(endpoint);
632        }
633        
634        let route = Arc::new(Route {
635            pattern,
636            param_names,
637            handler: endpoint,
638        });
639        
640        self.routes
641            .entry(method)
642            .or_insert_with(Vec::new)
643            .push(route);
644    }
645
646    /// Add CORS headers to a response builder based on the configured CORS policy
647    fn add_cors_headers(&self, mut builder: http::response::Builder) -> http::response::Builder {
648        if let Some(cors) = &self.cors_config {
649            // Add Access-Control-Allow-Origin
650            if !cors.allowed_origins.is_empty() {
651                if cors.allowed_origins.contains(&"*".to_string()) {
652                    builder = builder.header("Access-Control-Allow-Origin", "*");
653                } else {
654                    // For specific origins, we'd need to check the Origin header
655                    // For now, we'll add the first origin (could be improved with request checking)
656                    if let Some(origin) = cors.allowed_origins.first() {
657                        builder = builder.header("Access-Control-Allow-Origin", origin);
658                    }
659                }
660            }
661
662            // Add Access-Control-Allow-Methods
663            if !cors.allowed_methods.is_empty() {
664                let methods = cors.allowed_methods.join(", ");
665                builder = builder.header("Access-Control-Allow-Methods", methods);
666            }
667
668            // Add Access-Control-Allow-Headers
669            if !cors.allowed_headers.is_empty() {
670                let headers = cors.allowed_headers.join(", ");
671                builder = builder.header("Access-Control-Allow-Headers", headers);
672            }
673
674            // Add Access-Control-Allow-Credentials
675            if cors.allow_credentials {
676                builder = builder.header("Access-Control-Allow-Credentials", "true");
677            }
678
679            // Add Access-Control-Max-Age
680            builder = builder.header("Access-Control-Max-Age", cors.max_age.to_string());
681        }
682        builder
683    }
684
685    pub async fn handle(&self, mut req: OxiditeRequest) -> Result<OxiditeResponse> {
686        req.extensions_mut().insert(self.extensions.clone());
687        let method = req.method().clone();
688        let path = req.uri().path().to_string();
689
690        // Helper to try matching routes for a specific method
691        let try_match = |target_method: &Method, req: &mut OxiditeRequest| -> Option<Arc<Route>> {
692            if let Some(routes) = self.routes.get(target_method) {
693                for route in routes {
694                    if let Some(captures) = route.pattern.captures(&path) {
695                        // Extract path parameters
696                        let mut params = serde_json::Map::new();
697                        for (i, name) in route.param_names.iter().enumerate() {
698                            if let Some(value) = captures.get(i + 1) {
699                                params.insert(
700                                    name.clone(),
701                                    serde_json::Value::String(value.as_str().to_string()),
702                                );
703                            }
704                        }
705
706                        // Store params in request extensions
707                        if !params.is_empty() {
708                            req.extensions_mut().insert(crate::extract::PathParams(
709                                serde_json::Value::Object(params),
710                            ));
711                        }
712                        
713                        return Some(route.clone());
714                    }
715                }
716            }
717            None
718        };
719
720        // 1. Try exact method match
721        if let Some(route) = try_match(&method, &mut req) {
722            // Add router extensions to request so State extractor can find global state
723            req.extensions_mut().insert(self.extensions.clone());
724            let response = route.handler.call(req).await?;
725            
726            // Add CORS headers to successful responses
727            if self.cors_config.is_some() {
728                let hyper_response: hyper::Response<crate::types::BoxBody> = response.into();
729                let (parts, body) = hyper_response.into_parts();
730                let mut builder = http::Response::builder()
731                    .status(parts.status);
732                
733                // Copy existing headers
734                for (key, value) in parts.headers {
735                    if let Some(key) = key {
736                        builder = builder.header(key, value);
737                    }
738                }
739                
740                // Add CORS headers
741                builder = self.add_cors_headers(builder);
742                
743                return Ok(OxiditeResponse::new(builder.body(body).unwrap()));
744            }
745            
746            return Ok(response);
747        }
748
749        // 2. If OPTIONS, return empty success response for CORS if no explicit handler
750        if method == Method::OPTIONS {
751            if let Some(_route) = try_match(&Method::OPTIONS, &mut req) {
752                // Explicit handler exists, will be handled by step 1
753            } else {
754                // Return 204 No Content for CORS preflight
755                let mut builder = http::Response::builder()
756                    .status(http::StatusCode::NO_CONTENT);
757                
758                // Add CORS headers to preflight response
759                builder = self.add_cors_headers(builder);
760                
761                return Ok(OxiditeResponse::new(builder
762                    .body(crate::types::BoxBody::default())
763                    .unwrap()));
764            }
765        }
766
767        // 3. If HEAD, try GET
768        if method == Method::HEAD {
769            if let Some(route) = try_match(&Method::GET, &mut req) {
770                // Add router extensions to request so State extractor can find global state
771                req.extensions_mut().insert(self.extensions.clone());
772                // For HEAD requests, we execute the GET handler but the server/hyper 
773                // will strip the body automatically since it's a HEAD response.
774                return route.handler.call(req).await;
775            }
776        }
777
778        // 3. Path exists for other methods => method not allowed
779        let allowed_methods: Vec<String> = self
780            .routes
781            .iter()
782            .filter(|(route_method, _)| **route_method != method)
783            .filter_map(|(route_method, routes)| {
784                if routes.iter().any(|route| route.pattern.is_match(&path)) {
785                    Some(route_method.as_str().to_string())
786                } else {
787                    None
788                }
789            })
790            .collect();
791        if !allowed_methods.is_empty() {
792            return Err(Error::MethodNotAllowed(format!(
793                "{} {} (allowed: {})",
794                method,
795                path,
796                allowed_methods.join(", ")
797            )));
798        }
799
800        Err(Error::NotFound("Route not found".to_string()))
801    }
802}
803
804impl Service<OxiditeRequest> for Router {
805    type Response = OxiditeResponse;
806    type Error = Error;
807    type Future = Pin<Box<dyn Future<Output = Result<Self::Response>> + Send>>;
808
809    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
810        Poll::Ready(Ok(()))
811    }
812
813    fn call(&mut self, req: OxiditeRequest) -> Self::Future {
814        let router = self.clone();
815        Box::pin(async move {
816            router.handle(req).await
817        })
818    }
819}
820
821impl Default for Router {
822    fn default() -> Self {
823        Self::new()
824    }
825}
826
827/// Compile a route path pattern into a regex
828/// Converts `/users/:id` to `^/users/([^/]+)$` and returns param names
829fn compile_path(path: &str) -> (Regex, Vec<String>) {
830    let mut pattern = String::from("^");
831    let mut param_names = Vec::new();
832    let mut chars = path.chars().peekable();
833
834    while let Some(ch) = chars.next() {
835        match ch {
836            ':' => {
837                // Extract parameter name
838                let mut param_name = String::new();
839                while let Some(&next_ch) = chars.peek() {
840                    if next_ch.is_alphanumeric() || next_ch == '_' {
841                        param_name.push(next_ch);
842                        chars.next();
843                    } else {
844                        break;
845                    }
846                }
847                param_names.push(param_name);
848                pattern.push_str("([^/]+)");
849            }
850            '*' => {
851                // Wildcard
852                pattern.push_str("(.*)");
853            }
854            '.' | '+' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
855                // Escape regex special characters
856                pattern.push('\\');
857                pattern.push(ch);
858            }
859            _ => {
860                pattern.push(ch);
861            }
862        }
863    }
864
865    pattern.push('$');
866    let regex = Regex::new(&pattern).expect("Invalid route pattern");
867    (regex, param_names)
868}
869
870/// Trait that provides compile-time verification that a function is a valid handler.
871///
872/// This is used by the [`handler_fn`] function to give clear, readable error messages
873/// when a function doesn't satisfy the handler requirements, rather than the cryptic
874/// trait-bound errors that would otherwise surface from the router.
875///
876/// # Example
877///
878/// ```rust,ignore
879/// use oxidite::prelude::*;
880///
881/// // This compiles because the function matches Handler<(State<Arc<AppState>>,)>
882/// let h = handler_fn(my_handler);
883/// router.get("/users", h);
884///
885/// // This would fail at compile time with a clear error if the function
886/// // has extractors that don't implement FromRequest.
887/// ```
888pub trait IntoHandler<Args>: Handler<Args> + Sized {
889    fn into_handler(self) -> Self {
890        self
891    }
892}
893
894impl<H, Args> IntoHandler<Args> for H where H: Handler<Args> {}
895
896/// Compile-time handler verification helper.
897///
898/// Wraps a handler function and ensures at compile time that all its extractors
899/// implement `FromRequest`. Provides clearer error messages than raw trait bounds.
900///
901/// # Example
902///
903/// ```rust,ignore
904/// use oxidite::prelude::*;
905///
906/// async fn index(State(s): State<Arc<AppState>>) -> Result<OxiditeResponse> {
907///     Ok(response::json(serde_json::json!({"ok": true})))
908/// }
909///
910/// // Verified at compile time:
911/// router.get("/", handler_fn(index));
912/// ```
913pub fn handler_fn<H, Args>(handler: H) -> H
914where
915    H: IntoHandler<Args>,
916    Args: Send + Sync + 'static,
917{
918    handler
919}
920
921#[cfg(test)]
922mod tests {
923    use super::*;
924    use crate::types::BoxBody;
925
926    #[test]
927    fn test_compile_path() {
928        let (regex, params) = compile_path("/users/:id");
929        assert_eq!(params, vec!["id"]);
930        assert!(regex.is_match("/users/123"));
931        assert!(!regex.is_match("/users/123/posts"));
932
933        let (regex, params) = compile_path("/users/:user_id/posts/:post_id");
934        assert_eq!(params, vec!["user_id", "post_id"]);
935        assert!(regex.is_match("/users/1/posts/2"));
936    }
937
938    #[test]
939    fn test_exact_match() {
940        let (regex, params) = compile_path("/users");
941        assert_eq!(params.len(), 0);
942        assert!(regex.is_match("/users"));
943        assert!(!regex.is_match("/users/123"));
944    }
945
946    #[tokio::test]
947    async fn test_handler_with_12_extractors() {
948        use crate::extract::State;
949
950        #[derive(Clone)]
951        struct AppState;
952
953        async fn h12(
954            _s1: State<AppState>,
955            _s2: State<AppState>,
956            _s3: State<AppState>,
957            _s4: State<AppState>,
958            _s5: State<AppState>,
959            _s6: State<AppState>,
960            _s7: State<AppState>,
961            _s8: State<AppState>,
962            _s9: State<AppState>,
963            _s10: State<AppState>,
964            _s11: State<AppState>,
965            _s12: State<AppState>,
966        ) -> Result<OxiditeResponse> {
967            Ok(OxiditeResponse::text("ok"))
968        }
969
970        let mut router = Router::new();
971        router.with_state(AppState);
972        router.get("/", h12);
973
974        let req = http::Request::builder()
975            .method(Method::GET)
976            .uri("/")
977            .body(BoxBody::default())
978            .expect("request");
979
980        let result = router.handle(req).await.expect("handle");
981        assert_eq!(result.status(), http::StatusCode::OK);
982    }
983
984    #[tokio::test]
985    async fn test_method_not_allowed_when_path_exists() {
986        let mut router = Router::new();
987        router.get("/users", || async { Ok(crate::OxiditeResponse::text("ok")) });
988        let req = http::Request::builder()
989            .method(Method::POST)
990            .uri("/users")
991            .body(BoxBody::default())
992            .expect("request");
993
994        let result = router.handle(req).await;
995        assert!(matches!(result, Err(Error::MethodNotAllowed(_))));
996    }
997
998    #[tokio::test]
999    async fn test_not_found_when_path_missing() {
1000        let router = Router::new();
1001        let req = http::Request::builder()
1002            .method(Method::GET)
1003            .uri("/missing")
1004            .body(BoxBody::default())
1005            .expect("request");
1006
1007        let result = router.handle(req).await;
1008        assert!(matches!(result, Err(Error::NotFound(_))));
1009    }
1010
1011    #[tokio::test]
1012    async fn test_handler_with_8_extractors() {
1013        use crate::extract::State;
1014
1015        #[derive(Clone)]
1016        struct AppState;
1017
1018        async fn h8(
1019            _s1: State<AppState>,
1020            _s2: State<AppState>,
1021            _s3: State<AppState>,
1022            _s4: State<AppState>,
1023            _s5: State<AppState>,
1024            _s6: State<AppState>,
1025            _s7: State<AppState>,
1026            _s8: State<AppState>,
1027        ) -> Result<OxiditeResponse> {
1028            Ok(OxiditeResponse::text("ok"))
1029        }
1030
1031        let mut router = Router::new();
1032        router.with_state(AppState);
1033        router.get("/", h8);
1034
1035        let req = http::Request::builder()
1036            .method(Method::GET)
1037            .uri("/")
1038            .body(BoxBody::default())
1039            .expect("request");
1040
1041        let result = router.handle(req).await.expect("handle");
1042        assert_eq!(result.status(), http::StatusCode::OK);
1043    }
1044
1045    #[tokio::test]
1046    async fn test_router_layer() {
1047        use tower::Layer;
1048
1049        struct MyMiddleware<S>(S);
1050        impl<S: Endpoint> Endpoint for MyMiddleware<S> {
1051            fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
1052                req.extensions_mut().insert("middleware_called".to_string());
1053                self.0.call(req)
1054            }
1055        }
1056
1057        struct MyLayer;
1058        impl<S> Layer<S> for MyLayer {
1059            type Service = MyMiddleware<S>;
1060            fn layer(&self, inner: S) -> Self::Service {
1061                MyMiddleware(inner)
1062            }
1063        }
1064
1065        // We need to check if extensions were modified.
1066        // But the handler in this test doesn't check it.
1067        // Let's modify the handler to check.
1068
1069        let mut router = Router::new()
1070            .layer(MyLayer);
1071        router.get("/", |req: OxiditeRequest| async move {
1072            if req.extensions().get::<String>().map(|s| s == "middleware_called").unwrap_or(false) {
1073                Ok(OxiditeResponse::text("middleware_ok"))
1074            } else {
1075                Ok(OxiditeResponse::text("middleware_fail"))
1076            }
1077        });
1078
1079        let req = http::Request::builder()
1080            .method(Method::GET)
1081            .uri("/")
1082            .body(BoxBody::default())
1083            .expect("request");
1084
1085        let res = router.handle(req).await.expect("handle");
1086        let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1087        let body = hyper_res.into_body();
1088        use http_body_util::BodyExt;
1089        let bytes = body.collect().await.unwrap().to_bytes();
1090        assert_eq!(bytes, "middleware_ok");
1091    }
1092
1093    #[tokio::test]
1094    async fn test_cors_config_default() {
1095        let config = CorsConfig::default();
1096        assert_eq!(config.allowed_origins, vec!["*"]);
1097        assert!(!config.allowed_methods.is_empty());
1098        assert_eq!(config.allowed_headers, vec!["*"]);
1099        assert!(!config.allow_credentials);
1100        assert_eq!(config.max_age, 3600);
1101    }
1102
1103    #[tokio::test]
1104    async fn test_cors_config_permissive() {
1105        let config = CorsConfig::permissive();
1106        assert_eq!(config.allowed_origins, vec!["*"]);
1107        assert_eq!(config.allowed_headers, vec!["*"]);
1108    }
1109
1110    #[tokio::test]
1111    async fn test_cors_config_restrictive() {
1112        let config = CorsConfig::restrictive();
1113        assert!(config.allowed_origins.is_empty());
1114        assert_eq!(config.allowed_methods, vec!["GET", "POST"]);
1115        assert_eq!(config.allowed_headers, vec!["Content-Type"]);
1116    }
1117
1118    #[tokio::test]
1119    async fn test_cors_preflight_response() {
1120        let mut router = Router::new()
1121            .with_cors(CorsConfig {
1122                allowed_origins: vec!["http://localhost:3000".to_string()],
1123                allowed_methods: vec!["GET".to_string(), "POST".to_string()],
1124                allowed_headers: vec!["Content-Type".to_string()],
1125                allow_credentials: true,
1126                max_age: 7200,
1127            });
1128        router.get("/test", || async { Ok(OxiditeResponse::text("ok")) });
1129
1130        let req = http::Request::builder()
1131            .method(Method::OPTIONS)
1132            .uri("/test")
1133            .body(BoxBody::default())
1134            .expect("request");
1135
1136        let res = router.handle(req).await.expect("handle");
1137        let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1138        
1139        // Should be 204 No Content for preflight
1140        assert_eq!(hyper_res.status(), http::StatusCode::NO_CONTENT);
1141        
1142        // Check CORS headers
1143        let headers = hyper_res.headers();
1144        assert_eq!(
1145            headers.get("Access-Control-Allow-Origin").unwrap(),
1146            "http://localhost:3000"
1147        );
1148        assert_eq!(
1149            headers.get("Access-Control-Allow-Methods").unwrap(),
1150            "GET, POST"
1151        );
1152        assert_eq!(
1153            headers.get("Access-Control-Allow-Headers").unwrap(),
1154            "Content-Type"
1155        );
1156        assert_eq!(
1157            headers.get("Access-Control-Allow-Credentials").unwrap(),
1158            "true"
1159        );
1160        assert_eq!(
1161            headers.get("Access-Control-Max-Age").unwrap(),
1162            "7200"
1163        );
1164    }
1165
1166    #[tokio::test]
1167    async fn test_cors_on_successful_response() {
1168        let mut router = Router::new()
1169            .with_cors(CorsConfig::permissive());
1170        router.get("/test", || async { Ok(OxiditeResponse::text("ok")) });
1171
1172        let req = http::Request::builder()
1173            .method(Method::GET)
1174            .uri("/test")
1175            .body(BoxBody::default())
1176            .expect("request");
1177
1178        let res = router.handle(req).await.expect("handle");
1179        let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1180        
1181        // Check CORS headers are present
1182        let headers = hyper_res.headers();
1183        assert_eq!(
1184            headers.get("Access-Control-Allow-Origin").unwrap(),
1185            "*"
1186        );
1187    }
1188
1189    #[tokio::test]
1190    async fn test_cors_wildcard_origin() {
1191        let mut router = Router::new()
1192            .with_cors(CorsConfig::permissive());
1193        router.get("/test", || async { Ok(OxiditeResponse::text("ok")) });
1194
1195        let req = http::Request::builder()
1196            .method(Method::GET)
1197            .uri("/test")
1198            .body(BoxBody::default())
1199            .expect("request");
1200
1201        let res = router.handle(req).await.expect("handle");
1202        let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1203        
1204        let headers = hyper_res.headers();
1205        assert_eq!(
1206            headers.get("Access-Control-Allow-Origin").unwrap(),
1207            "*"
1208        );
1209    }
1210
1211    #[tokio::test]
1212    async fn test_no_cors_when_not_configured() {
1213        let mut router = Router::new();
1214        router.get("/test", || async { Ok(OxiditeResponse::text("ok")) });
1215
1216        let req = http::Request::builder()
1217            .method(Method::GET)
1218            .uri("/test")
1219            .body(BoxBody::default())
1220            .expect("request");
1221
1222        let res = router.handle(req).await.expect("handle");
1223        let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1224        
1225        // Should NOT have CORS headers
1226        let headers = hyper_res.headers();
1227        assert!(headers.get("Access-Control-Allow-Origin").is_none());
1228        assert!(headers.get("Access-Control-Allow-Methods").is_none());
1229    }
1230}