Skip to main content

product_os_router/
lib.rs

1//! # Product OS Router
2//!
3//! A fully-featured HTTP router built on top of [Axum](https://github.com/tokio-rs/axum) and [Tower](https://github.com/tower-rs/tower),
4//! providing convenient helper methods for creating HTTP, HTTPS, WebSocket, and Server-Sent Events (SSE) servers.
5//!
6//! ## Features
7//!
8//! - **Easy Routing**: Simple, chainable API for defining routes
9//! - **HTTP Methods**: Support for GET, POST, PUT, PATCH, DELETE, and more
10//! - **State Management**: Both stateful and stateless routing options
11//! - **WebSockets**: Built-in WebSocket support (with `ws` feature)
12//! - **Server-Sent Events**: SSE support (with `sse` feature)
13//! - **CORS**: Cross-Origin Resource Sharing support (with `cors` feature)
14//! - **Middleware**: Custom middleware support (with `middleware` feature)
15//! - **Default Headers**: Apply default headers to all responses
16//! - **Protocol Upgrade**: HTTP to HTTPS automatic redirection
17//! - **Path Parameters**: Express-style path parameter syntax (`:param`, `*wildcard`)
18//! - **no_std Support**: Can be used in no_std environments with alloc
19//!
20//! ## Quick Start
21//!
22//! ```rust
23//! use product_os_router::{ProductOSRouter, IntoResponse};
24//!
25//! async fn hello_world() -> &'static str {
26//!     "Hello, World!"
27//! }
28//!
29//! # async fn example() {
30//! let mut router = ProductOSRouter::new();
31//! router.add_get_no_state("/", hello_world);
32//!
33//! let app = router.get_router();
34//! // Use app with your server (e.g., axum::Server)
35//! # }
36//! ```
37//!
38//! ## Path Parameter Syntax
39//!
40//! The router automatically converts Express-style path parameters to Axum format:
41//! - `:param` becomes `{param}` (single segment)
42//! - `*wildcard` becomes `{wildcard}` (catch-all)
43//!
44//! ```rust
45//! # use product_os_router::ProductOSRouter;
46//! # use axum::extract::Path;
47//! async fn user_handler(Path(id): Path<u32>) -> String {
48//!     format!("User ID: {}", id)
49//! }
50//!
51//! # async fn example() {
52//! let mut router = ProductOSRouter::new();
53//! router.add_get_no_state("/users/:id", user_handler);  // Converted to /users/{id}
54//! # }
55//! ```
56//!
57//! ## With State
58//!
59//! ```rust
60//! # use product_os_router::ProductOSRouter;
61//! # use axum::extract::State;
62//! #[derive(Clone)]
63//! struct AppState {
64//!     counter: u32,
65//! }
66//!
67//! async fn handler(State(state): State<AppState>) -> String {
68//!     format!("Counter: {}", state.counter)
69//! }
70//!
71//! # async fn example() {
72//! let state = AppState { counter: 0 };
73//! let mut router = ProductOSRouter::new_with_state(state);
74//! router.add_get("/", handler);
75//! # }
76//! ```
77//!
78//! ## Feature Flags
79//!
80//! - `default`: Includes `core` and `std` features
81//! - `std`: Standard library support with Axum
82//! - `core`: Core routing functionality
83//! - `ws`: WebSocket support
84//! - `sse`: Server-Sent Events support
85//! - `cors`: CORS middleware support
86//! - `sessions`: Session management support
87//! - `middleware`: Custom middleware traits
88//! - `debug`: Debug handler macros
89
90#![no_std]
91extern crate no_std_compat as std;
92
93use std::prelude::v1::*;
94
95use core::mem;
96use std::collections::BTreeMap;
97
98pub use product_os_http::{
99    Request, Response,
100    request::Parts as RequestParts,
101    response::Parts as ResponseParts,
102    StatusCode,
103    header::{ HeaderMap, HeaderName, HeaderValue }
104};
105
106pub use product_os_http_body::{ BodyBytes, Bytes };
107
108#[cfg(feature = "middleware")]
109pub use product_os_http_body::*;
110
111pub use axum::{
112    routing::{get, post, put, patch, delete, head, trace, options, any, MethodRouter},
113    middleware::{from_fn, from_fn_with_state, Next},
114    handler::Handler,
115    response::{ IntoResponse, Redirect },
116    body::{ HttpBody, Body },
117    Router,
118    Json,
119    Form,
120    extract::{
121        State, Path, Query, FromRef, FromRequest, FromRequestParts,
122        ws::{ WebSocketUpgrade, WebSocket, Message }
123    },
124    BoxError,
125    http::Uri
126};
127
128// Import Route for internal use
129use axum::routing::Route;
130
131
132pub use crate::extractors::RequestMethod;
133
134
135
136#[cfg(feature = "debug")]
137pub use axum_macros::debug_handler;
138
139pub use tower::{
140    Layer, Service, ServiceBuilder, ServiceExt, MakeService,
141    make::AsService, make::MakeConnection, make::IntoService,
142    make::Shared, make::future::SharedFuture,
143    util::service_fn, util::ServiceFn
144};
145
146pub use product_os_request::Method;
147use regex::Regex;
148
149
150#[cfg(feature = "middleware")]
151pub mod middleware;
152
153mod extractors;
154mod default_headers;
155mod service_wrapper;
156pub mod dual_protocol;
157
158pub use crate::default_headers::DefaultHeadersLayer;
159pub use crate::dual_protocol::{UpgradeHttpLayer, Protocol};
160
161#[cfg(feature = "middleware")]
162pub use crate::middleware::*;
163
164#[cfg(feature = "cors")]
165use tower_http::cors::{CorsLayer, Any};
166
167use crate::service_wrapper::{WrapperLayer, WrapperService};
168
169
170/// Error types for router operations
171///
172/// Provides semantic error types for different HTTP error scenarios.
173/// Each variant maps to an appropriate HTTP status code.
174///
175/// # Examples
176///
177/// ```rust
178/// use product_os_router::{ProductOSRouterError, StatusCode};
179///
180/// let error = ProductOSRouterError::Authentication("Invalid token".to_string());
181/// let response = ProductOSRouterError::handle_error(&error);
182/// assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
183/// ```
184#[derive(Debug)]
185pub enum ProductOSRouterError {
186    /// Invalid header error (400 Bad Request)
187    Headers(String),
188    /// Invalid query parameter error (400 Bad Request)
189    Query(String),
190    /// Invalid request body error (400 Bad Request)
191    Body(String),
192    /// Authentication failure (401 Unauthorized)
193    Authentication(String),
194    /// Authorization failure (403 Forbidden)
195    Authorization(String),
196    /// Processing timeout or failure (504 Gateway Timeout)
197    Process(String),
198    /// Service unavailable (503 Service Unavailable)
199    Unavailable(String)
200}
201
202impl ProductOSRouterError {
203    /// Extract the error message from any variant.
204    fn message(&self) -> &str {
205        match self {
206            ProductOSRouterError::Headers(m) => m,
207            ProductOSRouterError::Query(m) => m,
208            ProductOSRouterError::Body(m) => m,
209            ProductOSRouterError::Authentication(m) => m,
210            ProductOSRouterError::Authorization(m) => m,
211            ProductOSRouterError::Process(m) => m,
212            ProductOSRouterError::Unavailable(m) => m,
213        }
214    }
215
216    /// Return the appropriate HTTP status code for this error variant.
217    fn status_code(&self) -> StatusCode {
218        match self {
219            ProductOSRouterError::Headers(_) => StatusCode::BAD_REQUEST,
220            ProductOSRouterError::Query(_) => StatusCode::BAD_REQUEST,
221            ProductOSRouterError::Body(_) => StatusCode::BAD_REQUEST,
222            ProductOSRouterError::Authentication(_) => StatusCode::UNAUTHORIZED,
223            ProductOSRouterError::Authorization(_) => StatusCode::FORBIDDEN,
224            ProductOSRouterError::Process(_) => StatusCode::GATEWAY_TIMEOUT,
225            ProductOSRouterError::Unavailable(_) => StatusCode::SERVICE_UNAVAILABLE,
226        }
227    }
228
229    /// Build a JSON error response body from a message and status code.
230    ///
231    /// Quotes in the error message are replaced with single quotes to
232    /// ensure valid JSON output.
233    fn build_error_response(message: &str, status_code: &StatusCode) -> Response<Body> {
234        let safe_message = message.replace("\"", "'");
235        let status_code_u16 = status_code.as_u16();
236
237        Response::builder()
238            .status(status_code_u16)
239            .header("content-type", "application/json")
240            .body(Body::from(format!("{{\"error\": \"{safe_message}\"}}")))
241            .unwrap_or_else(|_| {
242                Response::builder()
243                    .status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
244                    .body(Body::from("{\"error\": \"internal error\"}"))
245                    .expect("fallback response must be valid")
246            })
247    }
248
249    /// Create an HTTP response from an error with a specific status code.
250    ///
251    /// Consumes the error and converts it into a JSON response with the error
252    /// message. Quotes in the error message are replaced with single quotes.
253    ///
254    /// # Arguments
255    ///
256    /// * `error` - The error to convert (consumed)
257    /// * `status_code` - The HTTP status code to use
258    ///
259    /// # Returns
260    ///
261    /// An HTTP response with JSON body containing the error message
262    ///
263    /// # Examples
264    ///
265    /// ```rust
266    /// use product_os_router::{ProductOSRouterError, StatusCode};
267    ///
268    /// let error = ProductOSRouterError::Headers("Invalid header".to_string());
269    /// let response = ProductOSRouterError::error_response(error, &StatusCode::BAD_REQUEST);
270    /// ```
271    pub fn error_response(error: ProductOSRouterError, status_code: &StatusCode) -> Response<Body> {
272        Self::build_error_response(error.message(), status_code)
273    }
274
275    /// Handle an error and return an appropriate HTTP response.
276    ///
277    /// Automatically determines the appropriate HTTP status code based on the
278    /// error variant and returns a JSON response with the error message.
279    ///
280    /// # Arguments
281    ///
282    /// * `error` - The error to handle
283    ///
284    /// # Returns
285    ///
286    /// An HTTP response with the appropriate status code and error message
287    ///
288    /// # Examples
289    ///
290    /// ```rust
291    /// use product_os_router::ProductOSRouterError;
292    ///
293    /// let error = ProductOSRouterError::Authentication("Login required".to_string());
294    /// let response = ProductOSRouterError::handle_error(&error);
295    /// ```
296    pub fn handle_error(error: &ProductOSRouterError) -> Response<Body> {
297        Self::build_error_response(error.message(), &error.status_code())
298    }
299}
300
301
302/// Main router struct for Product OS Router
303///
304/// Provides a high-level API for defining routes, handlers, and middleware.
305/// Supports both stateful and stateless routing.
306///
307/// # Type Parameters
308///
309/// * `S` - The state type (defaults to `()` for stateless routing)
310///
311/// # Examples
312///
313/// ## Stateless Router
314///
315/// ```rust
316/// use product_os_router::ProductOSRouter;
317///
318/// async fn handler() -> &'static str {
319///     "Hello!"
320/// }
321///
322/// let mut router = ProductOSRouter::new();
323/// router.add_get_no_state("/", handler);
324/// ```
325///
326/// ## Stateful Router
327///
328/// ```rust
329/// use product_os_router::ProductOSRouter;
330/// use axum::extract::State;
331///
332/// #[derive(Clone)]
333/// struct AppState {
334///     message: String,
335/// }
336///
337/// async fn handler(State(state): State<AppState>) -> String {
338///     state.message
339/// }
340///
341/// let state = AppState { message: "Hello!".to_string() };
342/// let mut router = ProductOSRouter::new_with_state(state);
343/// router.add_get("/", handler);
344/// ```
345pub struct ProductOSRouter<S = ()> {
346    router: Router<S>,
347    state: S
348}
349
350impl<S: Clone + Send + Sync + 'static> ProductOSRouter<S> {
351    /// Take the inner router, replacing it with an empty default.
352    ///
353    /// This avoids cloning on every mutation. Axum's `Router` methods consume
354    /// `self`, so we need ownership; previously the code called `.clone()`.
355    fn take_router(&mut self) -> Router<S> {
356        mem::replace(&mut self.router, Router::new())
357    }
358}
359
360impl ProductOSRouter<()> {
361    /// Create a new stateless router
362    ///
363    /// Creates a router without any shared state. Use [`new_with_state`](Self::new_with_state)
364    /// if you need to share state between handlers.
365    ///
366    /// # Examples
367    ///
368    /// ```rust
369    /// use product_os_router::ProductOSRouter;
370    ///
371    /// let router = ProductOSRouter::new();
372    /// ```
373    pub fn new() -> Self {
374        ProductOSRouter::new_with_state(())
375    }
376}
377
378impl Default for ProductOSRouter<()> {
379    fn default() -> Self {
380        Self::new()
381    }
382}
383
384impl ProductOSRouter<()> {
385    /// Convert path parameters from Express-style to Axum format
386    ///
387    /// Transforms route paths with `:param` and `*wildcard` syntax into Axum's `{param}` format.
388    ///
389    /// # Arguments
390    ///
391    /// * `path` - The path with Express-style parameters
392    ///
393    /// # Returns
394    ///
395    /// The path converted to Axum format
396    ///
397    /// # Examples
398    ///
399    /// ```rust
400    /// use product_os_router::ProductOSRouter;
401    ///
402    /// assert_eq!(ProductOSRouter::param_to_field("/users/:id"), "/users/{id}");
403    /// assert_eq!(ProductOSRouter::param_to_field("/files/*path"), "/files/{path}");
404    /// ```
405    pub fn param_to_field(path: &str) -> String {
406        use std::sync::OnceLock;
407
408        static PARAM_MATCHER: OnceLock<Regex> = OnceLock::new();
409        static SYMBOL_MATCHER: OnceLock<Regex> = OnceLock::new();
410
411        let param_matcher = PARAM_MATCHER.get_or_init(|| {
412            Regex::new("([:*])([a-zA-Z][-_a-zA-Z0-9]*)").unwrap()
413        });
414        let symbol_matcher = SYMBOL_MATCHER.get_or_init(|| {
415            Regex::new("[*?&]").unwrap()
416        });
417
418        let mut path_new = path.to_owned();
419        for (value, [_prefix, variable_name]) in param_matcher.captures_iter(path).map(|c| c.extract()) {
420            let mut value_sanitized = value.to_owned();
421            for (prefix, []) in symbol_matcher.captures_iter(value).map(|c| c.extract()) {
422                let mut symbol = String::from("[");
423                symbol.push(prefix.chars().next().unwrap());
424                symbol.push(']');
425
426                value_sanitized = value_sanitized.replace(prefix, symbol.as_str());
427            }
428
429            let exact_matcher = Regex::new(value_sanitized.as_str()).unwrap();
430
431            let mut variable = String::from("{");
432            variable.push_str(variable_name);
433            variable.push('}');
434
435            path_new = exact_matcher.replace(path_new.as_str(), variable).to_string()
436        }
437
438        path_new
439    }
440
441    /// Add a Tower service as a route handler without state.
442    ///
443    /// The service must accept `Request<Body>` and return `Response<Body>`.
444    pub fn add_service_no_state<Z>(&mut self, path: &str, service: Z)
445        where
446            Z: Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Clone + Send + Sync + 'static,
447            Z::Future: Send + 'static
448    {
449        let wrapper = WrapperService::new(service);
450        let path = ProductOSRouter::param_to_field(path);
451        self.router = self.take_router().route_service(path.as_str(), wrapper);
452    }
453
454    /// Add a Tower middleware layer to the stateless router.
455    ///
456    /// The layer wraps every route and is applied in LIFO order (last added runs first).
457    pub fn add_middleware_no_state<L, ResBody>(&mut self, middleware: L)
458        where
459            L: Layer<Route> + Clone + Send + Sync + 'static,
460            L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
461            ResBody: HttpBody<Data = Bytes> + Send + 'static,
462            ResBody::Error: Into<BoxError>,
463            <L::Service as Service<Request<Body>>>::Future: Send + 'static,
464            <L::Service as Service<Request<Body>>>::Error: Into<BoxError> + 'static,
465    {
466        let wrapper: WrapperLayer<L, Route, ResBody> = WrapperLayer::new(middleware);
467        self.router = self.take_router().layer(wrapper)
468    }
469
470    /// Add a default header to all responses from this stateless router.
471    ///
472    /// Headers added this way will not overwrite headers already present in the response.
473    pub fn add_default_header_no_state(&mut self, header_name: &str, header_value: &str) {
474        let mut default_headers = HeaderMap::new();
475        default_headers.insert(HeaderName::from_bytes(header_name.as_bytes()).unwrap(), HeaderValue::from_str(header_value).unwrap());
476
477        self.add_middleware_no_state(DefaultHeadersLayer::new(default_headers));
478    }
479
480    /// Returns a 501 Not Implemented JSON response.
481    ///
482    /// Use this to indicate that the server does not support the functionality
483    /// required to fulfill the request.
484    pub fn not_implemented() -> Response<Body> {
485        Response::builder()
486            .status(StatusCode::NOT_IMPLEMENTED.as_u16())
487            .header("content-type", "application/json")
488            .body(Body::from("{}"))
489            .unwrap()
490    }
491
492    /// Returns a 404 Not Found JSON response.
493    ///
494    /// This was the original behavior of `not_implemented()` prior to v0.0.41.
495    /// Use [`not_implemented`](Self::not_implemented) for the correct 501 status code,
496    /// or use this method if you specifically need a 404 response.
497    #[deprecated(since = "0.0.41", note = "Use not_implemented() which correctly returns 501, or build a 404 response explicitly")]
498    pub fn not_implemented_legacy() -> Response<Body> {
499        Response::builder()
500            .status(StatusCode::NOT_FOUND.as_u16())
501            .header("content-type", "application/json")
502            .body(Body::from("{}"))
503            .unwrap()
504    }
505
506
507
508    /// Add a pre-built [`MethodRouter`] at the given path without state.
509    pub fn add_route_no_state(&mut self, path: &str, service_handler: MethodRouter)
510    {
511        let service= ServiceBuilder::new().service(service_handler);
512        let path = ProductOSRouter::param_to_field(path);
513        self.router = self.take_router().route(path.as_str(), service);
514    }
515
516    /// Set a fallback [`MethodRouter`] for unmatched routes without state.
517    pub fn set_fallback_no_state(&mut self, service_handler: MethodRouter)
518    {
519        let service= ServiceBuilder::new().service(service_handler);
520        self.router = self.take_router().fallback(service);
521    }
522
523    /// Add a GET route handler without state
524    ///
525    /// # Arguments
526    ///
527    /// * `path` - The route path (supports `:param` and `*wildcard`)
528    /// * `handler` - The async function to handle requests
529    ///
530    /// # Examples
531    ///
532    /// ```rust
533    /// use product_os_router::ProductOSRouter;
534    ///
535    /// async fn hello() -> &'static str {
536    ///     "Hello, World!"
537    /// }
538    ///
539    /// let mut router = ProductOSRouter::new();
540    /// router.add_get_no_state("/", hello);
541    /// ```
542    pub fn add_get_no_state<H, T>(&mut self, path: &str, handler: H)
543        where
544            H: Handler<T, ()>,
545            T: 'static
546    {
547        let method_router = ProductOSRouter::convert_handler_to_method_router(Method::GET, handler, ());
548        let path = ProductOSRouter::param_to_field(path);
549        self.router = self.take_router().route(path.as_str(), method_router);
550    }
551
552    /// Add a POST route handler without state
553    ///
554    /// # Arguments
555    ///
556    /// * `path` - The route path
557    /// * `handler` - The async function to handle requests
558    ///
559    /// # Examples
560    ///
561    /// ```rust
562    /// use product_os_router::{ProductOSRouter, Json};
563    /// use serde::{Deserialize, Serialize};
564    ///
565    /// #[derive(Deserialize, Serialize)]
566    /// struct User {
567    ///     name: String,
568    /// }
569    ///
570    /// async fn create_user(Json(user): Json<User>) -> Json<User> {
571    ///     Json(user)
572    /// }
573    ///
574    /// let mut router = ProductOSRouter::new();
575    /// router.add_post_no_state("/users", create_user);
576    /// ```
577    pub fn add_post_no_state<H, T>(&mut self, path: &str, handler: H)
578        where
579            H: Handler<T, ()>,
580            T: 'static
581    {
582        let method_router = ProductOSRouter::convert_handler_to_method_router(Method::POST, handler, ());
583        let path = ProductOSRouter::param_to_field(path);
584        self.router = self.take_router().route(path.as_str(), method_router);
585    }
586
587    /// Add a handler for the given HTTP method and path without state.
588    pub fn add_handler_no_state<H, T>(&mut self, path: &str, method: Method, handler: H)
589        where
590            H: Handler<T, ()>,
591            T: 'static
592    {
593        let method_router = ProductOSRouter::convert_handler_to_method_router(method, handler, ());
594        let path = ProductOSRouter::param_to_field(path);
595        self.router = self.take_router().route(path.as_str(), method_router);
596    }
597
598    /// Set a fallback handler for unmatched routes without state.
599    pub fn set_fallback_handler_no_state<H, T>(&mut self, handler: H)
600        where
601            H: Handler<T, ()>,
602            T: 'static
603    {
604        self.router = self.take_router().fallback(handler).with_state(());
605    }
606
607    /// Add a CORS-enabled handler for the given HTTP method and path without state.
608    #[cfg(feature = "cors")]
609    pub fn add_cors_handler_no_state<H, T>(&mut self, path: &str, method: Method, handler: H)
610        where
611            H: Handler<T, ()>,
612            T: 'static
613    {
614        let method_router = ProductOSRouter::add_cors_method_router(method, handler, ());
615        let path = ProductOSRouter::param_to_field(path);
616        self.router = self.take_router().route(path.as_str(), method_router);
617    }
618
619    /// Add a WebSocket handler at the given path without state.
620    #[cfg(feature = "ws")]
621    pub fn add_ws_handler_no_state<H, T>(&mut self, path: &str, ws_handler: H)
622        where
623            H: Handler<T, ()>,
624            T: 'static
625    {
626        // https://docs.rs/axum/latest/axum/extract/ws/index.html
627        let service: MethodRouter = ProductOSRouter::convert_handler_to_method_router(Method::GET, ws_handler, ());
628        let path = ProductOSRouter::param_to_field(path);
629        self.router = self.take_router().route(path.as_str(), service);
630    }
631
632    /// Add a Server-Sent Events handler at the given path without state.
633    #[cfg(feature = "sse")]
634    pub fn add_sse_handler_no_state<H, T>(&mut self, path: &str, sse_handler: H)
635        where
636            H: Handler<T, ()>,
637            T: 'static
638    {
639        // https://docs.rs/axum/latest/axum/response/sse/index.html
640        let service: MethodRouter = ProductOSRouter::convert_handler_to_method_router(Method::GET, sse_handler, ());
641        let path = ProductOSRouter::param_to_field(path);
642        self.router = self.take_router().route(path.as_str(), service);
643    }
644
645    /// Add multiple HTTP method handlers at a single path without state.
646    pub fn add_handlers_no_state<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
647        where
648            H: Handler<T, ()>,
649            T: 'static
650    {
651        let mut method_router: MethodRouter = MethodRouter::new();
652
653        for (method, handler) in handlers {
654            method_router = ProductOSRouter::add_handler_to_method_router(method_router, method, handler);
655        }
656
657        let path = ProductOSRouter::param_to_field(path);
658        self.router = self.take_router().route(path.as_str(), method_router);
659    }
660
661    /// Adds a global CORS middleware layer that applies to all routes.
662    ///
663    /// This permits any origin, any HTTP method, and any request header,
664    /// which is appropriate for development and for APIs consumed by
665    /// third-party front-ends.  For production you may want to use
666    /// [`add_middleware`] with a more restrictive [`CorsLayer`] instead.
667    #[cfg(feature = "cors")]
668    pub fn add_cors_middleware_no_state(&mut self) {
669        self.add_middleware_no_state(
670            CorsLayer::new()
671                .allow_origin(Any)
672                .allow_methods(Any)
673                .allow_headers(Any),
674        );
675    }
676
677    /// Add multiple CORS-enabled HTTP method handlers at a single path without state.
678    #[cfg(feature = "cors")]
679    pub fn add_cors_handlers_no_state<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
680        where
681            H: Handler<T, ()>,
682            T: 'static
683    {
684        let mut service: MethodRouter = MethodRouter::new();
685
686        for (method, handler) in handlers {
687            service = ProductOSRouter::add_cors_handler_to_method_router(service, method, handler);
688        }
689
690        let path = ProductOSRouter::param_to_field(path);
691        self.router = self.take_router().route(path.as_str(), service);
692    }
693
694}
695
696impl<S> ProductOSRouter<S>
697where
698    S: Clone + Send + Sync + 'static
699{
700    /// Build and return the finalized [`Router`] with the stored state applied.
701    ///
702    /// The returned `Router` can be used with a server (e.g., `axum::serve`).
703    pub fn get_router(&self) -> Router {
704        self.router.clone().with_state(self.state.clone())
705    }
706
707    /// Return a mutable reference to the underlying [`Router`] for advanced configuration.
708    pub fn get_router_mut(&mut self) -> &mut Router<S> {
709        &mut self.router
710    }
711
712    /// Create a new router with the given shared state.
713    ///
714    /// The state is made available to handlers via the [`State`] extractor.
715    pub fn new_with_state(state: S) -> Self
716    where
717        S: Clone + Send + 'static
718    {
719        Self {
720            router: Router::new().with_state(state.clone()),
721            state
722        }
723    }
724
725    /// Add a Tower service as a route handler with state.
726    pub fn add_service<Z>(&mut self, path: &str, service: Z)
727        where
728            Z: Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Clone + Send + Sync + 'static,
729            Z::Response: IntoResponse,
730            Z::Future: Send + 'static
731    {
732        let wrapper = WrapperService::new(service);
733        let path = ProductOSRouter::param_to_field(path);
734        self.router = self.take_router().route_service(path.as_str(), wrapper);
735    }
736
737    /// Add a Tower middleware layer to the router.
738    pub fn add_middleware<L, ResBody>(&mut self, middleware: L)
739        where
740            L: Layer<Route> + Clone + Send + Sync + 'static,
741            L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
742            ResBody: HttpBody<Data = Bytes> + Send + 'static,
743            ResBody::Error: Into<BoxError>,
744            <L::Service as Service<Request<Body>>>::Future: Send + 'static,
745            <L::Service as Service<Request<Body>>>::Error: Into<BoxError> + 'static,
746    {
747        let wrapper: WrapperLayer<L, Route, ResBody> = WrapperLayer::new(middleware);
748        self.router = self.take_router().layer(wrapper)
749    }
750
751    /// Add a default header to all responses from this router.
752    ///
753    /// Headers added this way will not overwrite headers already present in the response.
754    ///
755    /// # Arguments
756    ///
757    /// * `header_name` - The header name
758    /// * `header_value` - The header value
759    pub fn add_default_header(&mut self, header_name: &str, header_value: &str) {
760        let mut default_headers = HeaderMap::new();
761        default_headers.insert(HeaderName::from_bytes(header_name.as_bytes()).unwrap(), HeaderValue::from_str(header_value).unwrap());
762
763        self.add_middleware(DefaultHeadersLayer::new(default_headers));
764    }
765
766    /// Add a default header to all responses from this router.
767    ///
768    /// This method accepts owned `String` parameters for backward compatibility.
769    /// Prefer [`add_default_header`](Self::add_default_header) which takes `&str` instead.
770    #[deprecated(since = "0.0.41", note = "Use add_default_header(&str, &str) instead")]
771    pub fn add_default_header_owned(&mut self, header_name: String, header_value: String) {
772        self.add_default_header(header_name.as_str(), header_value.as_str());
773    }
774
775
776
777    /// Add a pre-built [`MethodRouter`] at the given path.
778    pub fn add_route(&mut self, path: &str, service_handler: MethodRouter<S>)
779    where
780        S: Clone + Send + 'static,
781    {
782        let service= ServiceBuilder::new().service(service_handler);
783        let path = ProductOSRouter::param_to_field(path);
784        self.router = self.take_router().route(path.as_str(), service);
785    }
786
787    /// Set a fallback [`MethodRouter`] for unmatched routes.
788    pub fn set_fallback(&mut self, service_handler: MethodRouter<S>)
789    where
790        S: Clone + Send + 'static
791    {
792        let service= ServiceBuilder::new().service(service_handler);
793        self.router = self.take_router().fallback(service);
794    }
795
796    /// Add a GET route handler with state.
797    pub fn add_get<H, T>(&mut self, path: &str, handler: H)
798    where
799        H: Handler<T, S>,
800        T: 'static
801    {
802        let method_router = ProductOSRouter::convert_handler_to_method_router(Method::GET, handler, self.state.clone());
803        let path = ProductOSRouter::param_to_field(path);
804        self.router = self.take_router().route(path.as_str(), method_router);
805    }
806
807    /// Add a POST route handler with state.
808    pub fn add_post<H, T>(&mut self, path: &str, handler: H)
809    where
810        H: Handler<T, S>,
811        T: 'static
812    {
813        let method_router = ProductOSRouter::convert_handler_to_method_router(Method::POST, handler, self.state.clone());
814        let path = ProductOSRouter::param_to_field(path);
815        self.router = self.take_router().route(path.as_str(), method_router);
816    }
817
818    /// Add a handler for the given HTTP method and path with state.
819    pub fn add_handler<H, T>(&mut self, path: &str, method: Method, handler: H)
820    where
821        H: Handler<T, S>,
822        T: 'static
823    {
824        let method_router = ProductOSRouter::convert_handler_to_method_router(method, handler, self.state.clone());
825        let path = ProductOSRouter::param_to_field(path);
826        self.router = self.take_router().route(path.as_str(), method_router);
827    }
828
829    /// Replace the router's shared state.
830    pub fn add_state(&mut self, state: S) {
831        self.router = self.take_router().with_state(state);
832    }
833
834    /// Set a fallback handler for unmatched routes with state.
835    pub fn set_fallback_handler<H, T>(&mut self, handler: H)
836    where
837        H: Handler<T, S>,
838        T: 'static
839    {
840        self.router = self.take_router().fallback(handler).with_state(self.state.clone());
841    }
842
843    /// Add a CORS-enabled handler for the given HTTP method and path with state.
844    #[cfg(feature = "cors")]
845    pub fn add_cors_handler<H, T>(&mut self, path: &str, method: Method, handler: H)
846    where
847        H: Handler<T, S>,
848        T: 'static
849    {
850        let method_router = ProductOSRouter::add_cors_method_router(method, handler, self.state.clone());
851        let path = ProductOSRouter::param_to_field(path);
852        self.router = self.take_router().route(path.as_str(), method_router);
853    }
854
855    /// Add a WebSocket handler at the given path with state.
856    #[cfg(feature = "ws")]
857    pub fn add_ws_handler<H, T>(&mut self, path: &str, ws_handler: H)
858    where
859        H: Handler<T, S>,
860        T: 'static
861    {
862        // https://docs.rs/axum/latest/axum/extract/ws/index.html
863        let service: MethodRouter<S> = ProductOSRouter::convert_handler_to_method_router(Method::GET, ws_handler, self.state.clone());
864        let path = ProductOSRouter::param_to_field(path);
865        self.router = self.take_router().route(path.as_str(), service);
866    }
867
868    /// Add a Server-Sent Events handler at the given path with state.
869    #[cfg(feature = "sse")]
870    pub fn add_sse_handler<H, T>(&mut self, path: &str, sse_handler: H)
871        where
872            H: Handler<T, S>,
873            T: 'static
874    {
875        // https://docs.rs/axum/latest/axum/response/sse/index.html
876        let service: MethodRouter<S> = ProductOSRouter::convert_handler_to_method_router(Method::GET, sse_handler, self.state.clone());
877        let path = ProductOSRouter::param_to_field(path);
878        self.router = self.take_router().route(path.as_str(), service);
879    }
880
881    /// Add multiple HTTP method handlers at a single path with state.
882    pub fn add_handlers<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
883    where
884        H: Handler<T, S>,
885        T: 'static
886    {
887        let mut method_router: MethodRouter<S> = MethodRouter::new();
888
889        for (method, handler) in handlers {
890            method_router = ProductOSRouter::add_handler_to_method_router(method_router, method, handler);
891        }
892
893        let path = ProductOSRouter::param_to_field(path);
894        self.router = self.take_router().route(path.as_str(), method_router);
895    }
896
897    /// Add multiple CORS-enabled HTTP method handlers at a single path with state.
898    #[cfg(feature = "cors")]
899    pub fn add_cors_handlers<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
900    where
901        H: Handler<T, S>,
902        T: 'static
903    {
904        let mut service: MethodRouter<S> = MethodRouter::new();
905
906        for (method, handler) in handlers {
907            service = ProductOSRouter::add_cors_handler_to_method_router(service, method, handler);
908        }
909
910        let path = ProductOSRouter::param_to_field(path);
911        self.router = self.take_router().route(path.as_str(), service);
912    }
913
914    /// Adds a global CORS middleware layer that applies to all routes.
915    ///
916    /// This permits any origin, any HTTP method, and any request header,
917    /// which is appropriate for development and for APIs consumed by
918    /// third-party front-ends.  For production you may want to use
919    /// [`add_middleware`] with a more restrictive [`CorsLayer`] instead.
920    #[cfg(feature = "cors")]
921    pub fn add_cors_middleware(&mut self) {
922        self.add_middleware(
923            CorsLayer::new()
924                .allow_origin(Any)
925                .allow_methods(Any)
926                .allow_headers(Any),
927        );
928    }
929
930    fn add_handler_to_method_router<H, T>(method_router: MethodRouter<S>, method: Method, handler: H) -> MethodRouter<S>
931    where
932        H: Handler<T, S>,
933        T: 'static
934    {
935        match method {
936            Method::GET => method_router.get(handler),
937            Method::POST => method_router.post(handler),
938            Method::PUT => method_router.put(handler),
939            Method::PATCH => method_router.patch(handler),
940            Method::DELETE => method_router.delete(handler),
941            Method::TRACE => method_router.trace(handler),
942            Method::HEAD => method_router.head(handler),
943            // Axum has no native CONNECT method router; fall back to GET
944            Method::CONNECT => method_router.get(handler),
945            Method::OPTIONS => method_router.options(handler),
946            Method::ANY => method_router.fallback(handler),
947        }
948    }
949
950    fn convert_handler_to_method_router<H, T>(method: Method, handler: H, state: S) -> MethodRouter<S>
951    where
952        H: Handler<T, S>,
953        T: 'static,
954    {
955        match method {
956            Method::GET => get(handler).with_state(state),
957            Method::POST => post(handler).with_state(state),
958            Method::PUT => put(handler).with_state(state),
959            Method::PATCH => patch(handler).with_state(state),
960            Method::DELETE => delete(handler).with_state(state),
961            Method::TRACE => trace(handler).with_state(state),
962            Method::HEAD => head(handler).with_state(state),
963            // Axum has no native CONNECT method router; fall back to GET
964            Method::CONNECT => get(handler).with_state(state),
965            Method::OPTIONS => options(handler).with_state(state),
966            Method::ANY => any(handler).with_state(state)
967        }
968    }
969
970    #[cfg(feature = "cors")]
971    fn add_cors_handler_to_method_router<H, T>(method_router: MethodRouter<S>, method: Method, handler: H) -> MethodRouter<S>
972        where
973            H: Handler<T, S>,
974            T: 'static
975    {
976        ProductOSRouter::add_handler_to_method_router(method_router, method, handler).layer(CorsLayer::new()
977            .allow_origin(Any)
978            .allow_methods(Any))
979    }
980
981    #[cfg(feature = "cors")]
982    fn add_cors_method_router<H, T>(method: Method, handler: H, state: S) -> MethodRouter<S>
983        where
984            H: Handler<T, S>,
985            T: 'static
986    {
987        ProductOSRouter::convert_handler_to_method_router(method, handler, state).layer(CorsLayer::new()
988            .allow_origin(Any)
989            .allow_methods(Any))
990    }
991}
992
993
994// Extension extractor for backward compatibility with code that used axum <0.7
995// Since axum 0.7+ removed Extension in favor of State, we reimplement it here
996// using http::Extensions
997use std::ops::{Deref, DerefMut};
998
999/// Extractor that gets a value from request extensions.
1000///
1001/// This is commonly used with middleware that inserts data into extensions.
1002/// In axum 0.7+, prefer using `State` for application-wide state.
1003///
1004/// # Example
1005///
1006/// ```ignore
1007/// use product_os_router::{Extension, ServiceBuilder, MethodRouter};
1008/// use std::sync::Arc;
1009///
1010/// async fn handler(Extension(value): Extension<Arc<String>>) -> String {
1011///     (*value).clone()
1012/// }
1013///
1014/// let shared_data = Arc::new("Hello".to_string());
1015/// let route = MethodRouter::new()
1016///     .get(handler)
1017///     .layer(ServiceBuilder::new().layer(tower::layer::util::AddExtensionLayer::new(shared_data)));
1018/// ```
1019#[derive(Debug, Clone, Copy, Default)]
1020pub struct Extension<T>(pub T);
1021
1022impl<T> Deref for Extension<T> {
1023    type Target = T;
1024
1025    fn deref(&self) -> &Self::Target {
1026        &self.0
1027    }
1028}
1029
1030impl<T> DerefMut for Extension<T> {
1031    fn deref_mut(&mut self) -> &mut Self::Target {
1032        &mut self.0
1033    }
1034}
1035
1036#[async_trait::async_trait]
1037impl<T, S> FromRequestParts<S> for Extension<T>
1038where
1039    T: Clone + Send + Sync + 'static,
1040    S: Send + Sync,
1041{
1042    type Rejection = (product_os_http::StatusCode, &'static str);
1043
1044    fn from_request_parts(parts: &mut product_os_http::request::Parts, _state: &S) -> impl core::future::Future<Output = Result<Self, Self::Rejection>> + Send {
1045        let value = parts
1046            .extensions
1047            .get::<T>()
1048            .cloned()
1049            .map(Extension)
1050            .ok_or((
1051                product_os_http::StatusCode::INTERNAL_SERVER_ERROR,
1052                "Extension not found in request",
1053            ));
1054        
1055        async move { value }
1056    }
1057}
1058
1059// Helper function to create an extension layer 
1060// Uses tower's MapRequestLayer which properly implements all required traits
1061pub fn extension_layer<T>(value: T) -> tower::util::MapRequestLayer<impl Fn(Request<Body>) -> Request<Body> + Clone>
1062where
1063    T: Clone + Send + Sync + 'static,
1064{
1065    tower::util::MapRequestLayer::new(move |mut req: Request<Body>| {
1066        req.extensions_mut().insert(value.clone());
1067        req
1068    })
1069}
1070