Skip to main content

oxihttp_server/
router.rs

1//! HTTP request router with path parameters and method-based routing.
2
3use bytes::Bytes;
4use http::{Method, StatusCode};
5use http_body_util::Full;
6use hyper::body::Incoming;
7use std::collections::HashMap;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use oxihttp_core::OxiHttpError;
13
14/// Type alias for the state injection function stored in `Router`.
15///
16/// The closure receives mutable access to `http::Extensions` and inserts the
17/// typed `Arc<T>` so that `Request::state::<T>()` can retrieve it later.
18type StateFn = Box<dyn Fn(&mut http::Extensions) + Send + Sync>;
19
20/// Type alias for the handler function signature.
21pub type HandlerFn = Arc<
22    dyn Fn(
23            Request,
24        ) -> Pin<
25            Box<dyn Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send>,
26        > + Send
27        + Sync,
28>;
29
30/// A request with parsed path parameters and query string.
31#[derive(Debug)]
32pub struct Request {
33    inner: hyper::Request<Incoming>,
34    path_params: HashMap<String, String>,
35}
36
37impl Request {
38    /// Create a new `Request` wrapping a hyper request.
39    pub fn new(inner: hyper::Request<Incoming>, path_params: HashMap<String, String>) -> Self {
40        Self { inner, path_params }
41    }
42
43    /// The HTTP method.
44    pub fn method(&self) -> &Method {
45        self.inner.method()
46    }
47
48    /// The request URI.
49    pub fn uri(&self) -> &http::Uri {
50        self.inner.uri()
51    }
52
53    /// The request headers.
54    pub fn headers(&self) -> &http::HeaderMap {
55        self.inner.headers()
56    }
57
58    /// The path portion of the URI.
59    pub fn path(&self) -> &str {
60        self.inner.uri().path()
61    }
62
63    /// Get a path parameter by name (e.g. from `/users/:id`).
64    pub fn param(&self, name: &str) -> Option<&str> {
65        self.path_params.get(name).map(|s| s.as_str())
66    }
67
68    /// Get all path parameters.
69    pub fn params(&self) -> &HashMap<String, String> {
70        &self.path_params
71    }
72
73    /// Parse query parameters from the URI.
74    pub fn query_params(&self) -> HashMap<String, String> {
75        self.inner
76            .uri()
77            .query()
78            .map(|q| {
79                q.split('&')
80                    .filter_map(|pair| {
81                        let (k, v) = pair.split_once('=')?;
82                        Some((percent_decode(k), percent_decode(v)))
83                    })
84                    .collect()
85            })
86            .unwrap_or_default()
87    }
88
89    /// Get a single query parameter by name.
90    pub fn query(&self, name: &str) -> Option<String> {
91        self.query_params().remove(name)
92    }
93
94    /// Consume the request and return the inner hyper request.
95    pub fn into_inner(self) -> hyper::Request<Incoming> {
96        self.inner
97    }
98
99    /// Consume the body and return raw bytes.
100    pub async fn body_bytes(self) -> Result<Bytes, OxiHttpError> {
101        use http_body_util::BodyExt;
102        self.inner
103            .into_body()
104            .collect()
105            .await
106            .map(|c| c.to_bytes())
107            .map_err(|e| OxiHttpError::Body(e.to_string()))
108    }
109
110    /// Consume the body and return it as a UTF-8 string.
111    pub async fn body_text(self) -> Result<String, OxiHttpError> {
112        let bytes = self.body_bytes().await?;
113        String::from_utf8(bytes.to_vec())
114            .map_err(|e| OxiHttpError::Body(format!("invalid UTF-8: {e}")))
115    }
116
117    /// Consume the body and deserialize from JSON.
118    pub async fn body_json<T: serde::de::DeserializeOwned>(self) -> Result<T, OxiHttpError> {
119        let bytes = self.body_bytes().await?;
120        serde_json::from_slice(&bytes).map_err(|e| OxiHttpError::Json(e.to_string()))
121    }
122
123    /// Retrieve the shared application state of type `T`.
124    ///
125    /// Returns `Some(Arc<T>)` when `Router::with_state::<T>()` was used to
126    /// register a value of type `T` before starting the server.  Returns `None`
127    /// when no state of that type was injected.
128    pub fn state<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
129        self.inner.extensions().get::<Arc<T>>().cloned()
130    }
131
132    /// Retrieve a per-request extension of type `T`.
133    ///
134    /// Handlers and middleware can store arbitrary values in request extensions
135    /// via `req.extensions_mut().insert(value)`.  This accessor clones the
136    /// stored value and returns it.
137    pub fn extension<T: Clone + Send + Sync + 'static>(&self) -> Option<T> {
138        self.inner.extensions().get::<T>().cloned()
139    }
140
141    /// Access the raw request extensions map (read-only).
142    pub fn extensions(&self) -> &http::Extensions {
143        self.inner.extensions()
144    }
145
146    /// Access the raw request extensions map (mutable).
147    ///
148    /// Useful in middleware or handlers that need to attach data for downstream
149    /// consumers.
150    pub fn extensions_mut(&mut self) -> &mut http::Extensions {
151        self.inner.extensions_mut()
152    }
153
154    /// Borrow the non-body request parts as a [`RequestParts`][crate::extractor::RequestParts].
155    ///
156    /// Used internally by [`Request::extract`] and available for manual
157    /// extraction via [`FromRequestParts`][crate::extractor::FromRequestParts].
158    pub fn parts(&self) -> crate::extractor::RequestParts<'_> {
159        crate::extractor::RequestParts {
160            method: self.inner.method(),
161            uri: self.inner.uri(),
162            headers: self.inner.headers(),
163            path_params: &self.path_params,
164        }
165    }
166
167    /// Extract a value implementing [`FromRequestParts`][crate::extractor::FromRequestParts]
168    /// from this request.
169    ///
170    /// # Errors
171    ///
172    /// Returns the extractor's `Rejection` type (which converts to [`OxiHttpError`])
173    /// when extraction fails.
174    pub fn extract<T: crate::extractor::FromRequestParts>(&self) -> Result<T, T::Rejection> {
175        T::from_request_parts(&self.parts())
176    }
177
178    /// Negotiate the best [`ContentType`][oxihttp_core::ContentType] from the
179    /// request's `Accept` header and the supplied list of supported types.
180    ///
181    /// Returns `None` when no supported type satisfies the client's `Accept`
182    /// header.  Falls back to `*/*` matching when the header is absent.
183    pub fn negotiate(
184        &self,
185        supported: &[oxihttp_core::ContentType],
186    ) -> Option<oxihttp_core::ContentType> {
187        negotiate_from_headers(self.headers(), supported)
188    }
189
190    /// Get TLS peer certificate information for the current connection.
191    ///
192    /// Returns `Some(Arc<PeerCertInfo>)` when the request arrived over a TLS
193    /// connection.  For plain-TLS connections (no client auth) the returned
194    /// struct will have an empty `peer_certificates` vec.  Returns `None` on
195    /// non-TLS connections.
196    #[cfg(feature = "tls")]
197    pub fn tls_info(&self) -> Option<Arc<crate::tls::PeerCertInfo>> {
198        self.inner
199            .extensions()
200            .get::<Arc<crate::tls::PeerCertInfo>>()
201            .cloned()
202    }
203
204    /// Get the peer certificate chain (DER-encoded, leaf first) from the mTLS handshake.
205    ///
206    /// Returns `Some(Vec<CertificateDer>)` only when the server required client
207    /// authentication (`TlsConfig::with_client_auth`) and the client presented a
208    /// valid certificate chain.  Returns `None` for non-TLS connections or when
209    /// no client certificate was presented.
210    #[cfg(feature = "tls")]
211    pub fn peer_certificates(&self) -> Option<Vec<rustls_pki_types::CertificateDer<'static>>> {
212        self.tls_info().and_then(|info| {
213            if info.peer_certificates.is_empty() {
214                None
215            } else {
216                Some(info.peer_certificates.clone())
217            }
218        })
219    }
220
221    /// Get typed TLS connection information (version, cipher suite, SNI) for the current request.
222    ///
223    /// Returns `Some` when the request arrived over a TLS connection, `None` otherwise.
224    ///
225    /// The returned [`oxitls::ConnectionInfo`] contains:
226    /// - `version` — negotiated TLS version (`TlsVersion::Tls13`, etc.)
227    /// - `cipher_suite` — negotiated cipher suite
228    /// - `alpn_protocol` — negotiated ALPN protocol bytes
229    /// - `sni` — SNI hostname sent by the client
230    /// - `peer_certificates` — DER-encoded client certificate chain (mTLS only)
231    #[cfg(feature = "tls")]
232    pub fn tls_connection_info(&self) -> Option<oxitls::ConnectionInfo> {
233        self.tls_info().map(|info| {
234            let mut ci = oxitls::ConnectionInfo::new();
235            if let Some(v) = info.version {
236                ci = ci.with_version(v);
237            }
238            if let Some(cs) = info.cipher_suite {
239                ci = ci.with_cipher_suite(cs);
240            }
241            if let Some(ref alpn) = info.alpn_protocol {
242                ci = ci.with_alpn_protocol(alpn.clone());
243            }
244            if let Some(ref sni) = info.sni {
245                ci = ci.with_sni(sni.clone());
246            }
247            if !info.peer_certificates.is_empty() {
248                let der_vecs: Vec<Vec<u8>> = info
249                    .peer_certificates
250                    .iter()
251                    .map(|c| c.as_ref().to_vec())
252                    .collect();
253                ci = ci.with_peer_certificates(der_vecs);
254            }
255            ci
256        })
257    }
258}
259
260/// Negotiate the best content type from the request's `Accept` header.
261///
262/// Extracted as a free function so unit tests can call it without constructing
263/// a full `hyper::Request` (which requires a live hyper body).
264fn negotiate_from_headers(
265    headers: &http::HeaderMap,
266    supported: &[oxihttp_core::ContentType],
267) -> Option<oxihttp_core::ContentType> {
268    let accept = headers
269        .get(http::header::ACCEPT)
270        .and_then(|v| v.to_str().ok())
271        .unwrap_or("*/*");
272    oxihttp_core::content_type::negotiate_content_type(accept, supported)
273}
274
275/// Simple percent-decoding for query parameters.
276fn percent_decode(s: &str) -> String {
277    let mut result = String::with_capacity(s.len());
278    let mut chars = s.bytes();
279    while let Some(b) = chars.next() {
280        if b == b'%' {
281            let hi = chars.next();
282            let lo = chars.next();
283            if let (Some(h), Some(l)) = (hi, lo) {
284                let hex = [h, l];
285                if let Ok(decoded) = u8::from_str_radix(std::str::from_utf8(&hex).unwrap_or(""), 16)
286                {
287                    result.push(decoded as char);
288                    continue;
289                }
290            }
291            result.push('%');
292        } else if b == b'+' {
293            result.push(' ');
294        } else {
295            result.push(b as char);
296        }
297    }
298    result
299}
300
301/// A single route definition: pattern + method + handler.
302struct Route {
303    method: Method,
304    segments: Vec<Segment>,
305    handler: HandlerFn,
306}
307
308/// A segment in a route pattern.
309#[derive(Debug, Clone)]
310enum Segment {
311    /// A literal path segment (e.g. "users").
312    Literal(String),
313    /// A parameter segment (e.g. ":id").
314    Param(String),
315    /// A wildcard segment (e.g. "*path") that matches the rest.
316    Wildcard(String),
317}
318
319/// Future type returned by `Router::dispatch`.
320pub type DispatchFuture<'a> =
321    Pin<Box<dyn Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'a>>;
322
323/// HTTP request router with path-parameter extraction and method-based dispatch.
324pub struct Router {
325    routes: Vec<Route>,
326    nested: Vec<(String, Router)>,
327    vhosts: Vec<(String, Router)>,
328    fallback: Option<HandlerFn>,
329    method_not_allowed_handler: Option<HandlerFn>,
330    /// Optional state injection function.  When `Some`, it is called with the
331    /// request's `Extensions` map immediately before dispatching to a handler,
332    /// inserting the typed `Arc<T>` for later retrieval via `Request::state::<T>()`.
333    state: Option<StateFn>,
334}
335
336impl Router {
337    /// Create a new empty router.
338    pub fn new() -> Self {
339        Self {
340            routes: Vec::new(),
341            nested: Vec::new(),
342            vhosts: Vec::new(),
343            fallback: None,
344            method_not_allowed_handler: None,
345            state: None,
346        }
347    }
348
349    /// Attach application state of type `T` to this router.
350    ///
351    /// The state is wrapped in an `Arc<T>` and injected into every request's
352    /// extensions map just before the handler is invoked.  Handlers retrieve
353    /// it with `req.state::<T>()`.
354    ///
355    /// Nested routers that do not have their own state automatically inherit
356    /// this router's state during dispatch.
357    ///
358    /// ```rust,no_run
359    /// # use oxihttp_server::{Router, router::Request};
360    /// # use std::sync::Arc;
361    /// #[derive(Clone)]
362    /// struct AppState { db_url: String }
363    ///
364    /// let state = AppState { db_url: "postgres://localhost/mydb".into() };
365    /// let router = Router::new()
366    ///     .with_state(state)
367    ///     .get("/", |req: Request| async move {
368    ///         let s = req.state::<AppState>().expect("state present");
369    ///         oxihttp_server::response::text_response(&s.db_url)
370    ///     });
371    /// ```
372    pub fn with_state<T: Clone + Send + Sync + 'static>(mut self, state: T) -> Self {
373        let arc = Arc::new(state);
374        self.state = Some(Box::new(move |ext: &mut http::Extensions| {
375            ext.insert(Arc::clone(&arc));
376        }));
377        self
378    }
379
380    /// Register a route for the given method and path pattern.
381    ///
382    /// Path patterns support:
383    /// - Literal segments: `/users/list`
384    /// - Parameters: `/users/:id`
385    /// - Wildcards: `/static/*path`
386    pub fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
387    where
388        F: Fn(Request) -> Fut + Send + Sync + 'static,
389        Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
390    {
391        let segments = parse_pattern(path);
392        let handler: HandlerFn = Arc::new(move |req| Box::pin(handler(req)));
393        self.routes.push(Route {
394            method,
395            segments,
396            handler,
397        });
398        self
399    }
400
401    /// Register a GET route.
402    pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
403    where
404        F: Fn(Request) -> Fut + Send + Sync + 'static,
405        Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
406    {
407        self.route(Method::GET, path, handler)
408    }
409
410    /// Register a POST route.
411    pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
412    where
413        F: Fn(Request) -> Fut + Send + Sync + 'static,
414        Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
415    {
416        self.route(Method::POST, path, handler)
417    }
418
419    /// Register a PUT route.
420    pub fn put<F, Fut>(self, path: &str, handler: F) -> Self
421    where
422        F: Fn(Request) -> Fut + Send + Sync + 'static,
423        Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
424    {
425        self.route(Method::PUT, path, handler)
426    }
427
428    /// Register a DELETE route.
429    pub fn delete<F, Fut>(self, path: &str, handler: F) -> Self
430    where
431        F: Fn(Request) -> Fut + Send + Sync + 'static,
432        Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
433    {
434        self.route(Method::DELETE, path, handler)
435    }
436
437    /// Register a PATCH route.
438    pub fn patch<F, Fut>(self, path: &str, handler: F) -> Self
439    where
440        F: Fn(Request) -> Fut + Send + Sync + 'static,
441        Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
442    {
443        self.route(Method::PATCH, path, handler)
444    }
445
446    /// Register a HEAD route.
447    pub fn head<F, Fut>(self, path: &str, handler: F) -> Self
448    where
449        F: Fn(Request) -> Fut + Send + Sync + 'static,
450        Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
451    {
452        self.route(Method::HEAD, path, handler)
453    }
454
455    /// Nest a sub-router under the given prefix.
456    pub fn nest(mut self, prefix: &str, router: Router) -> Self {
457        let prefix = prefix.trim_end_matches('/').to_string();
458        self.nested.push((prefix, router));
459        self
460    }
461
462    /// Route requests with the given `Host` header value to `router`.
463    ///
464    /// The `host` value is matched case-insensitively against the bare hostname
465    /// (port suffix stripped).  When a match is found the request is forwarded
466    /// to `router` without any path rewriting.  Virtual-host dispatch happens
467    /// before nested-prefix dispatch.
468    ///
469    /// # Example
470    ///
471    /// ```rust,no_run
472    /// # use oxihttp_server::Router;
473    /// let api = Router::new().get("/v1", |_req| async {
474    ///     oxihttp_server::response::text_response("api")
475    /// });
476    /// let web = Router::new().get("/", |_req| async {
477    ///     oxihttp_server::response::text_response("web")
478    /// });
479    /// let router = Router::new()
480    ///     .host("api.example.com", api)
481    ///     .host("example.com", web);
482    /// ```
483    pub fn host(mut self, host: &str, router: Router) -> Self {
484        self.vhosts.push((host.to_owned(), router));
485        self
486    }
487
488    /// Set a fallback handler for routes that don't match (custom 404).
489    pub fn fallback<F, Fut>(mut self, handler: F) -> Self
490    where
491        F: Fn(Request) -> Fut + Send + Sync + 'static,
492        Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
493    {
494        self.fallback = Some(Arc::new(move |req| Box::pin(handler(req))));
495        self
496    }
497
498    /// Set a handler for method-not-allowed (405) responses.
499    pub fn method_not_allowed<F, Fut>(mut self, handler: F) -> Self
500    where
501        F: Fn(Request) -> Fut + Send + Sync + 'static,
502        Fut: Future<Output = Result<hyper::Response<Full<Bytes>>, OxiHttpError>> + Send + 'static,
503    {
504        self.method_not_allowed_handler = Some(Arc::new(move |req| Box::pin(handler(req))));
505        self
506    }
507
508    /// A simple health-check route returning 200 OK.
509    pub fn health(self, path: &str) -> Self {
510        self.get(path, |_req| async {
511            hyper::Response::builder()
512                .status(StatusCode::OK)
513                .body(Full::new(Bytes::from("OK")))
514                .map_err(|e| OxiHttpError::Http(Arc::new(e)))
515        })
516    }
517
518    /// Match a request path against registered routes without dispatching.
519    ///
520    /// Replicates the O(n) dispatch scan for use in benchmarks and introspection.
521    /// Returns extracted path parameters on a successful match, `None` on no match.
522    ///
523    /// When the path is found but the method is not registered the method returns
524    /// `Some(HashMap::new())` — an empty map — to signal a 405 situation without
525    /// actually dispatching.
526    pub fn resolve(&self, method: &Method, path: &str) -> Option<HashMap<String, String>> {
527        // Check nested prefixes first (delegate to sub-router if matched).
528        for (prefix, sub_router) in &self.nested {
529            if let Some(stripped) = path.strip_prefix(prefix.as_str()) {
530                let sub_path = if stripped.is_empty() { "/" } else { stripped };
531                return sub_router.resolve(method, sub_path);
532            }
533        }
534
535        // Scan routes O(n).
536        let mut path_matched = false;
537        for route in &self.routes {
538            if let Some(params) = match_pattern(&route.segments, path) {
539                path_matched = true;
540                if route.method == *method {
541                    return Some(params);
542                }
543            }
544        }
545
546        // Path existed but method not allowed: return empty params to signal 405.
547        if path_matched {
548            return Some(HashMap::new());
549        }
550
551        None
552    }
553
554    /// Dispatch an incoming request through the router.
555    pub fn dispatch(&self, req: hyper::Request<Incoming>) -> DispatchFuture<'_> {
556        Box::pin(self.dispatch_inner(req))
557    }
558
559    async fn dispatch_inner(
560        &self,
561        mut req: hyper::Request<Incoming>,
562    ) -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
563        let method = req.method().clone();
564        let path = req.uri().path().to_string();
565
566        // Virtual host dispatch (runs before nested-prefix dispatch).
567        if let Some(host_hdr) = req.headers().get(http::header::HOST) {
568            if let Ok(host_str) = host_hdr.to_str() {
569                let host_bare = host_str.split(':').next().unwrap_or(host_str);
570                for (vhost_name, sub_router) in &self.vhosts {
571                    if vhost_name.eq_ignore_ascii_case(host_bare) {
572                        // Inherit parent state when the vhost sub-router has none.
573                        if sub_router.state.is_none() {
574                            if let Some(ref inject_fn) = self.state {
575                                inject_fn(req.extensions_mut());
576                            }
577                        }
578                        return sub_router.dispatch(req).await;
579                    }
580                }
581            }
582        }
583
584        // Try nested routers.
585        // If the nested router has no state of its own, inject the parent's
586        // state before forwarding the request so handlers in the sub-router
587        // can still call `req.state::<T>()`.
588        for (prefix, sub_router) in &self.nested {
589            if path.starts_with(prefix.as_str()) {
590                let sub_path = &path[prefix.len()..];
591                let sub_path = if sub_path.is_empty() { "/" } else { sub_path };
592
593                // Rebuild URI with sub-path.
594                let new_uri = http::Uri::builder()
595                    .path_and_query(sub_path)
596                    .build()
597                    .map_err(|e| OxiHttpError::Http(Arc::new(e)))?;
598
599                let (mut parts, body) = req.into_parts();
600                parts.uri = new_uri;
601                let mut new_req = hyper::Request::from_parts(parts, body);
602
603                // Inherit parent state only when the nested router does not
604                // define its own state (the nested router's state takes
605                // precedence when set).
606                if sub_router.state.is_none() {
607                    if let Some(ref inject_fn) = self.state {
608                        inject_fn(new_req.extensions_mut());
609                    }
610                }
611
612                return sub_router.dispatch(new_req).await;
613            }
614        }
615
616        // Try matching routes.
617        let mut path_matched = false;
618        for route in &self.routes {
619            if let Some(params) = match_pattern(&route.segments, &path) {
620                path_matched = true;
621                if route.method == method {
622                    let mut inner = req;
623                    if let Some(ref inject_fn) = self.state {
624                        inject_fn(inner.extensions_mut());
625                    }
626                    let request = Request::new(inner, params);
627                    return (route.handler)(request).await;
628                }
629            }
630        }
631
632        // Path matched but method didn't -> 405.
633        if path_matched {
634            if let Some(ref handler) = self.method_not_allowed_handler {
635                let mut inner = req;
636                if let Some(ref inject_fn) = self.state {
637                    inject_fn(inner.extensions_mut());
638                }
639                let request = Request::new(inner, HashMap::new());
640                return (handler)(request).await;
641            }
642            return hyper::Response::builder()
643                .status(StatusCode::METHOD_NOT_ALLOWED)
644                .body(Full::new(Bytes::from("Method Not Allowed")))
645                .map_err(|e| OxiHttpError::Http(Arc::new(e)));
646        }
647
648        // No match at all -> fallback or 404.
649        if let Some(ref handler) = self.fallback {
650            let mut inner = req;
651            if let Some(ref inject_fn) = self.state {
652                inject_fn(inner.extensions_mut());
653            }
654            let request = Request::new(inner, HashMap::new());
655            return (handler)(request).await;
656        }
657
658        hyper::Response::builder()
659            .status(StatusCode::NOT_FOUND)
660            .body(Full::new(Bytes::from("Not Found")))
661            .map_err(|e| OxiHttpError::Http(Arc::new(e)))
662    }
663
664    /// Return the number of registered routes (not including nested).
665    pub fn route_count(&self) -> usize {
666        self.routes.len()
667    }
668}
669
670impl Default for Router {
671    fn default() -> Self {
672        Self::new()
673    }
674}
675
676impl std::fmt::Debug for Router {
677    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
678        f.debug_struct("Router")
679            .field("routes", &self.routes.len())
680            .field("nested", &self.nested.len())
681            .field("vhosts", &self.vhosts.len())
682            .field("has_state", &self.state.is_some())
683            .finish()
684    }
685}
686
687impl std::fmt::Display for Router {
688    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
689        // List virtual hosts and their routes.
690        for (host, sub) in &self.vhosts {
691            writeln!(f, "vhost: {host}")?;
692            for route in &sub.routes {
693                writeln!(f, "  {} /<vhost-path>", route.method)?;
694            }
695        }
696        // List top-level routes.
697        for route in &self.routes {
698            let pattern = route
699                .segments
700                .iter()
701                .map(|s| match s {
702                    Segment::Literal(l) => format!("/{l}"),
703                    Segment::Param(p) => format!("/:{p}"),
704                    Segment::Wildcard(w) => format!("/*{w}"),
705                })
706                .collect::<String>();
707            writeln!(f, "{} {pattern}", route.method)?;
708        }
709        // List nested prefixes.
710        for (prefix, sub) in &self.nested {
711            writeln!(f, "nested: {prefix}")?;
712            for route in &sub.routes {
713                writeln!(f, "  {} {prefix}<path>", route.method)?;
714            }
715        }
716        Ok(())
717    }
718}
719
720#[cfg(feature = "tower")]
721impl Router {
722    /// Wrap this router in a `RouterMakeService` factory for use with
723    /// tower-compatible runtimes or test harnesses.
724    pub fn into_make_service(self) -> crate::tower_compat::RouterMakeService {
725        crate::tower_compat::RouterMakeService(std::sync::Arc::new(self))
726    }
727}
728
729/// Parse a route pattern string into segments.
730fn parse_pattern(pattern: &str) -> Vec<Segment> {
731    pattern
732        .split('/')
733        .filter(|s| !s.is_empty())
734        .map(|s| {
735            if let Some(param) = s.strip_prefix(':') {
736                Segment::Param(param.to_string())
737            } else if let Some(wildcard) = s.strip_prefix('*') {
738                Segment::Wildcard(wildcard.to_string())
739            } else {
740                Segment::Literal(s.to_string())
741            }
742        })
743        .collect()
744}
745
746/// Try to match a path against a route pattern, extracting parameters.
747fn match_pattern(segments: &[Segment], path: &str) -> Option<HashMap<String, String>> {
748    let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
749    let mut params = HashMap::new();
750    let mut path_idx = 0;
751
752    for seg in segments {
753        match seg {
754            Segment::Literal(expected) => {
755                if path_idx >= path_segments.len() || path_segments[path_idx] != expected.as_str() {
756                    return None;
757                }
758                path_idx += 1;
759            }
760            Segment::Param(name) => {
761                if path_idx >= path_segments.len() {
762                    return None;
763                }
764                params.insert(name.clone(), path_segments[path_idx].to_string());
765                path_idx += 1;
766            }
767            Segment::Wildcard(name) => {
768                if path_idx >= path_segments.len() {
769                    return None;
770                }
771                let rest = path_segments[path_idx..].join("/");
772                params.insert(name.clone(), rest);
773                return Some(params);
774            }
775        }
776    }
777
778    // All segments consumed; check that path is also fully consumed
779    if path_idx == path_segments.len() {
780        Some(params)
781    } else {
782        None
783    }
784}
785
786#[cfg(test)]
787mod tests {
788    use super::*;
789
790    // ---- negotiate_from_headers tests ----------------------------------------
791
792    #[test]
793    fn test_negotiate_returns_json_for_json_accept() {
794        let mut headers = http::HeaderMap::new();
795        headers.insert(
796            http::header::ACCEPT,
797            http::HeaderValue::from_static("application/json"),
798        );
799        let supported = vec![
800            oxihttp_core::ContentType::Json,
801            oxihttp_core::ContentType::Html(None),
802        ];
803        let result = negotiate_from_headers(&headers, &supported);
804        assert_eq!(result, Some(oxihttp_core::ContentType::Json));
805    }
806
807    #[test]
808    fn test_negotiate_returns_none_for_unsupported() {
809        let mut headers = http::HeaderMap::new();
810        headers.insert(
811            http::header::ACCEPT,
812            http::HeaderValue::from_static("image/png"),
813        );
814        let supported = vec![
815            oxihttp_core::ContentType::Json,
816            oxihttp_core::ContentType::Html(None),
817        ];
818        let result = negotiate_from_headers(&headers, &supported);
819        assert_eq!(result, None);
820    }
821
822    // ---- Pattern parse tests -------------------------------------------------
823
824    #[test]
825    fn test_parse_literal_pattern() {
826        let segments = parse_pattern("/users/list");
827        assert_eq!(segments.len(), 2);
828        assert!(matches!(&segments[0], Segment::Literal(s) if s == "users"));
829        assert!(matches!(&segments[1], Segment::Literal(s) if s == "list"));
830    }
831
832    #[test]
833    fn test_parse_param_pattern() {
834        let segments = parse_pattern("/users/:id");
835        assert_eq!(segments.len(), 2);
836        assert!(matches!(&segments[0], Segment::Literal(s) if s == "users"));
837        assert!(matches!(&segments[1], Segment::Param(s) if s == "id"));
838    }
839
840    #[test]
841    fn test_parse_wildcard_pattern() {
842        let segments = parse_pattern("/static/*path");
843        assert_eq!(segments.len(), 2);
844        assert!(matches!(&segments[0], Segment::Literal(s) if s == "static"));
845        assert!(matches!(&segments[1], Segment::Wildcard(s) if s == "path"));
846    }
847
848    #[test]
849    fn test_match_literal() {
850        let segments = parse_pattern("/users/list");
851        let result = match_pattern(&segments, "/users/list");
852        assert!(result.is_some());
853        assert!(result.as_ref().is_some_and(|p| p.is_empty()));
854    }
855
856    #[test]
857    fn test_match_literal_no_match() {
858        let segments = parse_pattern("/users/list");
859        assert!(match_pattern(&segments, "/users/other").is_none());
860        assert!(match_pattern(&segments, "/users").is_none());
861    }
862
863    #[test]
864    fn test_match_param() {
865        let segments = parse_pattern("/users/:id");
866        let result = match_pattern(&segments, "/users/42");
867        assert!(result.is_some());
868        let params = result.expect("should match");
869        assert_eq!(params.get("id"), Some(&"42".to_string()));
870    }
871
872    #[test]
873    fn test_match_wildcard() {
874        let segments = parse_pattern("/static/*path");
875        let result = match_pattern(&segments, "/static/css/style.css");
876        assert!(result.is_some());
877        let params = result.expect("should match");
878        assert_eq!(params.get("path"), Some(&"css/style.css".to_string()));
879    }
880
881    #[test]
882    fn test_no_match_extra_segments() {
883        let segments = parse_pattern("/users");
884        assert!(match_pattern(&segments, "/users/extra").is_none());
885    }
886
887    #[test]
888    fn test_percent_decode() {
889        assert_eq!(percent_decode("hello%20world"), "hello world");
890        assert_eq!(percent_decode("a+b"), "a b");
891        assert_eq!(percent_decode("plain"), "plain");
892    }
893}
894
895#[cfg(test)]
896mod resolve_tests {
897    use super::*;
898
899    #[tokio::test]
900    async fn test_resolve_match_and_miss() {
901        use oxihttp_core::OxiHttpError;
902        async fn dummy(
903            _req: Request,
904        ) -> Result<hyper::Response<http_body_util::Full<bytes::Bytes>>, OxiHttpError> {
905            Ok(hyper::Response::new(http_body_util::Full::new(
906                bytes::Bytes::new(),
907            )))
908        }
909        let router = Router::new().get("/hello", dummy).get("/users/:id", dummy);
910
911        let method = http::Method::GET;
912        // Exact match
913        assert!(router.resolve(&method, "/hello").is_some());
914        // Param match
915        let params = router.resolve(&method, "/users/42").expect("should match");
916        assert_eq!(params.get("id").map(|s| s.as_str()), Some("42"));
917        // Miss
918        assert!(router.resolve(&method, "/nonexistent").is_none());
919        // Wrong method: POST on GET route returns Some(empty) to signal 405
920        let post = http::Method::POST;
921        let result = router.resolve(&post, "/hello");
922        assert!(result.is_some());
923        assert!(result.unwrap().is_empty());
924    }
925}