use axum::{
RequestExt,
extract::Request,
http::{HeaderName, HeaderValue, StatusCode, Uri, uri::PathAndQuery},
response::{IntoResponse, Response},
};
use axum_extra::{
TypedHeader,
headers::{self, Header},
};
use futures::future::BoxFuture;
use regex::Regex;
use std::{
fmt::Debug,
ops::Deref,
sync::LazyLock,
task::{Context, Poll},
};
use tower::{Layer, Service};
use tracing::debug;
static VERSION: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r#"^v(\d{1,4})$"#).expect("version regex is valid"));
#[derive(Clone)]
pub struct ApiVersionLayer<const N: usize> {
base_path: String,
versions: ApiVersions<N>,
}
impl<const N: usize> ApiVersionLayer<N> {
pub fn new(base_path: impl AsRef<str>, versions: ApiVersions<N>) -> Self {
let base_path = base_path.as_ref().trim_end_matches('/').to_string();
assert!(base_path.starts_with('/'), "base path must start with '/'");
assert!(!base_path.len() > 1, "base path must not be empty");
Self {
base_path,
versions,
}
}
}
impl<const N: usize, S> Layer<S> for ApiVersionLayer<N> {
type Service = ApiVersionService<N, S>;
fn layer(&self, inner: S) -> Self::Service {
ApiVersionService {
inner,
base_path: self.base_path.clone(),
versions: self.versions,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ApiVersions<const N: usize>([u16; N]);
impl<const N: usize> ApiVersions<N> {
pub const fn new(versions: [u16; N]) -> Self {
assert!(!versions.is_empty(), "API versions must not be empty");
assert!(
is_monotonically_increasing(versions),
"API versions must be strictly monotonically increasing"
);
assert!(
versions[N - 1] < 10_000,
"API versions must be within 0u16..10_000"
);
Self(versions)
}
}
impl<const N: usize> Deref for ApiVersions<N> {
type Target = [u16; N];
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Clone)]
pub struct ApiVersionService<const N: usize, S> {
inner: S,
base_path: String,
versions: ApiVersions<N>,
}
impl<const N: usize, S> Service<Request> for ApiVersionService<N, S>
where
S: Service<Request, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut request: Request) -> Self::Future {
let mut inner = self.inner.clone();
let base_path = self.base_path.clone();
let versions = self.versions;
Box::pin(async move {
let path = if let Some(path) = request.uri().path().strip_prefix(&base_path)
&& path.starts_with('/')
{
path.to_owned()
} else {
debug!(
uri = %request.uri(),
"not rewriting the path, because does not start with base path"
);
return inner.call(request).await;
};
let has_version_prefix = versions
.iter()
.any(|version| path.starts_with(&format!("/v{version}/")));
if has_version_prefix {
debug!(
uri = %request.uri(),
"not rewriting the path, because starts with valid version prefix"
);
return inner.call(request).await;
}
let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
let version = version
.as_ref()
.map(|TypedHeader(XApiVersion(v))| v)
.unwrap_or_else(|_| versions.last().expect("versions is not empty"));
if !versions.contains(version) {
let response = (
StatusCode::NOT_FOUND,
format!("unknown version '{version}'"),
);
return Ok(response.into_response());
}
debug!(?version, "using API version");
let mut parts = request.uri().to_owned().into_parts();
let paq = parts.path_and_query.expect("uri has 'path and query'");
let mut paq_parts = paq.as_str().split('?').skip(1);
let paq = match paq_parts.next() {
Some(query) => format!("{base_path}/v{version}{path}?{query}"),
None => format!("{base_path}/v{version}{path}"),
};
let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
parts.path_and_query = Some(paq);
let uri = Uri::from_parts(parts).expect("parts are valid");
debug!(original_uri = %request.uri(), %uri, "rewrote the path");
request.uri_mut().clone_from(&uri);
inner.call(request).await
})
}
}
pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
#[derive(Debug)]
pub struct XApiVersion(u16);
impl Header for XApiVersion {
fn name() -> &'static HeaderName {
&X_API_VERSION
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i HeaderValue>,
{
values
.next()
.and_then(|v| v.to_str().ok())
.and_then(|s| VERSION.captures(s).and_then(|c| c.get(1)))
.and_then(|m| m.as_str().parse().ok())
.map(XApiVersion)
.ok_or_else(headers::Error::invalid)
}
fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
unimplemented!("not yet needed");
}
}
const fn is_monotonically_increasing<const N: usize>(versions: [u16; N]) -> bool {
if N < 2 {
return true;
}
let mut n = 1;
while n < N {
if versions[n - 1] >= versions[n] {
return false;
}
n += 1;
}
true
}
#[cfg(test)]
mod tests {
use crate::{VERSION, is_monotonically_increasing};
use assert_matches::assert_matches;
#[test]
fn test_x_api_header() {
let version = VERSION
.captures("v0")
.and_then(|c| c.get(1))
.map(|m| m.as_str());
assert_matches!(version, Some("0"));
let version = VERSION
.captures("v1")
.and_then(|c| c.get(1))
.map(|m| m.as_str());
assert_matches!(version, Some("1"));
let version = VERSION
.captures("v99")
.and_then(|c| c.get(1))
.map(|m| m.as_str());
assert_matches!(version, Some("99"));
let version = VERSION
.captures("v9999")
.and_then(|c| c.get(1))
.map(|m| m.as_str());
assert_matches!(version, Some("9999"));
let version = VERSION
.captures("v10000")
.and_then(|c| c.get(1))
.map(|m| m.as_str());
assert_matches!(version, None);
let version = VERSION
.captures("vx")
.and_then(|c| c.get(1))
.map(|m| m.as_str());
assert_matches!(version, None);
}
#[test]
fn test_is_monotonically_increasing() {
assert!(is_monotonically_increasing([]));
assert!(is_monotonically_increasing([0]));
assert!(is_monotonically_increasing([0, 1]));
assert!(!is_monotonically_increasing([0, 0]));
assert!(!is_monotonically_increasing([1, 0]));
}
}