1use bytes::Bytes;
4use http::{Method, StatusCode};
5use http_body_util::Full;
6use hyper::body::Incoming;
7use std::collections::HashMap;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use oxihttp_core::OxiHttpError;
13
14type StateFn = Box<dyn Fn(&mut http::Extensions) + Send + Sync>;
19
20pub type HandlerFn = Arc<
22 dyn Fn(
23 Request,
24 ) -> Pin<
25 Box<dyn Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send>,
26 > + Send
27 + Sync,
28>;
29
30#[derive(Debug)]
32pub struct Request {
33 inner: hyper::Request<Incoming>,
34 path_params: HashMap<String, String>,
35}
36
37impl Request {
38 pub fn new(inner: hyper::Request<Incoming>, path_params: HashMap<String, String>) -> Self {
40 Self { inner, path_params }
41 }
42
43 pub fn method(&self) -> &Method {
45 self.inner.method()
46 }
47
48 pub fn uri(&self) -> &http::Uri {
50 self.inner.uri()
51 }
52
53 pub fn headers(&self) -> &http::HeaderMap {
55 self.inner.headers()
56 }
57
58 pub fn path(&self) -> &str {
60 self.inner.uri().path()
61 }
62
63 pub fn param(&self, name: &str) -> Option<&str> {
65 self.path_params.get(name).map(|s| s.as_str())
66 }
67
68 pub fn params(&self) -> &HashMap<String, String> {
70 &self.path_params
71 }
72
73 pub fn query_params(&self) -> HashMap<String, String> {
75 self.inner
76 .uri()
77 .query()
78 .map(|q| {
79 q.split('&')
80 .filter_map(|pair| {
81 let (k, v) = pair.split_once('=')?;
82 Some((percent_decode(k), percent_decode(v)))
83 })
84 .collect()
85 })
86 .unwrap_or_default()
87 }
88
89 pub fn query(&self, name: &str) -> Option<String> {
91 self.query_params().remove(name)
92 }
93
94 pub fn into_inner(self) -> hyper::Request<Incoming> {
96 self.inner
97 }
98
99 pub async fn body_bytes(self) -> Result<Bytes, OxiHttpError> {
101 use http_body_util::BodyExt;
102 self.inner
103 .into_body()
104 .collect()
105 .await
106 .map(|c| c.to_bytes())
107 .map_err(|e| OxiHttpError::Body(e.to_string()))
108 }
109
110 pub async fn body_text(self) -> Result<String, OxiHttpError> {
112 let bytes = self.body_bytes().await?;
113 String::from_utf8(bytes.to_vec())
114 .map_err(|e| OxiHttpError::Body(format!("invalid UTF-8: {e}")))
115 }
116
117 pub async fn body_json<T: serde::de::DeserializeOwned>(self) -> Result<T, OxiHttpError> {
119 let bytes = self.body_bytes().await?;
120 serde_json::from_slice(&bytes).map_err(|e| OxiHttpError::Json(e.to_string()))
121 }
122
123 pub fn state<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
129 self.inner.extensions().get::<Arc<T>>().cloned()
130 }
131
132 pub fn extension<T: Clone + Send + Sync + 'static>(&self) -> Option<T> {
138 self.inner.extensions().get::<T>().cloned()
139 }
140
141 pub fn extensions(&self) -> &http::Extensions {
143 self.inner.extensions()
144 }
145
146 pub fn extensions_mut(&mut self) -> &mut http::Extensions {
151 self.inner.extensions_mut()
152 }
153
154 pub fn parts(&self) -> crate::extractor::RequestParts<'_> {
159 crate::extractor::RequestParts {
160 method: self.inner.method(),
161 uri: self.inner.uri(),
162 headers: self.inner.headers(),
163 path_params: &self.path_params,
164 }
165 }
166
167 pub fn extract<T: crate::extractor::FromRequestParts>(&self) -> Result<T, T::Rejection> {
175 T::from_request_parts(&self.parts())
176 }
177
178 pub fn negotiate(
184 &self,
185 supported: &[oxihttp_core::ContentType],
186 ) -> Option<oxihttp_core::ContentType> {
187 negotiate_from_headers(self.headers(), supported)
188 }
189
190 #[cfg(feature = "tls")]
197 pub fn tls_info(&self) -> Option<Arc<crate::tls::PeerCertInfo>> {
198 self.inner
199 .extensions()
200 .get::<Arc<crate::tls::PeerCertInfo>>()
201 .cloned()
202 }
203
204 #[cfg(feature = "tls")]
211 pub fn peer_certificates(&self) -> Option<Vec<rustls_pki_types::CertificateDer<'static>>> {
212 self.tls_info().and_then(|info| {
213 if info.peer_certificates.is_empty() {
214 None
215 } else {
216 Some(info.peer_certificates.clone())
217 }
218 })
219 }
220
221 #[cfg(feature = "tls")]
232 pub fn tls_connection_info(&self) -> Option<oxitls::ConnectionInfo> {
233 self.tls_info().map(|info| {
234 let mut ci = oxitls::ConnectionInfo::new();
235 if let Some(v) = info.version {
236 ci = ci.with_version(v);
237 }
238 if let Some(cs) = info.cipher_suite {
239 ci = ci.with_cipher_suite(cs);
240 }
241 if let Some(ref alpn) = info.alpn_protocol {
242 ci = ci.with_alpn_protocol(alpn.clone());
243 }
244 if let Some(ref sni) = info.sni {
245 ci = ci.with_sni(sni.clone());
246 }
247 if !info.peer_certificates.is_empty() {
248 let der_vecs: Vec<Vec<u8>> = info
249 .peer_certificates
250 .iter()
251 .map(|c| c.as_ref().to_vec())
252 .collect();
253 ci = ci.with_peer_certificates(der_vecs);
254 }
255 ci
256 })
257 }
258}
259
260fn negotiate_from_headers(
265 headers: &http::HeaderMap,
266 supported: &[oxihttp_core::ContentType],
267) -> Option<oxihttp_core::ContentType> {
268 let accept = headers
269 .get(http::header::ACCEPT)
270 .and_then(|v| v.to_str().ok())
271 .unwrap_or("*/*");
272 oxihttp_core::content_type::negotiate_content_type(accept, supported)
273}
274
275fn percent_decode(s: &str) -> String {
277 let mut result = String::with_capacity(s.len());
278 let mut chars = s.bytes();
279 while let Some(b) = chars.next() {
280 if b == b'%' {
281 let hi = chars.next();
282 let lo = chars.next();
283 if let (Some(h), Some(l)) = (hi, lo) {
284 let hex = [h, l];
285 if let Ok(decoded) = u8::from_str_radix(std::str::from_utf8(&hex).unwrap_or(""), 16)
286 {
287 result.push(decoded as char);
288 continue;
289 }
290 }
291 result.push('%');
292 } else if b == b'+' {
293 result.push(' ');
294 } else {
295 result.push(b as char);
296 }
297 }
298 result
299}
300
301struct Route {
303 method: Method,
304 segments: Vec<Segment>,
305 handler: HandlerFn,
306}
307
308#[derive(Debug, Clone)]
310enum Segment {
311 Literal(String),
313 Param(String),
315 Wildcard(String),
317}
318
319pub type DispatchFuture<'a> =
321 Pin<Box<dyn Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'a>>;
322
323pub struct Router {
325 routes: Vec<Route>,
326 nested: Vec<(String, Router)>,
327 vhosts: Vec<(String, Router)>,
328 fallback: Option<HandlerFn>,
329 method_not_allowed_handler: Option<HandlerFn>,
330 state: Option<StateFn>,
334}
335
336impl Router {
337 pub fn new() -> Self {
339 Self {
340 routes: Vec::new(),
341 nested: Vec::new(),
342 vhosts: Vec::new(),
343 fallback: None,
344 method_not_allowed_handler: None,
345 state: None,
346 }
347 }
348
349 pub fn with_state<T: Clone + Send + Sync + 'static>(mut self, state: T) -> Self {
373 let arc = Arc::new(state);
374 self.state = Some(Box::new(move |ext: &mut http::Extensions| {
375 ext.insert(Arc::clone(&arc));
376 }));
377 self
378 }
379
380 pub fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
387 where
388 F: Fn(Request) -> Fut + Send + Sync + 'static,
389 Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
390 {
391 let segments = parse_pattern(path);
392 let handler: HandlerFn = Arc::new(move |req| Box::pin(handler(req)));
393 self.routes.push(Route {
394 method,
395 segments,
396 handler,
397 });
398 self
399 }
400
401 pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
403 where
404 F: Fn(Request) -> Fut + Send + Sync + 'static,
405 Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
406 {
407 self.route(Method::GET, path, handler)
408 }
409
410 pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
412 where
413 F: Fn(Request) -> Fut + Send + Sync + 'static,
414 Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
415 {
416 self.route(Method::POST, path, handler)
417 }
418
419 pub fn put<F, Fut>(self, path: &str, handler: F) -> Self
421 where
422 F: Fn(Request) -> Fut + Send + Sync + 'static,
423 Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
424 {
425 self.route(Method::PUT, path, handler)
426 }
427
428 pub fn delete<F, Fut>(self, path: &str, handler: F) -> Self
430 where
431 F: Fn(Request) -> Fut + Send + Sync + 'static,
432 Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
433 {
434 self.route(Method::DELETE, path, handler)
435 }
436
437 pub fn patch<F, Fut>(self, path: &str, handler: F) -> Self
439 where
440 F: Fn(Request) -> Fut + Send + Sync + 'static,
441 Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
442 {
443 self.route(Method::PATCH, path, handler)
444 }
445
446 pub fn head<F, Fut>(self, path: &str, handler: F) -> Self
448 where
449 F: Fn(Request) -> Fut + Send + Sync + 'static,
450 Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
451 {
452 self.route(Method::HEAD, path, handler)
453 }
454
455 pub fn nest(mut self, prefix: &str, router: Router) -> Self {
457 let prefix = prefix.trim_end_matches('/').to_string();
458 self.nested.push((prefix, router));
459 self
460 }
461
462 pub fn host(mut self, host: &str, router: Router) -> Self {
484 self.vhosts.push((host.to_owned(), router));
485 self
486 }
487
488 pub fn fallback<F, Fut>(mut self, handler: F) -> Self
490 where
491 F: Fn(Request) -> Fut + Send + Sync + 'static,
492 Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
493 {
494 self.fallback = Some(Arc::new(move |req| Box::pin(handler(req))));
495 self
496 }
497
498 pub fn method_not_allowed<F, Fut>(mut self, handler: F) -> Self
500 where
501 F: Fn(Request) -> Fut + Send + Sync + 'static,
502 Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
503 {
504 self.method_not_allowed_handler = Some(Arc::new(move |req| Box::pin(handler(req))));
505 self
506 }
507
508 pub fn health(self, path: &str) -> Self {
510 self.get(path, |_req| async {
511 hyper::Response::builder()
512 .status(StatusCode::OK)
513 .body(Full::new(Bytes::from("OK")))
514 .map_err(|e| OxiHttpError::Http(Arc::new(e)))
515 })
516 }
517
518 pub fn resolve(&self, method: &Method, path: &str) -> Option<HashMap<String, String>> {
527 for (prefix, sub_router) in &self.nested {
529 if let Some(stripped) = path.strip_prefix(prefix.as_str()) {
530 let sub_path = if stripped.is_empty() { "/" } else { stripped };
531 return sub_router.resolve(method, sub_path);
532 }
533 }
534
535 let mut path_matched = false;
537 for route in &self.routes {
538 if let Some(params) = match_pattern(&route.segments, path) {
539 path_matched = true;
540 if route.method == *method {
541 return Some(params);
542 }
543 }
544 }
545
546 if path_matched {
548 return Some(HashMap::new());
549 }
550
551 None
552 }
553
554 pub fn dispatch(&self, req: hyper::Request<Incoming>) -> DispatchFuture<'_> {
556 Box::pin(self.dispatch_inner(req))
557 }
558
559 async fn dispatch_inner(
560 &self,
561 mut req: hyper::Request<Incoming>,
562 ) -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
563 let method = req.method().clone();
564 let path = req.uri().path().to_string();
565
566 if let Some(host_hdr) = req.headers().get(http::header::HOST) {
568 if let Ok(host_str) = host_hdr.to_str() {
569 let host_bare = host_str.split(':').next().unwrap_or(host_str);
570 for (vhost_name, sub_router) in &self.vhosts {
571 if vhost_name.eq_ignore_ascii_case(host_bare) {
572 if sub_router.state.is_none() {
574 if let Some(ref inject_fn) = self.state {
575 inject_fn(req.extensions_mut());
576 }
577 }
578 return sub_router.dispatch(req).await;
579 }
580 }
581 }
582 }
583
584 for (prefix, sub_router) in &self.nested {
589 if path.starts_with(prefix.as_str()) {
590 let sub_path = &path[prefix.len()..];
591 let sub_path = if sub_path.is_empty() { "/" } else { sub_path };
592
593 let new_uri = http::Uri::builder()
595 .path_and_query(sub_path)
596 .build()
597 .map_err(|e| OxiHttpError::Http(Arc::new(e)))?;
598
599 let (mut parts, body) = req.into_parts();
600 parts.uri = new_uri;
601 let mut new_req = hyper::Request::from_parts(parts, body);
602
603 if sub_router.state.is_none() {
607 if let Some(ref inject_fn) = self.state {
608 inject_fn(new_req.extensions_mut());
609 }
610 }
611
612 return sub_router.dispatch(new_req).await;
613 }
614 }
615
616 let mut path_matched = false;
618 for route in &self.routes {
619 if let Some(params) = match_pattern(&route.segments, &path) {
620 path_matched = true;
621 if route.method == method {
622 let mut inner = req;
623 if let Some(ref inject_fn) = self.state {
624 inject_fn(inner.extensions_mut());
625 }
626 let request = Request::new(inner, params);
627 return (route.handler)(request).await;
628 }
629 }
630 }
631
632 if path_matched {
634 if let Some(ref handler) = self.method_not_allowed_handler {
635 let mut inner = req;
636 if let Some(ref inject_fn) = self.state {
637 inject_fn(inner.extensions_mut());
638 }
639 let request = Request::new(inner, HashMap::new());
640 return (handler)(request).await;
641 }
642 return hyper::Response::builder()
643 .status(StatusCode::METHOD_NOT_ALLOWED)
644 .body(Full::new(Bytes::from("Method Not Allowed")))
645 .map_err(|e| OxiHttpError::Http(Arc::new(e)));
646 }
647
648 if let Some(ref handler) = self.fallback {
650 let mut inner = req;
651 if let Some(ref inject_fn) = self.state {
652 inject_fn(inner.extensions_mut());
653 }
654 let request = Request::new(inner, HashMap::new());
655 return (handler)(request).await;
656 }
657
658 hyper::Response::builder()
659 .status(StatusCode::NOT_FOUND)
660 .body(Full::new(Bytes::from("Not Found")))
661 .map_err(|e| OxiHttpError::Http(Arc::new(e)))
662 }
663
664 pub fn route_count(&self) -> usize {
666 self.routes.len()
667 }
668}
669
670impl Default for Router {
671 fn default() -> Self {
672 Self::new()
673 }
674}
675
676impl std::fmt::Debug for Router {
677 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
678 f.debug_struct("Router")
679 .field("routes", &self.routes.len())
680 .field("nested", &self.nested.len())
681 .field("vhosts", &self.vhosts.len())
682 .field("has_state", &self.state.is_some())
683 .finish()
684 }
685}
686
687impl std::fmt::Display for Router {
688 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
689 for (host, sub) in &self.vhosts {
691 writeln!(f, "vhost: {host}")?;
692 for route in &sub.routes {
693 writeln!(f, " {} /<vhost-path>", route.method)?;
694 }
695 }
696 for route in &self.routes {
698 let pattern = route
699 .segments
700 .iter()
701 .map(|s| match s {
702 Segment::Literal(l) => format!("/{l}"),
703 Segment::Param(p) => format!("/:{p}"),
704 Segment::Wildcard(w) => format!("/*{w}"),
705 })
706 .collect::<String>();
707 writeln!(f, "{} {pattern}", route.method)?;
708 }
709 for (prefix, sub) in &self.nested {
711 writeln!(f, "nested: {prefix}")?;
712 for route in &sub.routes {
713 writeln!(f, " {} {prefix}<path>", route.method)?;
714 }
715 }
716 Ok(())
717 }
718}
719
720#[cfg(feature = "tower")]
721impl Router {
722 pub fn into_make_service(self) -> crate::tower_compat::RouterMakeService {
725 crate::tower_compat::RouterMakeService(std::sync::Arc::new(self))
726 }
727}
728
729fn parse_pattern(pattern: &str) -> Vec<Segment> {
731 pattern
732 .split('/')
733 .filter(|s| !s.is_empty())
734 .map(|s| {
735 if let Some(param) = s.strip_prefix(':') {
736 Segment::Param(param.to_string())
737 } else if let Some(wildcard) = s.strip_prefix('*') {
738 Segment::Wildcard(wildcard.to_string())
739 } else {
740 Segment::Literal(s.to_string())
741 }
742 })
743 .collect()
744}
745
746fn match_pattern(segments: &[Segment], path: &str) -> Option<HashMap<String, String>> {
748 let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
749 let mut params = HashMap::new();
750 let mut path_idx = 0;
751
752 for seg in segments {
753 match seg {
754 Segment::Literal(expected) => {
755 if path_idx >= path_segments.len() || path_segments[path_idx] != expected.as_str() {
756 return None;
757 }
758 path_idx += 1;
759 }
760 Segment::Param(name) => {
761 if path_idx >= path_segments.len() {
762 return None;
763 }
764 params.insert(name.clone(), path_segments[path_idx].to_string());
765 path_idx += 1;
766 }
767 Segment::Wildcard(name) => {
768 if path_idx >= path_segments.len() {
769 return None;
770 }
771 let rest = path_segments[path_idx..].join("/");
772 params.insert(name.clone(), rest);
773 return Some(params);
774 }
775 }
776 }
777
778 if path_idx == path_segments.len() {
780 Some(params)
781 } else {
782 None
783 }
784}
785
786#[cfg(test)]
787mod tests {
788 use super::*;
789
790 #[test]
793 fn test_negotiate_returns_json_for_json_accept() {
794 let mut headers = http::HeaderMap::new();
795 headers.insert(
796 http::header::ACCEPT,
797 http::HeaderValue::from_static("application/json"),
798 );
799 let supported = vec![
800 oxihttp_core::ContentType::Json,
801 oxihttp_core::ContentType::Html(None),
802 ];
803 let result = negotiate_from_headers(&headers, &supported);
804 assert_eq!(result, Some(oxihttp_core::ContentType::Json));
805 }
806
807 #[test]
808 fn test_negotiate_returns_none_for_unsupported() {
809 let mut headers = http::HeaderMap::new();
810 headers.insert(
811 http::header::ACCEPT,
812 http::HeaderValue::from_static("image/png"),
813 );
814 let supported = vec![
815 oxihttp_core::ContentType::Json,
816 oxihttp_core::ContentType::Html(None),
817 ];
818 let result = negotiate_from_headers(&headers, &supported);
819 assert_eq!(result, None);
820 }
821
822 #[test]
825 fn test_parse_literal_pattern() {
826 let segments = parse_pattern("/users/list");
827 assert_eq!(segments.len(), 2);
828 assert!(matches!(&segments[0], Segment::Literal(s) if s == "users"));
829 assert!(matches!(&segments[1], Segment::Literal(s) if s == "list"));
830 }
831
832 #[test]
833 fn test_parse_param_pattern() {
834 let segments = parse_pattern("/users/:id");
835 assert_eq!(segments.len(), 2);
836 assert!(matches!(&segments[0], Segment::Literal(s) if s == "users"));
837 assert!(matches!(&segments[1], Segment::Param(s) if s == "id"));
838 }
839
840 #[test]
841 fn test_parse_wildcard_pattern() {
842 let segments = parse_pattern("/static/*path");
843 assert_eq!(segments.len(), 2);
844 assert!(matches!(&segments[0], Segment::Literal(s) if s == "static"));
845 assert!(matches!(&segments[1], Segment::Wildcard(s) if s == "path"));
846 }
847
848 #[test]
849 fn test_match_literal() {
850 let segments = parse_pattern("/users/list");
851 let result = match_pattern(&segments, "/users/list");
852 assert!(result.is_some());
853 assert!(result.as_ref().is_some_and(|p| p.is_empty()));
854 }
855
856 #[test]
857 fn test_match_literal_no_match() {
858 let segments = parse_pattern("/users/list");
859 assert!(match_pattern(&segments, "/users/other").is_none());
860 assert!(match_pattern(&segments, "/users").is_none());
861 }
862
863 #[test]
864 fn test_match_param() {
865 let segments = parse_pattern("/users/:id");
866 let result = match_pattern(&segments, "/users/42");
867 assert!(result.is_some());
868 let params = result.expect("should match");
869 assert_eq!(params.get("id"), Some(&"42".to_string()));
870 }
871
872 #[test]
873 fn test_match_wildcard() {
874 let segments = parse_pattern("/static/*path");
875 let result = match_pattern(&segments, "/static/css/style.css");
876 assert!(result.is_some());
877 let params = result.expect("should match");
878 assert_eq!(params.get("path"), Some(&"css/style.css".to_string()));
879 }
880
881 #[test]
882 fn test_no_match_extra_segments() {
883 let segments = parse_pattern("/users");
884 assert!(match_pattern(&segments, "/users/extra").is_none());
885 }
886
887 #[test]
888 fn test_percent_decode() {
889 assert_eq!(percent_decode("hello%20world"), "hello world");
890 assert_eq!(percent_decode("a+b"), "a b");
891 assert_eq!(percent_decode("plain"), "plain");
892 }
893}
894
895#[cfg(test)]
896mod resolve_tests {
897 use super::*;
898
899 #[tokio::test]
900 async fn test_resolve_match_and_miss() {
901 use oxihttp_core::OxiHttpError;
902 async fn dummy(
903 _req: Request,
904 ) -> Result<hyper::Response<http_body_util::Full<bytes::Bytes>>, OxiHttpError> {
905 Ok(hyper::Response::new(http_body_util::Full::new(
906 bytes::Bytes::new(),
907 )))
908 }
909 let router = Router::new().get("/hello", dummy).get("/users/:id", dummy);
910
911 let method = http::Method::GET;
912 assert!(router.resolve(&method, "/hello").is_some());
914 let params = router.resolve(&method, "/users/42").expect("should match");
916 assert_eq!(params.get("id").map(|s| s.as_str()), Some("42"));
917 assert!(router.resolve(&method, "/nonexistent").is_none());
919 let post = http::Method::POST;
921 let result = router.resolve(&post, "/hello");
922 assert!(result.is_some());
923 assert!(result.unwrap().is_empty());
924 }
925}