use axum::{body::Body, extract::Request, response::Response, routing::MethodRouter, Router};
#[derive(Debug, Clone)]
pub struct VersionedRoute {
pub version: String,
pub path: String,
pub method: axum::http::Method,
pub handler: MethodRouter,
}
#[derive(Debug, Clone)]
pub struct VersionRouterConfig {
pub default_version: String,
pub supported_versions: Vec<String>,
pub redirect_unknown: bool,
pub deprecated_versions: std::collections::HashMap<String, String>,
pub sunset_header: String,
}
impl Default for VersionRouterConfig {
fn default() -> Self {
Self {
default_version: "v1".to_string(),
supported_versions: vec!["v1".to_string()],
redirect_unknown: true,
deprecated_versions: std::collections::HashMap::new(),
sunset_header: "Sunset".to_string(),
}
}
}
inventory::collect!(VersionedRoute);
pub fn build_version_router() -> Router {
let mut router = Router::new();
for route in inventory::iter::<VersionedRoute> {
let path = format!("/api/{}{}", route.version, route.path);
router = router.route(&path, route.handler.clone());
}
router
}
pub async fn version_redirect_middleware(
req: Request<Body>,
next: axum::middleware::Next,
) -> Response {
let uri = req.uri().path().to_string();
let config = VersionRouterConfig::default();
if let Some(path_after_api) = uri.strip_prefix("/api/") {
if path_after_api.starts_with("v") {
let end_of_version = path_after_api.find('/').unwrap_or(path_after_api.len());
let version_part = &path_after_api[..end_of_version];
if version_part
.chars()
.next()
.map(|c| c == 'v')
.unwrap_or(false)
&& version_part[1..].chars().all(|c| c.is_ascii_digit())
{
let mut response = next.run(req).await;
if let Some(sunset_date) = config.deprecated_versions.get(version_part) {
response.headers_mut().insert(
axum::http::header::HeaderName::from_static("deprecation"),
axum::http::HeaderValue::from_str("true").unwrap(),
);
response.headers_mut().insert(
axum::http::header::HeaderName::from_static("Sunset"),
axum::http::HeaderValue::from_str(sunset_date).unwrap(),
);
if let Some(newer_version) =
find_newer_version(version_part, &config.supported_versions)
{
let link_header =
format!("</api/{}>; rel=\"successor-version\"", newer_version);
response.headers_mut().insert(
axum::http::header::LINK,
axum::http::HeaderValue::from_str(&link_header).unwrap(),
);
}
}
return response;
}
}
let default_version = &config.default_version;
let path_without_version = if path_after_api.starts_with('/') {
path_after_api.to_string()
} else {
format!("/{}", path_after_api)
};
let new_uri = format!("/api/{}{}", default_version, path_without_version);
let mut response = Response::new(Body::empty());
*response.status_mut() = axum::http::StatusCode::MOVED_PERMANENTLY;
response.headers_mut().insert(
axum::http::header::LOCATION,
axum::http::HeaderValue::from_str(&new_uri)
.unwrap_or_else(|_| axum::http::HeaderValue::from_static("/")),
);
return response;
}
next.run(req).await
}
fn find_newer_version(current: &str, supported: &[String]) -> Option<String> {
let current_num = current[1..].parse::<u32>().ok()?;
let mut newer: Option<String> = None;
for version in supported {
if let Some(num) = version
.strip_prefix('v')
.and_then(|v| v.parse::<u32>().ok())
{
if num > current_num
&& (newer.is_none()
|| num < newer.as_ref().and_then(|v| v[1..].parse::<u32>().ok())?)
{
newer = Some(version.clone());
}
}
}
newer
}
#[macro_export]
macro_rules! define_versioned_route {
(version: $version:expr, path: $path:expr, method: $method:ident, handler: $handler:ident) => {
::inventory::submit!(sdforge::http::version_routing::VersionedRoute {
version: $version.to_string(),
path: $path.to_string(),
method: ::axum::http::Method::$method,
handler: ::axum::routing::MethodRouter::new().$method($handler),
});
};
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::routing::get;
use tower::ServiceExt;
async fn test_handler() -> &'static str {
"test response"
}
#[tokio::test]
async fn test_version_redirect() {
let router = Router::new()
.route("/api/v1/test", get(test_handler))
.layer(axum::middleware::from_fn(version_redirect_middleware));
let response = router
.clone()
.oneshot(
Request::builder()
.uri("/api/test")
.body(Body::empty())
.expect("Failed to build request"),
)
.await
.expect("Failed to handle request");
assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY);
assert_eq!(
response
.headers()
.get("location")
.expect("Location header not found"),
"/api/v1/test"
);
}
#[tokio::test]
async fn test_valid_version_passes() {
let router = Router::new()
.route("/api/v1/test", get(test_handler))
.layer(axum::middleware::from_fn(version_redirect_middleware));
let response = router
.clone()
.oneshot(
Request::builder()
.uri("/api/v1/test")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
}