use std::collections::BTreeMap;
use std::sync::OnceLock;
use axum::Router;
use axum::handler::Handler;
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct RouteSpec {
pub path: String,
pub methods: Vec<&'static str>,
}
impl RouteSpec {
pub fn new<P: Into<String>>(path: P, methods: Vec<&'static str>) -> Self {
Self {
path: path.into(),
methods,
}
}
}
impl From<&str> for RouteSpec {
fn from(path: &str) -> Self {
Self {
path: path.to_string(),
methods: Vec::new(),
}
}
}
impl From<String> for RouteSpec {
fn from(path: String) -> Self {
Self {
path,
methods: Vec::new(),
}
}
}
impl From<(&'static str, &str)> for RouteSpec {
fn from((method, path): (&'static str, &str)) -> Self {
Self {
path: path.to_string(),
methods: vec![method],
}
}
}
impl From<(&'static str, String)> for RouteSpec {
fn from((method, path): (&'static str, String)) -> Self {
Self {
path,
methods: vec![method],
}
}
}
impl From<(&[&'static str], &str)> for RouteSpec {
fn from((methods, path): (&[&'static str], &str)) -> Self {
Self {
path: path.to_string(),
methods: methods.to_vec(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RouteRegistry {
pub by_plugin: BTreeMap<String, Vec<RouteSpec>>,
}
impl RouteRegistry {
pub fn total(&self) -> usize {
self.by_plugin.values().map(|v| v.len()).sum()
}
}
#[must_use = "Routes must be passed to AppBuilder::routes to take effect"]
pub struct Routes {
inner: Router,
specs: Vec<RouteSpec>,
}
impl Routes {
pub fn new() -> Self {
Self {
inner: Router::new(),
specs: Vec::new(),
}
}
pub fn get<H, T>(self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.with_method("GET", path, axum::routing::get(handler))
}
pub fn post<H, T>(self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.with_method("POST", path, axum::routing::post(handler))
}
pub fn put<H, T>(self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.with_method("PUT", path, axum::routing::put(handler))
}
pub fn patch<H, T>(self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.with_method("PATCH", path, axum::routing::patch(handler))
}
pub fn delete<H, T>(self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.with_method("DELETE", path, axum::routing::delete(handler))
}
pub fn head<H, T>(self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.with_method("HEAD", path, axum::routing::head(handler))
}
pub fn options<H, T>(self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.with_method("OPTIONS", path, axum::routing::options(handler))
}
pub fn layered(
self,
method: &'static str,
path: &str,
handler: axum::routing::MethodRouter<()>,
) -> Self {
self.route(&[method], path, handler)
}
pub fn route(
mut self,
methods: &[&'static str],
path: &str,
handler: axum::routing::MethodRouter<()>,
) -> Self {
self.specs.push(RouteSpec {
path: path.to_string(),
methods: methods.to_vec(),
});
self.inner = self.inner.route(path, handler);
self
}
pub fn with_router(mut self, router: Router) -> Self {
self.inner = self.inner.merge(router);
self
}
pub fn into_parts(self) -> (Router, Vec<RouteSpec>) {
(self.inner, self.specs)
}
fn with_method(
mut self,
method: &'static str,
path: &str,
handler: axum::routing::MethodRouter<()>,
) -> Self {
self.specs.push(RouteSpec {
path: path.to_string(),
methods: vec![method],
});
self.inner = self.inner.route(path, handler);
self
}
}
impl Default for Routes {
fn default() -> Self {
Self::new()
}
}
static REGISTRY: OnceLock<RouteRegistry> = OnceLock::new();
pub fn init(registry: RouteRegistry) {
let _ = REGISTRY.set(registry);
}
pub fn get() -> Option<&'static RouteRegistry> {
REGISTRY.get()
}
static OPENAPI_REGISTRY: OnceLock<Vec<(String, serde_json::Value)>> = OnceLock::new();
pub fn init_openapi(entries: Vec<(String, serde_json::Value)>) {
let _ = OPENAPI_REGISTRY.set(entries);
}
pub fn registered_openapi_paths() -> Option<&'static [(String, serde_json::Value)]> {
OPENAPI_REGISTRY.get().map(|v| v.as_slice())
}
static OPENAPI_SPEC_URL: OnceLock<String> = OnceLock::new();
pub fn init_openapi_spec_url(url: String) {
let _ = OPENAPI_SPEC_URL.set(url);
}
pub fn registered_openapi_spec_url() -> Option<&'static str> {
OPENAPI_SPEC_URL.get().map(|s| s.as_str())
}
#[cfg(test)]
mod tests {
use super::*;
async fn dummy_get() -> &'static str {
"ok"
}
async fn dummy_post() -> &'static str {
"ok"
}
#[test]
fn routes_builder_records_one_spec_per_get_with_method_and_path() {
let (_router, specs) = Routes::new()
.get("/", dummy_get)
.get("/articles", dummy_get)
.post("/api/articles", dummy_post)
.into_parts();
assert_eq!(specs.len(), 3, "one spec per builder call: {specs:?}");
assert_eq!(specs[0].path, "/");
assert_eq!(specs[0].methods, vec!["GET"]);
assert_eq!(specs[1].path, "/articles");
assert_eq!(specs[1].methods, vec!["GET"]);
assert_eq!(specs[2].path, "/api/articles");
assert_eq!(specs[2].methods, vec!["POST"]);
}
#[test]
fn routes_builder_supports_multi_method_via_route() {
use axum::routing::get;
let (_router, specs) = Routes::new()
.route(
&["GET", "POST"],
"/api/comments",
get(dummy_get).post(dummy_post),
)
.into_parts();
assert_eq!(specs.len(), 1);
assert_eq!(specs[0].path, "/api/comments");
assert_eq!(specs[0].methods, vec!["GET", "POST"]);
}
#[test]
fn routes_with_router_merges_axum_router_silently() {
use axum::Router;
use axum::routing::get;
let external = Router::new().route("/external", get(dummy_get));
let (_router, specs) = Routes::new()
.get("/tracked", dummy_get)
.with_router(external)
.into_parts();
assert_eq!(specs.len(), 1);
assert_eq!(specs[0].path, "/tracked");
}
#[test]
fn total_sums_per_plugin_paths_and_handles_empty_groups() {
let mut reg = RouteRegistry::default();
reg.by_plugin
.insert("app".to_string(), vec!["/".into(), "/articles".into()]);
reg.by_plugin.insert(
"admin".to_string(),
vec![
"/admin/".into(),
"/admin/login".into(),
"/admin/logout".into(),
],
);
reg.by_plugin.insert("sessions".to_string(), Vec::new());
assert_eq!(reg.total(), 5);
}
}