api_version/
lib.rs

1//! Axum middleware to rewrite a request such that a version prefix, e.g. `"/v0"`, is added to the
2//! path.
3
4use axum::{
5    RequestExt,
6    extract::Request,
7    http::{HeaderName, HeaderValue, StatusCode, Uri, uri::PathAndQuery},
8    response::{IntoResponse, Response},
9};
10use axum_extra::{
11    TypedHeader,
12    headers::{self, Header},
13};
14use futures::future::BoxFuture;
15use regex::Regex;
16use std::{
17    fmt::Debug,
18    ops::Deref,
19    sync::LazyLock,
20    task::{Context, Poll},
21};
22use tower::{Layer, Service};
23use tracing::debug;
24
25static VERSION: LazyLock<Regex> =
26    LazyLock::new(|| Regex::new(r#"^v(\d{1,4})$"#).expect("version regex is valid"));
27
28/// Axum middleware to rewrite a request such that a version prefix is added to the path. This is
29/// based on a set of API versions and an optional `"x-api-version"` custom HTTP header: if no such
30/// header is present, the highest version is used. Yet this only applies to requests the URIs of
31/// which pass a filter; others are not rewritten.  Also, paths starting with a valid/existing
32/// version prefix, e.g. `"/v0"`, are not rewritten.
33///
34/// # Examples
35///
36/// The middleware needs to be applied to the "root" router:
37///
38/// ```ignore
39/// let app = Router::new()
40///     .route("/", get(ok_0))
41///     .route("/v0/test", get(ok_0))
42///     .route("/v1/test", get(ok_1))
43///     .route("/foo", get(ok_foo));
44///
45/// const API_VERSIONS: ApiVersions<2> = ApiVersions::new([0, 1]);
46///
47/// let mut app = ApiVersionLayer::new(API_VERSIONS, FooFilter).layer(app);
48/// ```
49#[derive(Clone)]
50pub struct ApiVersionLayer<const N: usize> {
51    base_path: String,
52    versions: ApiVersions<N>,
53}
54
55impl<const N: usize> ApiVersionLayer<N> {
56    /// Create a new API version layer.
57    pub fn new(base_path: impl AsRef<str>, versions: ApiVersions<N>) -> Self {
58        let base_path = base_path.as_ref().trim_end_matches('/').to_string();
59        assert!(base_path.starts_with('/'), "base path must start with '/'");
60        assert!(!base_path.len() > 1, "base path must not be empty");
61
62        Self {
63            base_path,
64            versions,
65        }
66    }
67}
68
69impl<const N: usize, S> Layer<S> for ApiVersionLayer<N> {
70    type Service = ApiVersionService<N, S>;
71
72    fn layer(&self, inner: S) -> Self::Service {
73        ApiVersionService {
74            inner,
75            base_path: self.base_path.clone(),
76            versions: self.versions,
77        }
78    }
79}
80
81/// API versions; a validated newtype for a `u16` array.
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub struct ApiVersions<const N: usize>([u16; N]);
84
85impl<const N: usize> ApiVersions<N> {
86    /// Create API versions. The given numbers must not be empty, must be strictly monotonically
87    /// increasing and less than `10_000`; otherwise `new` fails to compile in const contexts or
88    /// panics otherwise.
89    ///
90    /// # Examples
91    ///
92    /// Strictly monotonically versions `1` and `2` are valid:
93    ///
94    /// ```
95    /// # use api_version::ApiVersions;
96    /// const VERSIONS: ApiVersions<2> = ApiVersions::new([1, 2]);;
97    /// ```
98    ///
99    /// Empty versions or such that are not strictly monotonically increasing are invalid and fail
100    /// to compile in const contexts or panic otherwise.
101    ///
102    /// ```compile_fail
103    /// # use api_version::ApiVersions;
104    /// /// API versions must not be empty!
105    /// const VERSIONS: ApiVersions<0> = ApiVersions::new([]);
106    /// /// API versions must be strictly monotonically increasing!
107    /// const VERSIONS: ApiVersions<0> = ApiVersions::new([2, 1]);
108    /// /// API versions must be within 0u16..10_000!
109    /// const VERSIONS: ApiVersions<0> = ApiVersions::new([10_000]);
110    /// ```
111    pub const fn new(versions: [u16; N]) -> Self {
112        assert!(!versions.is_empty(), "API versions must not be empty");
113        assert!(
114            is_monotonically_increasing(versions),
115            "API versions must be strictly monotonically increasing"
116        );
117        assert!(
118            versions[N - 1] < 10_000,
119            "API versions must be within 0u16..10_000"
120        );
121
122        Self(versions)
123    }
124}
125
126impl<const N: usize> Deref for ApiVersions<N> {
127    type Target = [u16; N];
128
129    fn deref(&self) -> &Self::Target {
130        &self.0
131    }
132}
133
134/// See [ApiVersionLayer].
135#[derive(Clone)]
136pub struct ApiVersionService<const N: usize, S> {
137    inner: S,
138    base_path: String,
139    versions: ApiVersions<N>,
140}
141
142impl<const N: usize, S> Service<Request> for ApiVersionService<N, S>
143where
144    S: Service<Request, Response = Response> + Clone + Send + 'static,
145    S::Future: Send + 'static,
146{
147    type Response = S::Response;
148    type Error = S::Error;
149    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
150
151    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
152        self.inner.poll_ready(cx)
153    }
154
155    fn call(&mut self, mut request: Request) -> Self::Future {
156        let mut inner = self.inner.clone();
157        let base_path = self.base_path.clone();
158        let versions = self.versions;
159
160        Box::pin(async move {
161            // Return without rewriting for paths not starting with base path.
162            let Some(path) = request.uri().path().strip_prefix(&base_path) else {
163                debug!(
164                    uri = %request.uri(),
165                    "not rewriting the path, because does not start with base path"
166                );
167                return inner.call(request).await;
168            };
169            let path = path.to_owned();
170
171            // Return without rewriting for paths starting with a valid version prefix.
172            let has_version_prefix = versions
173                .iter()
174                .any(|version| path.starts_with(&format!("/v{version}/")));
175            if has_version_prefix {
176                debug!(
177                    uri = %request.uri(),
178                    "not rewriting the path, because starts with valid version prefix"
179                );
180                return inner.call(request).await;
181            }
182
183            // Determine API version.
184            let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
185            let version = version
186                .as_ref()
187                .map(|TypedHeader(XApiVersion(v))| v)
188                .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
189            if !versions.contains(version) {
190                let response = (
191                    StatusCode::NOT_FOUND,
192                    format!("unknown version '{version}'"),
193                );
194                return Ok(response.into_response());
195            }
196            debug!(?version, "using API version");
197
198            // Prepend the suitable prefix to the request URI.
199            let mut parts = request.uri().to_owned().into_parts();
200            let paq = parts.path_and_query.expect("uri has 'path and query'");
201            let mut paq_parts = paq.as_str().split('?').skip(1);
202            let paq = match paq_parts.next() {
203                Some(query) => format!("{base_path}/v{version}{path}?{query}"),
204                None => format!("{base_path}/v{version}{path}"),
205            };
206            let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
207            parts.path_and_query = Some(paq);
208            let uri = Uri::from_parts(parts).expect("parts are valid");
209
210            // Rewrite the request URI and run the downstream services.
211            debug!(original_uri = %request.uri(), %uri, "rewrote the path");
212            request.uri_mut().clone_from(&uri);
213            inner.call(request).await
214        })
215    }
216}
217
218/// Header name for the [XApiVersion] custom HTTP header.
219pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
220
221/// Custom HTTP header conveying the API version, which is expected to be a version designator
222/// starting with `'v'` followed by a number within `0u16..10_000` without leading zero, e.g. `v0`.
223#[derive(Debug)]
224pub struct XApiVersion(u16);
225
226impl Header for XApiVersion {
227    fn name() -> &'static HeaderName {
228        &X_API_VERSION
229    }
230
231    fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
232    where
233        Self: Sized,
234        I: Iterator<Item = &'i HeaderValue>,
235    {
236        values
237            .next()
238            .and_then(|v| v.to_str().ok())
239            .and_then(|s| VERSION.captures(s).and_then(|c| c.get(1)))
240            .and_then(|m| m.as_str().parse().ok())
241            .map(XApiVersion)
242            .ok_or_else(headers::Error::invalid)
243    }
244
245    fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
246        // We do not yet need to encode this header.
247        unimplemented!("not yet needed");
248    }
249}
250
251const fn is_monotonically_increasing<const N: usize>(versions: [u16; N]) -> bool {
252    if N < 2 {
253        return true;
254    }
255
256    let mut n = 1;
257    while n < N {
258        if versions[n - 1] >= versions[n] {
259            return false;
260        }
261        n += 1;
262    }
263
264    true
265}
266
267#[cfg(test)]
268mod tests {
269    use crate::{VERSION, is_monotonically_increasing};
270    use assert_matches::assert_matches;
271
272    #[test]
273    fn test_x_api_header() {
274        let version = VERSION
275            .captures("v0")
276            .and_then(|c| c.get(1))
277            .map(|m| m.as_str());
278        assert_matches!(version, Some("0"));
279
280        let version = VERSION
281            .captures("v1")
282            .and_then(|c| c.get(1))
283            .map(|m| m.as_str());
284        assert_matches!(version, Some("1"));
285
286        let version = VERSION
287            .captures("v99")
288            .and_then(|c| c.get(1))
289            .map(|m| m.as_str());
290        assert_matches!(version, Some("99"));
291
292        let version = VERSION
293            .captures("v9999")
294            .and_then(|c| c.get(1))
295            .map(|m| m.as_str());
296        assert_matches!(version, Some("9999"));
297
298        let version = VERSION
299            .captures("v10000")
300            .and_then(|c| c.get(1))
301            .map(|m| m.as_str());
302        assert_matches!(version, None);
303
304        let version = VERSION
305            .captures("vx")
306            .and_then(|c| c.get(1))
307            .map(|m| m.as_str());
308        assert_matches!(version, None);
309    }
310
311    #[test]
312    fn test_is_monotonically_increasing() {
313        assert!(is_monotonically_increasing([]));
314        assert!(is_monotonically_increasing([0]));
315        assert!(is_monotonically_increasing([0, 1]));
316
317        assert!(!is_monotonically_increasing([0, 0]));
318        assert!(!is_monotonically_increasing([1, 0]));
319    }
320}