1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse};
3use crate::extract::FromRequest;
4use hyper::Method;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tower_service::Service;
11
12use regex::Regex;
13
14pub trait Endpoint: Send + Sync + 'static {
16 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
17}
18
19impl Endpoint for Arc<dyn Endpoint> {
20 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
21 (**self).call(req)
22 }
23}
24
25pub struct EndpointService(pub Arc<dyn Endpoint>);
26
27impl Clone for EndpointService {
28 fn clone(&self) -> Self {
29 Self(self.0.clone())
30 }
31}
32
33impl Service<OxiditeRequest> for EndpointService {
34 type Response = OxiditeResponse;
35 type Error = Error;
36 type Future = Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
37
38 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
39 Poll::Ready(Ok(()))
40 }
41
42 fn call(&mut self, req: OxiditeRequest) -> Self::Future {
43 self.0.call(req)
44 }
45}
46
47impl Endpoint for EndpointService {
48 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
49 self.0.call(req)
50 }
51}
52
53#[derive(Clone, Debug)]
55pub struct CorsConfig {
56 pub allowed_origins: Vec<String>,
59 pub allowed_methods: Vec<String>,
62 pub allowed_headers: Vec<String>,
65 pub allow_credentials: bool,
67 pub max_age: u32,
69}
70
71impl Default for CorsConfig {
72 fn default() -> Self {
73 Self {
74 allowed_origins: vec!["*".to_string()],
75 allowed_methods: vec!["GET".to_string(), "POST".to_string(), "PUT".to_string(), "DELETE".to_string(), "OPTIONS".to_string(), "PATCH".to_string()],
76 allowed_headers: vec!["*".to_string()],
77 allow_credentials: false,
78 max_age: 3600,
79 }
80 }
81}
82
83impl CorsConfig {
84 pub fn permissive() -> Self {
86 Self::default()
87 }
88
89 pub fn restrictive() -> Self {
91 Self {
92 allowed_origins: Vec::new(),
93 allowed_methods: vec!["GET".to_string(), "POST".to_string()],
94 allowed_headers: vec!["Content-Type".to_string()],
95 allow_credentials: false,
96 max_age: 3600,
97 }
98 }
99}
100
101pub trait Handler<Args>: Clone + Send + Sync + 'static {
111 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>>;
112}
113
114struct HandlerService<H, Args> {
116 handler: H,
117 _marker: std::marker::PhantomData<Args>,
118}
119
120impl<H, Args> Endpoint for HandlerService<H, Args>
121where
122 H: Handler<Args>,
123 Args: Send + Sync + 'static,
124{
125 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
126 self.handler.call(req)
127 }
128}
129
130impl<F, Fut> Handler<OxiditeRequest> for F
132where
133 F: Fn(OxiditeRequest) -> Fut + Clone + Send + Sync + 'static,
134 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
135{
136 fn call(&self, req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
137 let fut = self(req);
138 Box::pin(async move { fut.await })
139 }
140}
141
142impl<F, Fut> Handler<()> for F
144where
145 F: Fn() -> Fut + Clone + Send + Sync + 'static,
146 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
147{
148 fn call(&self, _req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
149 let fut = self();
150 Box::pin(async move { fut.await })
151 }
152}
153
154impl<F, Fut, T1> Handler<(T1,)> for F
156where
157 F: Fn(T1) -> Fut + Clone + Send + Sync + 'static,
158 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
159 T1: FromRequest + Send + 'static,
160{
161 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
162 let handler = self.clone();
163 Box::pin(async move {
164 let t1 = T1::from_request(&mut req).await?;
165 handler(t1).await
166 })
167 }
168}
169
170impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8, T9> Handler<(T1, T2, T3, T4, T5, T6, T7, T8, T9)> for F
172where
173 F: Fn(T1, T2, T3, T4, T5, T6, T7, T8, T9) -> Fut + Clone + Send + Sync + 'static,
174 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
175 T1: FromRequest + Send + 'static,
176 T2: FromRequest + Send + 'static,
177 T3: FromRequest + Send + 'static,
178 T4: FromRequest + Send + 'static,
179 T5: FromRequest + Send + 'static,
180 T6: FromRequest + Send + 'static,
181 T7: FromRequest + Send + 'static,
182 T8: FromRequest + Send + 'static,
183 T9: FromRequest + Send + 'static,
184{
185 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
186 let handler = self.clone();
187 Box::pin(async move {
188 let t1 = T1::from_request(&mut req).await?;
189 let t2 = T2::from_request(&mut req).await?;
190 let t3 = T3::from_request(&mut req).await?;
191 let t4 = T4::from_request(&mut req).await?;
192 let t5 = T5::from_request(&mut req).await?;
193 let t6 = T6::from_request(&mut req).await?;
194 let t7 = T7::from_request(&mut req).await?;
195 let t8 = T8::from_request(&mut req).await?;
196 let t9 = T9::from_request(&mut req).await?;
197 handler(t1, t2, t3, t4, t5, t6, t7, t8, t9).await
198 })
199 }
200}
201
202impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10> Handler<(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10)> for F
204where
205 F: Fn(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) -> Fut + Clone + Send + Sync + 'static,
206 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
207 T1: FromRequest + Send + 'static,
208 T2: FromRequest + Send + 'static,
209 T3: FromRequest + Send + 'static,
210 T4: FromRequest + Send + 'static,
211 T5: FromRequest + Send + 'static,
212 T6: FromRequest + Send + 'static,
213 T7: FromRequest + Send + 'static,
214 T8: FromRequest + Send + 'static,
215 T9: FromRequest + Send + 'static,
216 T10: FromRequest + Send + 'static,
217{
218 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
219 let handler = self.clone();
220 Box::pin(async move {
221 let t1 = T1::from_request(&mut req).await?;
222 let t2 = T2::from_request(&mut req).await?;
223 let t3 = T3::from_request(&mut req).await?;
224 let t4 = T4::from_request(&mut req).await?;
225 let t5 = T5::from_request(&mut req).await?;
226 let t6 = T6::from_request(&mut req).await?;
227 let t7 = T7::from_request(&mut req).await?;
228 let t8 = T8::from_request(&mut req).await?;
229 let t9 = T9::from_request(&mut req).await?;
230 let t10 = T10::from_request(&mut req).await?;
231 handler(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10).await
232 })
233 }
234}
235
236impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11> Handler<(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11)> for F
238where
239 F: Fn(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) -> Fut + Clone + Send + Sync + 'static,
240 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
241 T1: FromRequest + Send + 'static,
242 T2: FromRequest + Send + 'static,
243 T3: FromRequest + Send + 'static,
244 T4: FromRequest + Send + 'static,
245 T5: FromRequest + Send + 'static,
246 T6: FromRequest + Send + 'static,
247 T7: FromRequest + Send + 'static,
248 T8: FromRequest + Send + 'static,
249 T9: FromRequest + Send + 'static,
250 T10: FromRequest + Send + 'static,
251 T11: FromRequest + Send + 'static,
252{
253 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
254 let handler = self.clone();
255 Box::pin(async move {
256 let t1 = T1::from_request(&mut req).await?;
257 let t2 = T2::from_request(&mut req).await?;
258 let t3 = T3::from_request(&mut req).await?;
259 let t4 = T4::from_request(&mut req).await?;
260 let t5 = T5::from_request(&mut req).await?;
261 let t6 = T6::from_request(&mut req).await?;
262 let t7 = T7::from_request(&mut req).await?;
263 let t8 = T8::from_request(&mut req).await?;
264 let t9 = T9::from_request(&mut req).await?;
265 let t10 = T10::from_request(&mut req).await?;
266 let t11 = T11::from_request(&mut req).await?;
267 handler(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11).await
268 })
269 }
270}
271
272impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12> Handler<(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12)> for F
274where
275 F: Fn(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) -> Fut + Clone + Send + Sync + 'static,
276 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
277 T1: FromRequest + Send + 'static,
278 T2: FromRequest + Send + 'static,
279 T3: FromRequest + Send + 'static,
280 T4: FromRequest + Send + 'static,
281 T5: FromRequest + Send + 'static,
282 T6: FromRequest + Send + 'static,
283 T7: FromRequest + Send + 'static,
284 T8: FromRequest + Send + 'static,
285 T9: FromRequest + Send + 'static,
286 T10: FromRequest + Send + 'static,
287 T11: FromRequest + Send + 'static,
288 T12: FromRequest + Send + 'static,
289{
290 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
291 let handler = self.clone();
292 Box::pin(async move {
293 let t1 = T1::from_request(&mut req).await?;
294 let t2 = T2::from_request(&mut req).await?;
295 let t3 = T3::from_request(&mut req).await?;
296 let t4 = T4::from_request(&mut req).await?;
297 let t5 = T5::from_request(&mut req).await?;
298 let t6 = T6::from_request(&mut req).await?;
299 let t7 = T7::from_request(&mut req).await?;
300 let t8 = T8::from_request(&mut req).await?;
301 let t9 = T9::from_request(&mut req).await?;
302 let t10 = T10::from_request(&mut req).await?;
303 let t11 = T11::from_request(&mut req).await?;
304 let t12 = T12::from_request(&mut req).await?;
305 handler(t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12).await
306 })
307 }
308}
309
310impl<F, Fut, T1, T2> Handler<(T1, T2)> for F
312where
313 F: Fn(T1, T2) -> Fut + Clone + Send + Sync + 'static,
314 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
315 T1: FromRequest + Send + 'static,
316 T2: FromRequest + Send + 'static,
317{
318 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
319 let handler = self.clone();
320 Box::pin(async move {
321 let t1 = T1::from_request(&mut req).await?;
322 let t2 = T2::from_request(&mut req).await?;
323 handler(t1, t2).await
324 })
325 }
326}
327
328impl<F, Fut, T1, T2, T3> Handler<(T1, T2, T3)> for F
330where
331 F: Fn(T1, T2, T3) -> Fut + Clone + Send + Sync + 'static,
332 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
333 T1: FromRequest + Send + 'static,
334 T2: FromRequest + Send + 'static,
335 T3: FromRequest + Send + 'static,
336{
337 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
338 let handler = self.clone();
339 Box::pin(async move {
340 let t1 = T1::from_request(&mut req).await?;
341 let t2 = T2::from_request(&mut req).await?;
342 let t3 = T3::from_request(&mut req).await?;
343 handler(t1, t2, t3).await
344 })
345 }
346}
347
348impl<F, Fut, T1, T2, T3, T4> Handler<(T1, T2, T3, T4)> for F
350where
351 F: Fn(T1, T2, T3, T4) -> Fut + Clone + Send + Sync + 'static,
352 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
353 T1: FromRequest + Send + 'static,
354 T2: FromRequest + Send + 'static,
355 T3: FromRequest + Send + 'static,
356 T4: FromRequest + Send + 'static,
357{
358 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
359 let handler = self.clone();
360 Box::pin(async move {
361 let t1 = T1::from_request(&mut req).await?;
362 let t2 = T2::from_request(&mut req).await?;
363 let t3 = T3::from_request(&mut req).await?;
364 let t4 = T4::from_request(&mut req).await?;
365 handler(t1, t2, t3, t4).await
366 })
367 }
368}
369
370impl<F, Fut, T1, T2, T3, T4, T5> Handler<(T1, T2, T3, T4, T5)> for F
372where
373 F: Fn(T1, T2, T3, T4, T5) -> Fut + Clone + Send + Sync + 'static,
374 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
375 T1: FromRequest + Send + 'static,
376 T2: FromRequest + Send + 'static,
377 T3: FromRequest + Send + 'static,
378 T4: FromRequest + Send + 'static,
379 T5: FromRequest + Send + 'static,
380{
381 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
382 let handler = self.clone();
383 Box::pin(async move {
384 let t1 = T1::from_request(&mut req).await?;
385 let t2 = T2::from_request(&mut req).await?;
386 let t3 = T3::from_request(&mut req).await?;
387 let t4 = T4::from_request(&mut req).await?;
388 let t5 = T5::from_request(&mut req).await?;
389 handler(t1, t2, t3, t4, t5).await
390 })
391 }
392}
393
394impl<F, Fut, T1, T2, T3, T4, T5, T6> Handler<(T1, T2, T3, T4, T5, T6)> for F
396where
397 F: Fn(T1, T2, T3, T4, T5, T6) -> Fut + Clone + Send + Sync + 'static,
398 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
399 T1: FromRequest + Send + 'static,
400 T2: FromRequest + Send + 'static,
401 T3: FromRequest + Send + 'static,
402 T4: FromRequest + Send + 'static,
403 T5: FromRequest + Send + 'static,
404 T6: FromRequest + Send + 'static,
405{
406 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
407 let handler = self.clone();
408 Box::pin(async move {
409 let t1 = T1::from_request(&mut req).await?;
410 let t2 = T2::from_request(&mut req).await?;
411 let t3 = T3::from_request(&mut req).await?;
412 let t4 = T4::from_request(&mut req).await?;
413 let t5 = T5::from_request(&mut req).await?;
414 let t6 = T6::from_request(&mut req).await?;
415 handler(t1, t2, t3, t4, t5, t6).await
416 })
417 }
418}
419
420impl<F, Fut, T1, T2, T3, T4, T5, T6, T7> Handler<(T1, T2, T3, T4, T5, T6, T7)> for F
422where
423 F: Fn(T1, T2, T3, T4, T5, T6, T7) -> Fut + Clone + Send + Sync + 'static,
424 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
425 T1: FromRequest + Send + 'static,
426 T2: FromRequest + Send + 'static,
427 T3: FromRequest + Send + 'static,
428 T4: FromRequest + Send + 'static,
429 T5: FromRequest + Send + 'static,
430 T6: FromRequest + Send + 'static,
431 T7: FromRequest + Send + 'static,
432{
433 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
434 let handler = self.clone();
435 Box::pin(async move {
436 let t1 = T1::from_request(&mut req).await?;
437 let t2 = T2::from_request(&mut req).await?;
438 let t3 = T3::from_request(&mut req).await?;
439 let t4 = T4::from_request(&mut req).await?;
440 let t5 = T5::from_request(&mut req).await?;
441 let t6 = T6::from_request(&mut req).await?;
442 let t7 = T7::from_request(&mut req).await?;
443 handler(t1, t2, t3, t4, t5, t6, t7).await
444 })
445 }
446}
447
448impl<F, Fut, T1, T2, T3, T4, T5, T6, T7, T8> Handler<(T1, T2, T3, T4, T5, T6, T7, T8)> for F
450where
451 F: Fn(T1, T2, T3, T4, T5, T6, T7, T8) -> Fut + Clone + Send + Sync + 'static,
452 Fut: Future<Output = Result<OxiditeResponse>> + Send + 'static,
453 T1: FromRequest + Send + 'static,
454 T2: FromRequest + Send + 'static,
455 T3: FromRequest + Send + 'static,
456 T4: FromRequest + Send + 'static,
457 T5: FromRequest + Send + 'static,
458 T6: FromRequest + Send + 'static,
459 T7: FromRequest + Send + 'static,
460 T8: FromRequest + Send + 'static,
461{
462 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
463 let handler = self.clone();
464 Box::pin(async move {
465 let t1 = T1::from_request(&mut req).await?;
466 let t2 = T2::from_request(&mut req).await?;
467 let t3 = T3::from_request(&mut req).await?;
468 let t4 = T4::from_request(&mut req).await?;
469 let t5 = T5::from_request(&mut req).await?;
470 let t6 = T6::from_request(&mut req).await?;
471 let t7 = T7::from_request(&mut req).await?;
472 let t8 = T8::from_request(&mut req).await?;
473 handler(t1, t2, t3, t4, t5, t6, t7, t8).await
474 })
475 }
476}
477
478struct Route {
479 pattern: Regex,
480 param_names: Vec<String>,
481 handler: Arc<dyn Endpoint>,
482}
483
484#[derive(Clone)]
485pub struct Router {
486 routes: HashMap<Method, Vec<Arc<Route>>>,
487 extensions: Arc<std::sync::RwLock<http::Extensions>>,
488 middleware: Vec<Arc<dyn Fn(Arc<dyn Endpoint>) -> Arc<dyn Endpoint> + Send + Sync>>,
489 cors_config: Option<CorsConfig>,
490}
491
492impl Router {
493 pub fn new() -> Self {
494 Self {
495 routes: HashMap::new(),
496 extensions: Arc::new(std::sync::RwLock::new(http::Extensions::new())),
497 middleware: Vec::new(),
498 cors_config: None,
499 }
500 }
501
502 pub fn with_state<T: Clone + Send + Sync + 'static>(&mut self, state: T) {
504 if let Ok(mut exts) = self.extensions.write() {
505 exts.insert(state);
506 }
507 }
508
509 pub fn get<H, Args>(&mut self, path: &str, handler: H)
510 where
511 H: Handler<Args>,
512 Args: Send + Sync + 'static,
513 {
514 self.add_route(Method::GET, path, handler);
515 }
516
517 pub fn post<H, Args>(&mut self, path: &str, handler: H)
518 where
519 H: Handler<Args>,
520 Args: Send + Sync + 'static,
521 {
522 self.add_route(Method::POST, path, handler);
523 }
524
525 pub fn put<H, Args>(&mut self, path: &str, handler: H)
526 where
527 H: Handler<Args>,
528 Args: Send + Sync + 'static,
529 {
530 self.add_route(Method::PUT, path, handler);
531 }
532
533 pub fn delete<H, Args>(&mut self, path: &str, handler: H)
534 where
535 H: Handler<Args>,
536 Args: Send + Sync + 'static,
537 {
538 self.add_route(Method::DELETE, path, handler);
539 }
540
541 pub fn patch<H, Args>(&mut self, path: &str, handler: H)
542 where
543 H: Handler<Args>,
544 Args: Send + Sync + 'static,
545 {
546 self.add_route(Method::PATCH, path, handler);
547 }
548
549 pub fn layer<L>(mut self, layer: L) -> Self
567 where
568 L: tower::Layer<EndpointService> + Send + Sync + 'static,
569 L::Service: Endpoint,
570 {
571 let layer = Arc::new(layer);
572 self.middleware.push(Arc::new(move |endpoint| {
573 Arc::new(layer.layer(EndpointService(endpoint)))
574 }));
575 self
576 }
577
578 pub fn with_layer<L>(self, layer: L) -> Self
580 where
581 L: tower::Layer<EndpointService> + Send + Sync + 'static,
582 L::Service: Endpoint,
583 {
584 self.layer(layer)
585 }
586
587 pub fn with_cors(mut self, config: CorsConfig) -> Self {
614 self.cors_config = Some(config);
615 self
616 }
617
618 fn add_route<H, Args>(&mut self, method: Method, path: &str, handler: H)
619 where
620 H: Handler<Args>,
621 Args: Send + Sync + 'static,
622 {
623 let (pattern, param_names) = compile_path(path);
624 let mut endpoint: Arc<dyn Endpoint> = Arc::new(HandlerService {
625 handler,
626 _marker: std::marker::PhantomData,
627 });
628
629 for mw in self.middleware.iter().rev() {
631 endpoint = mw(endpoint);
632 }
633
634 let route = Arc::new(Route {
635 pattern,
636 param_names,
637 handler: endpoint,
638 });
639
640 self.routes
641 .entry(method)
642 .or_insert_with(Vec::new)
643 .push(route);
644 }
645
646 fn add_cors_headers(&self, mut builder: http::response::Builder) -> http::response::Builder {
648 if let Some(cors) = &self.cors_config {
649 if !cors.allowed_origins.is_empty() {
651 if cors.allowed_origins.contains(&"*".to_string()) {
652 builder = builder.header("Access-Control-Allow-Origin", "*");
653 } else {
654 if let Some(origin) = cors.allowed_origins.first() {
657 builder = builder.header("Access-Control-Allow-Origin", origin);
658 }
659 }
660 }
661
662 if !cors.allowed_methods.is_empty() {
664 let methods = cors.allowed_methods.join(", ");
665 builder = builder.header("Access-Control-Allow-Methods", methods);
666 }
667
668 if !cors.allowed_headers.is_empty() {
670 let headers = cors.allowed_headers.join(", ");
671 builder = builder.header("Access-Control-Allow-Headers", headers);
672 }
673
674 if cors.allow_credentials {
676 builder = builder.header("Access-Control-Allow-Credentials", "true");
677 }
678
679 builder = builder.header("Access-Control-Max-Age", cors.max_age.to_string());
681 }
682 builder
683 }
684
685 pub async fn handle(&self, mut req: OxiditeRequest) -> Result<OxiditeResponse> {
686 req.extensions_mut().insert(self.extensions.clone());
687 let method = req.method().clone();
688 let path = req.uri().path().to_string();
689
690 let try_match = |target_method: &Method, req: &mut OxiditeRequest| -> Option<Arc<Route>> {
692 if let Some(routes) = self.routes.get(target_method) {
693 for route in routes {
694 if let Some(captures) = route.pattern.captures(&path) {
695 let mut params = serde_json::Map::new();
697 for (i, name) in route.param_names.iter().enumerate() {
698 if let Some(value) = captures.get(i + 1) {
699 params.insert(
700 name.clone(),
701 serde_json::Value::String(value.as_str().to_string()),
702 );
703 }
704 }
705
706 if !params.is_empty() {
708 req.extensions_mut().insert(crate::extract::PathParams(
709 serde_json::Value::Object(params),
710 ));
711 }
712
713 return Some(route.clone());
714 }
715 }
716 }
717 None
718 };
719
720 if let Some(route) = try_match(&method, &mut req) {
722 req.extensions_mut().insert(self.extensions.clone());
724 let response = route.handler.call(req).await?;
725
726 if self.cors_config.is_some() {
728 let hyper_response: hyper::Response<crate::types::BoxBody> = response.into();
729 let (parts, body) = hyper_response.into_parts();
730 let mut builder = http::Response::builder()
731 .status(parts.status);
732
733 for (key, value) in parts.headers {
735 if let Some(key) = key {
736 builder = builder.header(key, value);
737 }
738 }
739
740 builder = self.add_cors_headers(builder);
742
743 return Ok(OxiditeResponse::new(builder.body(body).unwrap()));
744 }
745
746 return Ok(response);
747 }
748
749 if method == Method::OPTIONS {
751 if let Some(_route) = try_match(&Method::OPTIONS, &mut req) {
752 } else {
754 let mut builder = http::Response::builder()
756 .status(http::StatusCode::NO_CONTENT);
757
758 builder = self.add_cors_headers(builder);
760
761 return Ok(OxiditeResponse::new(builder
762 .body(crate::types::BoxBody::default())
763 .unwrap()));
764 }
765 }
766
767 if method == Method::HEAD {
769 if let Some(route) = try_match(&Method::GET, &mut req) {
770 req.extensions_mut().insert(self.extensions.clone());
772 return route.handler.call(req).await;
775 }
776 }
777
778 let allowed_methods: Vec<String> = self
780 .routes
781 .iter()
782 .filter(|(route_method, _)| **route_method != method)
783 .filter_map(|(route_method, routes)| {
784 if routes.iter().any(|route| route.pattern.is_match(&path)) {
785 Some(route_method.as_str().to_string())
786 } else {
787 None
788 }
789 })
790 .collect();
791 if !allowed_methods.is_empty() {
792 return Err(Error::MethodNotAllowed(format!(
793 "{} {} (allowed: {})",
794 method,
795 path,
796 allowed_methods.join(", ")
797 )));
798 }
799
800 Err(Error::NotFound("Route not found".to_string()))
801 }
802}
803
804impl Service<OxiditeRequest> for Router {
805 type Response = OxiditeResponse;
806 type Error = Error;
807 type Future = Pin<Box<dyn Future<Output = Result<Self::Response>> + Send>>;
808
809 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
810 Poll::Ready(Ok(()))
811 }
812
813 fn call(&mut self, req: OxiditeRequest) -> Self::Future {
814 let router = self.clone();
815 Box::pin(async move {
816 router.handle(req).await
817 })
818 }
819}
820
821impl Default for Router {
822 fn default() -> Self {
823 Self::new()
824 }
825}
826
827fn compile_path(path: &str) -> (Regex, Vec<String>) {
830 let mut pattern = String::from("^");
831 let mut param_names = Vec::new();
832 let mut chars = path.chars().peekable();
833
834 while let Some(ch) = chars.next() {
835 match ch {
836 ':' => {
837 let mut param_name = String::new();
839 while let Some(&next_ch) = chars.peek() {
840 if next_ch.is_alphanumeric() || next_ch == '_' {
841 param_name.push(next_ch);
842 chars.next();
843 } else {
844 break;
845 }
846 }
847 param_names.push(param_name);
848 pattern.push_str("([^/]+)");
849 }
850 '*' => {
851 pattern.push_str("(.*)");
853 }
854 '.' | '+' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '\\' => {
855 pattern.push('\\');
857 pattern.push(ch);
858 }
859 _ => {
860 pattern.push(ch);
861 }
862 }
863 }
864
865 pattern.push('$');
866 let regex = Regex::new(&pattern).expect("Invalid route pattern");
867 (regex, param_names)
868}
869
870pub trait IntoHandler<Args>: Handler<Args> + Sized {
889 fn into_handler(self) -> Self {
890 self
891 }
892}
893
894impl<H, Args> IntoHandler<Args> for H where H: Handler<Args> {}
895
896pub fn handler_fn<H, Args>(handler: H) -> H
914where
915 H: IntoHandler<Args>,
916 Args: Send + Sync + 'static,
917{
918 handler
919}
920
921#[cfg(test)]
922mod tests {
923 use super::*;
924 use crate::types::BoxBody;
925
926 #[test]
927 fn test_compile_path() {
928 let (regex, params) = compile_path("/users/:id");
929 assert_eq!(params, vec!["id"]);
930 assert!(regex.is_match("/users/123"));
931 assert!(!regex.is_match("/users/123/posts"));
932
933 let (regex, params) = compile_path("/users/:user_id/posts/:post_id");
934 assert_eq!(params, vec!["user_id", "post_id"]);
935 assert!(regex.is_match("/users/1/posts/2"));
936 }
937
938 #[test]
939 fn test_exact_match() {
940 let (regex, params) = compile_path("/users");
941 assert_eq!(params.len(), 0);
942 assert!(regex.is_match("/users"));
943 assert!(!regex.is_match("/users/123"));
944 }
945
946 #[tokio::test]
947 async fn test_handler_with_12_extractors() {
948 use crate::extract::State;
949
950 #[derive(Clone)]
951 struct AppState;
952
953 async fn h12(
954 _s1: State<AppState>,
955 _s2: State<AppState>,
956 _s3: State<AppState>,
957 _s4: State<AppState>,
958 _s5: State<AppState>,
959 _s6: State<AppState>,
960 _s7: State<AppState>,
961 _s8: State<AppState>,
962 _s9: State<AppState>,
963 _s10: State<AppState>,
964 _s11: State<AppState>,
965 _s12: State<AppState>,
966 ) -> Result<OxiditeResponse> {
967 Ok(OxiditeResponse::text("ok"))
968 }
969
970 let mut router = Router::new();
971 router.with_state(AppState);
972 router.get("/", h12);
973
974 let req = http::Request::builder()
975 .method(Method::GET)
976 .uri("/")
977 .body(BoxBody::default())
978 .expect("request");
979
980 let result = router.handle(req).await.expect("handle");
981 assert_eq!(result.status(), http::StatusCode::OK);
982 }
983
984 #[tokio::test]
985 async fn test_method_not_allowed_when_path_exists() {
986 let mut router = Router::new();
987 router.get("/users", || async { Ok(crate::OxiditeResponse::text("ok")) });
988 let req = http::Request::builder()
989 .method(Method::POST)
990 .uri("/users")
991 .body(BoxBody::default())
992 .expect("request");
993
994 let result = router.handle(req).await;
995 assert!(matches!(result, Err(Error::MethodNotAllowed(_))));
996 }
997
998 #[tokio::test]
999 async fn test_not_found_when_path_missing() {
1000 let router = Router::new();
1001 let req = http::Request::builder()
1002 .method(Method::GET)
1003 .uri("/missing")
1004 .body(BoxBody::default())
1005 .expect("request");
1006
1007 let result = router.handle(req).await;
1008 assert!(matches!(result, Err(Error::NotFound(_))));
1009 }
1010
1011 #[tokio::test]
1012 async fn test_handler_with_8_extractors() {
1013 use crate::extract::State;
1014
1015 #[derive(Clone)]
1016 struct AppState;
1017
1018 async fn h8(
1019 _s1: State<AppState>,
1020 _s2: State<AppState>,
1021 _s3: State<AppState>,
1022 _s4: State<AppState>,
1023 _s5: State<AppState>,
1024 _s6: State<AppState>,
1025 _s7: State<AppState>,
1026 _s8: State<AppState>,
1027 ) -> Result<OxiditeResponse> {
1028 Ok(OxiditeResponse::text("ok"))
1029 }
1030
1031 let mut router = Router::new();
1032 router.with_state(AppState);
1033 router.get("/", h8);
1034
1035 let req = http::Request::builder()
1036 .method(Method::GET)
1037 .uri("/")
1038 .body(BoxBody::default())
1039 .expect("request");
1040
1041 let result = router.handle(req).await.expect("handle");
1042 assert_eq!(result.status(), http::StatusCode::OK);
1043 }
1044
1045 #[tokio::test]
1046 async fn test_router_layer() {
1047 use tower::Layer;
1048
1049 struct MyMiddleware<S>(S);
1050 impl<S: Endpoint> Endpoint for MyMiddleware<S> {
1051 fn call(&self, mut req: OxiditeRequest) -> Pin<Box<dyn Future<Output = Result<OxiditeResponse>> + Send>> {
1052 req.extensions_mut().insert("middleware_called".to_string());
1053 self.0.call(req)
1054 }
1055 }
1056
1057 struct MyLayer;
1058 impl<S> Layer<S> for MyLayer {
1059 type Service = MyMiddleware<S>;
1060 fn layer(&self, inner: S) -> Self::Service {
1061 MyMiddleware(inner)
1062 }
1063 }
1064
1065 let mut router = Router::new()
1070 .layer(MyLayer);
1071 router.get("/", |req: OxiditeRequest| async move {
1072 if req.extensions().get::<String>().map(|s| s == "middleware_called").unwrap_or(false) {
1073 Ok(OxiditeResponse::text("middleware_ok"))
1074 } else {
1075 Ok(OxiditeResponse::text("middleware_fail"))
1076 }
1077 });
1078
1079 let req = http::Request::builder()
1080 .method(Method::GET)
1081 .uri("/")
1082 .body(BoxBody::default())
1083 .expect("request");
1084
1085 let res = router.handle(req).await.expect("handle");
1086 let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1087 let body = hyper_res.into_body();
1088 use http_body_util::BodyExt;
1089 let bytes = body.collect().await.unwrap().to_bytes();
1090 assert_eq!(bytes, "middleware_ok");
1091 }
1092
1093 #[tokio::test]
1094 async fn test_cors_config_default() {
1095 let config = CorsConfig::default();
1096 assert_eq!(config.allowed_origins, vec!["*"]);
1097 assert!(!config.allowed_methods.is_empty());
1098 assert_eq!(config.allowed_headers, vec!["*"]);
1099 assert!(!config.allow_credentials);
1100 assert_eq!(config.max_age, 3600);
1101 }
1102
1103 #[tokio::test]
1104 async fn test_cors_config_permissive() {
1105 let config = CorsConfig::permissive();
1106 assert_eq!(config.allowed_origins, vec!["*"]);
1107 assert_eq!(config.allowed_headers, vec!["*"]);
1108 }
1109
1110 #[tokio::test]
1111 async fn test_cors_config_restrictive() {
1112 let config = CorsConfig::restrictive();
1113 assert!(config.allowed_origins.is_empty());
1114 assert_eq!(config.allowed_methods, vec!["GET", "POST"]);
1115 assert_eq!(config.allowed_headers, vec!["Content-Type"]);
1116 }
1117
1118 #[tokio::test]
1119 async fn test_cors_preflight_response() {
1120 let mut router = Router::new()
1121 .with_cors(CorsConfig {
1122 allowed_origins: vec!["http://localhost:3000".to_string()],
1123 allowed_methods: vec!["GET".to_string(), "POST".to_string()],
1124 allowed_headers: vec!["Content-Type".to_string()],
1125 allow_credentials: true,
1126 max_age: 7200,
1127 });
1128 router.get("/test", || async { Ok(OxiditeResponse::text("ok")) });
1129
1130 let req = http::Request::builder()
1131 .method(Method::OPTIONS)
1132 .uri("/test")
1133 .body(BoxBody::default())
1134 .expect("request");
1135
1136 let res = router.handle(req).await.expect("handle");
1137 let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1138
1139 assert_eq!(hyper_res.status(), http::StatusCode::NO_CONTENT);
1141
1142 let headers = hyper_res.headers();
1144 assert_eq!(
1145 headers.get("Access-Control-Allow-Origin").unwrap(),
1146 "http://localhost:3000"
1147 );
1148 assert_eq!(
1149 headers.get("Access-Control-Allow-Methods").unwrap(),
1150 "GET, POST"
1151 );
1152 assert_eq!(
1153 headers.get("Access-Control-Allow-Headers").unwrap(),
1154 "Content-Type"
1155 );
1156 assert_eq!(
1157 headers.get("Access-Control-Allow-Credentials").unwrap(),
1158 "true"
1159 );
1160 assert_eq!(
1161 headers.get("Access-Control-Max-Age").unwrap(),
1162 "7200"
1163 );
1164 }
1165
1166 #[tokio::test]
1167 async fn test_cors_on_successful_response() {
1168 let mut router = Router::new()
1169 .with_cors(CorsConfig::permissive());
1170 router.get("/test", || async { Ok(OxiditeResponse::text("ok")) });
1171
1172 let req = http::Request::builder()
1173 .method(Method::GET)
1174 .uri("/test")
1175 .body(BoxBody::default())
1176 .expect("request");
1177
1178 let res = router.handle(req).await.expect("handle");
1179 let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1180
1181 let headers = hyper_res.headers();
1183 assert_eq!(
1184 headers.get("Access-Control-Allow-Origin").unwrap(),
1185 "*"
1186 );
1187 }
1188
1189 #[tokio::test]
1190 async fn test_cors_wildcard_origin() {
1191 let mut router = Router::new()
1192 .with_cors(CorsConfig::permissive());
1193 router.get("/test", || async { Ok(OxiditeResponse::text("ok")) });
1194
1195 let req = http::Request::builder()
1196 .method(Method::GET)
1197 .uri("/test")
1198 .body(BoxBody::default())
1199 .expect("request");
1200
1201 let res = router.handle(req).await.expect("handle");
1202 let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1203
1204 let headers = hyper_res.headers();
1205 assert_eq!(
1206 headers.get("Access-Control-Allow-Origin").unwrap(),
1207 "*"
1208 );
1209 }
1210
1211 #[tokio::test]
1212 async fn test_no_cors_when_not_configured() {
1213 let mut router = Router::new();
1214 router.get("/test", || async { Ok(OxiditeResponse::text("ok")) });
1215
1216 let req = http::Request::builder()
1217 .method(Method::GET)
1218 .uri("/test")
1219 .body(BoxBody::default())
1220 .expect("request");
1221
1222 let res = router.handle(req).await.expect("handle");
1223 let hyper_res: hyper::Response<crate::types::BoxBody> = res.into();
1224
1225 let headers = hyper_res.headers();
1227 assert!(headers.get("Access-Control-Allow-Origin").is_none());
1228 assert!(headers.get("Access-Control-Allow-Methods").is_none());
1229 }
1230}