use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::app::{BoxHandler, RouteEntry};
use crate::context::RequestContext;
use crate::request::{Method, Request};
use crate::response::Response;
#[derive(Debug, Clone)]
pub struct ResponseDef {
pub description: String,
pub example: Option<serde_json::Value>,
pub content_type: Option<String>,
}
impl ResponseDef {
#[must_use]
pub fn new(description: impl Into<String>) -> Self {
Self {
description: description.into(),
example: None,
content_type: None,
}
}
#[must_use]
pub fn with_example(mut self, example: serde_json::Value) -> Self {
self.example = Some(example);
self
}
#[must_use]
pub fn with_content_type(mut self, content_type: impl Into<String>) -> Self {
self.content_type = Some(content_type.into());
self
}
}
pub type BoxDependency = Arc<
dyn Fn(
&RequestContext,
&mut Request,
) -> Pin<Box<dyn Future<Output = Result<(), Response>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct RouterDependency {
pub(crate) handler: BoxDependency,
pub(crate) name: String,
}
impl RouterDependency {
pub fn new<F, Fut>(name: impl Into<String>, f: F) -> Self
where
F: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), Response>> + Send + 'static,
{
Self {
handler: Arc::new(move |ctx, req| Box::pin(f(ctx, req))),
name: name.into(),
}
}
pub async fn execute(&self, ctx: &RequestContext, req: &mut Request) -> Result<(), Response> {
(self.handler)(ctx, req).await
}
}
impl std::fmt::Debug for RouterDependency {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouterDependency")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
#[derive(Debug, Default, Clone)]
pub struct IncludeConfig {
prefix: Option<String>,
tags: Vec<String>,
dependencies: Vec<RouterDependency>,
responses: HashMap<u16, ResponseDef>,
deprecated: Option<bool>,
include_in_schema: Option<bool>,
}
impl IncludeConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
let p = prefix.into();
if !p.is_empty() {
let normalized = if p.starts_with('/') {
p
} else {
format!("/{}", p)
};
let normalized = if normalized.ends_with('/') && normalized.len() > 1 {
normalized.trim_end_matches('/').to_string()
} else {
normalized
};
self.prefix = Some(normalized);
}
self
}
#[must_use]
pub fn tags(mut self, tags: Vec<impl Into<String>>) -> Self {
self.tags = tags.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
#[must_use]
pub fn dependency(mut self, dep: RouterDependency) -> Self {
self.dependencies.push(dep);
self
}
#[must_use]
pub fn dependencies(mut self, deps: Vec<RouterDependency>) -> Self {
self.dependencies = deps;
self
}
#[must_use]
pub fn response(mut self, status_code: u16, def: ResponseDef) -> Self {
self.responses.insert(status_code, def);
self
}
#[must_use]
pub fn responses(mut self, responses: HashMap<u16, ResponseDef>) -> Self {
self.responses = responses;
self
}
#[must_use]
pub fn deprecated(mut self, deprecated: bool) -> Self {
self.deprecated = Some(deprecated);
self
}
#[must_use]
pub fn include_in_schema(mut self, include: bool) -> Self {
self.include_in_schema = Some(include);
self
}
#[must_use]
pub fn get_prefix(&self) -> Option<&str> {
self.prefix.as_deref()
}
#[must_use]
pub fn get_tags(&self) -> &[String] {
&self.tags
}
#[must_use]
pub fn get_dependencies(&self) -> &[RouterDependency] {
&self.dependencies
}
#[must_use]
pub fn get_responses(&self) -> &HashMap<u16, ResponseDef> {
&self.responses
}
#[must_use]
pub fn get_deprecated(&self) -> Option<bool> {
self.deprecated
}
#[must_use]
pub fn get_include_in_schema(&self) -> Option<bool> {
self.include_in_schema
}
}
#[derive(Clone)]
pub struct RouterRoute {
pub method: Method,
pub path: String,
pub(crate) handler: Arc<BoxHandler>,
pub tags: Vec<String>,
pub dependencies: Vec<RouterDependency>,
pub deprecated: Option<bool>,
pub include_in_schema: bool,
}
impl std::fmt::Debug for RouterRoute {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouterRoute")
.field("method", &self.method)
.field("path", &self.path)
.field("tags", &self.tags)
.field("deprecated", &self.deprecated)
.field("include_in_schema", &self.include_in_schema)
.finish_non_exhaustive()
}
}
#[derive(Debug, Default)]
pub struct APIRouter {
prefix: String,
tags: Vec<String>,
dependencies: Vec<RouterDependency>,
responses: HashMap<u16, ResponseDef>,
deprecated: Option<bool>,
include_in_schema: bool,
routes: Vec<RouterRoute>,
}
impl APIRouter {
#[must_use]
pub fn new() -> Self {
Self {
prefix: String::new(),
tags: Vec::new(),
dependencies: Vec::new(),
responses: HashMap::new(),
deprecated: None,
include_in_schema: true,
routes: Vec::new(),
}
}
#[must_use]
pub fn with_prefix(prefix: impl Into<String>) -> Self {
Self::new().prefix(prefix)
}
#[must_use]
pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
let p = prefix.into();
if !p.is_empty() && !p.starts_with('/') {
self.prefix = format!("/{}", p);
} else {
self.prefix = p;
}
if self.prefix.ends_with('/') && self.prefix.len() > 1 {
self.prefix.pop();
}
self
}
#[must_use]
pub fn tags(mut self, tags: Vec<impl Into<String>>) -> Self {
self.tags = tags.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
#[must_use]
pub fn dependency(mut self, dep: RouterDependency) -> Self {
self.dependencies.push(dep);
self
}
#[must_use]
pub fn dependencies(mut self, deps: Vec<RouterDependency>) -> Self {
self.dependencies.extend(deps);
self
}
#[must_use]
pub fn response(mut self, status_code: u16, def: ResponseDef) -> Self {
self.responses.insert(status_code, def);
self
}
#[must_use]
pub fn responses(mut self, responses: HashMap<u16, ResponseDef>) -> Self {
self.responses = responses;
self
}
#[must_use]
pub fn deprecated(mut self, deprecated: bool) -> Self {
self.deprecated = Some(deprecated);
self
}
#[must_use]
pub fn include_in_schema(mut self, include: bool) -> Self {
self.include_in_schema = include;
self
}
#[must_use]
pub fn route<H, Fut>(mut self, path: impl Into<String>, method: Method, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
let boxed: BoxHandler = Box::new(move |ctx, req| {
let fut = handler(ctx, req);
Box::pin(fut)
});
self.routes.push(RouterRoute {
method,
path: path.into(),
handler: Arc::new(boxed),
tags: Vec::new(),
dependencies: Vec::new(),
deprecated: None,
include_in_schema: true,
});
self
}
#[must_use]
pub fn get<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Get, handler)
}
#[must_use]
pub fn post<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Post, handler)
}
#[must_use]
pub fn put<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Put, handler)
}
#[must_use]
pub fn delete<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Delete, handler)
}
#[must_use]
pub fn patch<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Patch, handler)
}
#[must_use]
pub fn options<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Options, handler)
}
#[must_use]
pub fn head<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Head, handler)
}
#[must_use]
pub fn include_router(self, other: APIRouter) -> Self {
self.include_router_with_config(other, IncludeConfig::default())
}
#[must_use]
pub fn include_router_with_config(mut self, other: APIRouter, config: IncludeConfig) -> Self {
let effective_deprecated = config.deprecated.or(other.deprecated);
let effective_include_in_schema =
config.include_in_schema.unwrap_or(other.include_in_schema);
let full_prefix = match config.prefix.as_deref() {
Some(config_prefix) => combine_paths(config_prefix, &other.prefix),
None => other.prefix.clone(),
};
for mut route in other.routes {
let combined_path = combine_paths(&full_prefix, &route.path);
route.path = combined_path;
let mut merged_tags = config.tags.clone();
merged_tags.extend(other.tags.clone());
merged_tags.extend(route.tags);
route.tags = merged_tags;
let mut merged_deps = config.dependencies.clone();
merged_deps.extend(other.dependencies.clone());
merged_deps.extend(route.dependencies);
route.dependencies = merged_deps;
if route.deprecated.is_none() {
route.deprecated = effective_deprecated;
}
if !effective_include_in_schema {
route.include_in_schema = false;
}
self.routes.push(route);
}
for (code, def) in config.responses {
self.responses.entry(code).or_insert(def);
}
for (code, def) in other.responses {
self.responses.insert(code, def);
}
self
}
#[must_use]
pub fn get_prefix(&self) -> &str {
&self.prefix
}
#[must_use]
pub fn get_tags(&self) -> &[String] {
&self.tags
}
#[must_use]
pub fn get_dependencies(&self) -> &[RouterDependency] {
&self.dependencies
}
#[must_use]
pub fn get_responses(&self) -> &HashMap<u16, ResponseDef> {
&self.responses
}
#[must_use]
pub fn is_deprecated(&self) -> Option<bool> {
self.deprecated
}
#[must_use]
pub fn get_include_in_schema(&self) -> bool {
self.include_in_schema
}
#[must_use]
pub fn get_routes(&self) -> &[RouterRoute] {
&self.routes
}
#[must_use]
pub fn into_route_entries(self) -> Vec<RouteEntry> {
let prefix = self.prefix;
let _router_tags = self.tags;
let router_deps = self.dependencies;
let _router_deprecated = self.deprecated;
let router_include_in_schema = self.include_in_schema;
self.routes
.into_iter()
.filter(|route| {
router_include_in_schema && route.include_in_schema
})
.map(move |route| {
let full_path = combine_paths(&prefix, &route.path);
let deps: Vec<RouterDependency> = router_deps
.iter()
.cloned()
.chain(route.dependencies)
.collect();
let handler = route.handler;
if deps.is_empty() {
RouteEntry::new(route.method, full_path, move |ctx, req| {
let handler = Arc::clone(&handler);
(handler)(ctx, req)
})
} else {
let deps = Arc::new(deps);
RouteEntry::new(route.method, full_path, move |ctx, req| {
let handler = Arc::clone(&handler);
let deps = Arc::clone(&deps);
Box::pin(async move {
for dep in deps.iter() {
if let Err(response) = dep.execute(ctx, req).await {
return response;
}
}
(handler)(ctx, req).await
})
})
}
})
.collect()
}
}
fn combine_paths(prefix: &str, path: &str) -> String {
match (prefix.is_empty(), path.is_empty()) {
(true, true) => "/".to_string(),
(true, false) => {
if path.starts_with('/') {
path.to_string()
} else {
format!("/{}", path)
}
}
(false, true) => prefix.to_string(),
(false, false) => {
let prefix = prefix.trim_end_matches('/');
let path = path.trim_start_matches('/');
if path.is_empty() {
prefix.to_string()
} else {
format!("{}/{}", prefix, path)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_combine_paths() {
assert_eq!(combine_paths("", ""), "/");
assert_eq!(combine_paths("", "/users"), "/users");
assert_eq!(combine_paths("", "users"), "/users");
assert_eq!(combine_paths("/api", ""), "/api");
assert_eq!(combine_paths("/api", "/users"), "/api/users");
assert_eq!(combine_paths("/api", "users"), "/api/users");
assert_eq!(combine_paths("/api/", "/users"), "/api/users");
assert_eq!(combine_paths("/api/", "users"), "/api/users");
}
#[test]
fn test_router_prefix_normalization() {
let router = APIRouter::new().prefix("api");
assert_eq!(router.get_prefix(), "/api");
let router = APIRouter::new().prefix("/api/");
assert_eq!(router.get_prefix(), "/api");
let router = APIRouter::new().prefix("/api/v1");
assert_eq!(router.get_prefix(), "/api/v1");
}
#[test]
fn test_router_tags() {
let router = APIRouter::new().tags(vec!["users", "admin"]).tag("api");
assert_eq!(router.get_tags(), &["users", "admin", "api"]);
}
#[test]
fn test_router_deprecated() {
let router = APIRouter::new().deprecated(true);
assert_eq!(router.is_deprecated(), Some(true));
let router = APIRouter::new();
assert_eq!(router.is_deprecated(), None);
}
#[test]
fn test_response_def() {
let def = ResponseDef::new("Success")
.with_example(serde_json::json!({"id": 1}))
.with_content_type("application/json");
assert_eq!(def.description, "Success");
assert_eq!(def.example, Some(serde_json::json!({"id": 1})));
assert_eq!(def.content_type, Some("application/json".to_string()));
}
#[test]
fn test_include_in_schema() {
let router = APIRouter::new().include_in_schema(false);
assert!(!router.get_include_in_schema());
let router = APIRouter::new();
assert!(router.get_include_in_schema());
}
#[test]
fn test_nested_routers_prefix_combination() {
let inner = APIRouter::new().prefix("/items");
assert_eq!(inner.get_prefix(), "/items");
let outer = APIRouter::new().prefix("/api/v1").include_router(inner);
assert_eq!(outer.get_prefix(), "/api/v1");
}
#[test]
fn test_router_with_responses() {
let router = APIRouter::new()
.response(200, ResponseDef::new("Success"))
.response(404, ResponseDef::new("Not Found"));
let responses = router.get_responses();
assert_eq!(responses.len(), 2);
assert!(responses.contains_key(&200));
assert!(responses.contains_key(&404));
}
#[test]
fn test_router_dependency_creation() {
let dep = RouterDependency::new("auth", |_ctx, _req| async { Ok(()) });
assert_eq!(dep.name, "auth");
}
#[test]
fn test_router_with_dependency() {
let dep = RouterDependency::new("auth", |_ctx, _req| async { Ok(()) });
let router = APIRouter::new().dependency(dep);
assert_eq!(router.get_dependencies().len(), 1);
assert_eq!(router.get_dependencies()[0].name, "auth");
}
#[test]
fn test_router_multiple_dependencies() {
let dep1 = RouterDependency::new("auth", |_ctx, _req| async { Ok(()) });
let dep2 = RouterDependency::new("rate_limit", |_ctx, _req| async { Ok(()) });
let router = APIRouter::new().dependencies(vec![dep1, dep2]);
assert_eq!(router.get_dependencies().len(), 2);
}
#[test]
fn test_tag_merging_with_nested_routers() {
let inner = APIRouter::new().tags(vec!["items"]);
let outer = APIRouter::new().tags(vec!["api"]).include_router(inner);
assert_eq!(outer.get_tags(), &["api"]);
}
#[test]
fn test_with_prefix_constructor() {
let router = APIRouter::with_prefix("/api/v1");
assert_eq!(router.get_prefix(), "/api/v1");
}
#[test]
fn test_empty_router() {
let router = APIRouter::new();
assert_eq!(router.get_prefix(), "");
assert!(router.get_tags().is_empty());
assert!(router.get_dependencies().is_empty());
assert!(router.get_responses().is_empty());
assert_eq!(router.is_deprecated(), None);
assert!(router.get_include_in_schema());
assert!(router.get_routes().is_empty());
}
#[test]
fn test_include_config_default() {
let config = IncludeConfig::new();
assert!(config.get_prefix().is_none());
assert!(config.get_tags().is_empty());
assert!(config.get_dependencies().is_empty());
assert!(config.get_responses().is_empty());
assert!(config.get_deprecated().is_none());
assert!(config.get_include_in_schema().is_none());
}
#[test]
fn test_include_config_prefix() {
let config = IncludeConfig::new().prefix("/api/v1");
assert_eq!(config.get_prefix(), Some("/api/v1"));
let config = IncludeConfig::new().prefix("api/v1");
assert_eq!(config.get_prefix(), Some("/api/v1"));
let config = IncludeConfig::new().prefix("/api/v1/");
assert_eq!(config.get_prefix(), Some("/api/v1"));
}
#[test]
fn test_include_config_tags() {
let config = IncludeConfig::new().tags(vec!["api", "v1"]).tag("extra");
assert_eq!(config.get_tags(), &["api", "v1", "extra"]);
}
#[test]
fn test_include_config_dependencies() {
let dep1 = RouterDependency::new("auth", |_ctx, _req| async { Ok(()) });
let dep2 = RouterDependency::new("rate_limit", |_ctx, _req| async { Ok(()) });
let config = IncludeConfig::new().dependency(dep1).dependency(dep2);
assert_eq!(config.get_dependencies().len(), 2);
}
#[test]
fn test_include_config_responses() {
let config = IncludeConfig::new()
.response(401, ResponseDef::new("Unauthorized"))
.response(500, ResponseDef::new("Server Error"));
assert_eq!(config.get_responses().len(), 2);
}
#[test]
fn test_include_config_deprecated() {
let config = IncludeConfig::new().deprecated(true);
assert_eq!(config.get_deprecated(), Some(true));
let config = IncludeConfig::new().deprecated(false);
assert_eq!(config.get_deprecated(), Some(false));
}
#[test]
fn test_include_config_include_in_schema() {
let config = IncludeConfig::new().include_in_schema(false);
assert_eq!(config.get_include_in_schema(), Some(false));
}
#[test]
fn test_merge_rule_prefix_prepending() {
let inner_router = APIRouter::new().prefix("/users");
let config = IncludeConfig::new().prefix("/api/v1");
let outer = APIRouter::new().include_router_with_config(inner_router, config);
assert_eq!(outer.get_routes().len(), 0); }
#[test]
fn test_merge_rule_tags_prepending() {
let inner = APIRouter::new().tags(vec!["users"]);
let config = IncludeConfig::new().tags(vec!["api", "v1"]);
let outer = APIRouter::new()
.tags(vec!["outer"])
.include_router_with_config(inner, config);
assert_eq!(outer.get_tags(), &["outer"]);
}
#[test]
fn test_merge_rule_deprecated_override() {
let inner = APIRouter::new().deprecated(false);
let config = IncludeConfig::new().deprecated(true);
let outer = APIRouter::new().include_router_with_config(inner, config);
assert_eq!(outer.is_deprecated(), None);
}
#[test]
fn test_merge_rule_include_in_schema_override() {
let inner = APIRouter::new().include_in_schema(true);
let config = IncludeConfig::new().include_in_schema(false);
let _outer = APIRouter::new().include_router_with_config(inner, config);
}
#[test]
fn test_merge_rule_responses_priority() {
let inner = APIRouter::new()
.response(200, ResponseDef::new("Router Success"))
.response(404, ResponseDef::new("Router Not Found"));
let config = IncludeConfig::new()
.response(200, ResponseDef::new("Config Success"))
.response(500, ResponseDef::new("Config Error"));
let outer = APIRouter::new().include_router_with_config(inner, config);
let responses = outer.get_responses();
assert_eq!(responses.get(&200).unwrap().description, "Router Success");
assert_eq!(responses.get(&500).unwrap().description, "Config Error");
assert_eq!(responses.get(&404).unwrap().description, "Router Not Found");
}
#[test]
fn test_recursive_router_inclusion() {
let level3 = APIRouter::new().prefix("/items");
let level2 = APIRouter::new().prefix("/users").include_router(level3);
let level1 = APIRouter::new().prefix("/api").include_router(level2);
assert_eq!(level1.get_prefix(), "/api");
}
#[test]
fn test_recursive_config_merging() {
let inner = APIRouter::new().tags(vec!["items"]);
let middle_config = IncludeConfig::new().tags(vec!["users"]).prefix("/users");
let outer_config = IncludeConfig::new().tags(vec!["api"]).prefix("/api");
let middle = APIRouter::new().include_router_with_config(inner, middle_config);
let outer = APIRouter::new().include_router_with_config(middle, outer_config);
assert!(outer.get_tags().is_empty());
}
#[test]
fn test_include_config_empty_prefix() {
let inner = APIRouter::new().prefix("/users");
let config = IncludeConfig::new();
let outer = APIRouter::new()
.prefix("/api")
.include_router_with_config(inner, config);
assert_eq!(outer.get_prefix(), "/api");
}
#[test]
fn test_multi_level_path_construction() {
let level1 = "/api";
let level2 = "/v1";
let level3 = "/users";
let level4 = "/{id}";
let combined_12 = combine_paths(level1, level2);
assert_eq!(combined_12, "/api/v1");
let combined_123 = combine_paths(&combined_12, level3);
assert_eq!(combined_123, "/api/v1/users");
let combined_1234 = combine_paths(&combined_123, level4);
assert_eq!(combined_1234, "/api/v1/users/{id}");
}
}