api_version/
lib.rs

1//! Axum middleware to rewrite a request such that a version prefix, e.g. `"/v0"`, is added to the
2//! path.
3
4use axum::{
5    RequestExt,
6    extract::Request,
7    http::{HeaderName, HeaderValue, StatusCode, Uri, uri::PathAndQuery},
8    response::{IntoResponse, Response},
9};
10use axum_extra::{
11    TypedHeader,
12    headers::{self, Header},
13};
14use futures::future::BoxFuture;
15use regex::Regex;
16use std::{
17    fmt::Debug,
18    ops::Deref,
19    sync::LazyLock,
20    task::{Context, Poll},
21};
22use tower::{Layer, Service};
23use tracing::debug;
24
25static VERSION: LazyLock<Regex> =
26    LazyLock::new(|| Regex::new(r#"^v(\d{1,4})$"#).expect("version regex is valid"));
27
28/// Axum middleware to rewrite a request such that a version prefix is added to the path. This is
29/// based on a set of API versions and an optional `"x-api-version"` custom HTTP header: if no such
30/// header is present, the highest version is used. Yet this only applies to requests the URIs of
31/// which start with the given base path, e.g. "/api"; others are not rewritten.  Also, paths
32/// starting with a valid/existing version prefix, e.g. `"/api/v0"`, are not rewritten.
33///
34/// # Examples
35///
36/// The middleware needs to be applied to the "root" router:
37///
38/// ```ignore
39/// let app = Router::new()
40///     .route("/ready", get(ok_0))
41///     .route("/api/v0/test", get(ok_0))
42///     .route("/api/v1/test", get(ok_1));
43///
44/// const API_VERSIONS: ApiVersions<2> = ApiVersions::new([0, 1]);
45///
46/// let mut app = ApiVersionLayer::new("/api", API_VERSIONS).layer(app);
47/// ```
48#[derive(Clone)]
49pub struct ApiVersionLayer<const N: usize> {
50    base_path: String,
51    versions: ApiVersions<N>,
52}
53
54impl<const N: usize> ApiVersionLayer<N> {
55    /// Create a new API version layer with the given base path and api versions.
56    ///
57    /// # Panics
58    ///
59    /// Panics if base path does not start with "/" or is empty.
60    pub fn new(base_path: impl AsRef<str>, versions: ApiVersions<N>) -> Self {
61        let base_path = base_path.as_ref().trim_end_matches('/').to_string();
62        assert!(base_path.starts_with('/'), "base path must start with '/'");
63        assert!(!base_path.len() > 1, "base path must not be empty");
64
65        Self {
66            base_path,
67            versions,
68        }
69    }
70}
71
72impl<const N: usize, S> Layer<S> for ApiVersionLayer<N> {
73    type Service = ApiVersionService<N, S>;
74
75    fn layer(&self, inner: S) -> Self::Service {
76        ApiVersionService {
77            inner,
78            base_path: self.base_path.clone(),
79            versions: self.versions,
80        }
81    }
82}
83
84/// API versions; a validated newtype for a `u16` array.
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub struct ApiVersions<const N: usize>([u16; N]);
87
88impl<const N: usize> ApiVersions<N> {
89    /// Create API versions. The given numbers must not be empty, must be strictly monotonically
90    /// increasing and less than `10_000`; otherwise `new` fails to compile in const contexts or
91    /// panics otherwise.
92    ///
93    /// # Examples
94    ///
95    /// Strictly monotonically versions `1` and `2` are valid:
96    ///
97    /// ```
98    /// # use api_version::ApiVersions;
99    /// const VERSIONS: ApiVersions<2> = ApiVersions::new([1, 2]);;
100    /// ```
101    ///
102    /// # Panics
103    ///
104    /// Empty versions or such that are not strictly monotonically increasing are invalid and fail
105    /// to compile in const contexts or panic otherwise.
106    ///
107    /// ```compile_fail
108    /// # use api_version::ApiVersions;
109    /// /// API versions must not be empty!
110    /// const VERSIONS: ApiVersions<0> = ApiVersions::new([]);
111    /// /// API versions must be strictly monotonically increasing!
112    /// const VERSIONS: ApiVersions<0> = ApiVersions::new([2, 1]);
113    /// /// API versions must be within 0u16..10_000!
114    /// const VERSIONS: ApiVersions<0> = ApiVersions::new([10_000]);
115    /// ```
116    pub const fn new(versions: [u16; N]) -> Self {
117        assert!(!versions.is_empty(), "API versions must not be empty");
118        assert!(
119            is_monotonically_increasing(versions),
120            "API versions must be strictly monotonically increasing"
121        );
122        assert!(
123            versions[N - 1] < 10_000,
124            "API versions must be within 0u16..10_000"
125        );
126
127        Self(versions)
128    }
129}
130
131impl<const N: usize> Deref for ApiVersions<N> {
132    type Target = [u16; N];
133
134    fn deref(&self) -> &Self::Target {
135        &self.0
136    }
137}
138
139/// See [ApiVersionLayer].
140#[derive(Clone)]
141pub struct ApiVersionService<const N: usize, S> {
142    inner: S,
143    base_path: String,
144    versions: ApiVersions<N>,
145}
146
147impl<const N: usize, S> Service<Request> for ApiVersionService<N, S>
148where
149    S: Service<Request, Response = Response> + Clone + Send + 'static,
150    S::Future: Send + 'static,
151{
152    type Response = S::Response;
153    type Error = S::Error;
154    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
155
156    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
157        self.inner.poll_ready(cx)
158    }
159
160    fn call(&mut self, mut request: Request) -> Self::Future {
161        let mut inner = self.inner.clone();
162        let base_path = self.base_path.clone();
163        let versions = self.versions;
164
165        Box::pin(async move {
166            // Strip base path prefix or return without rewriting.
167            let path = if let Some(path) = request.uri().path().strip_prefix(&base_path)
168                && path.starts_with('/')
169            {
170                path.to_owned()
171            } else {
172                debug!(
173                    uri = %request.uri(),
174                    "not rewriting the path, because does not start with base path"
175                );
176                return inner.call(request).await;
177            };
178
179            // Return without rewriting if stripped path starts with valid version prefix.
180            let has_version_prefix = versions
181                .iter()
182                .any(|version| path.starts_with(&format!("/v{version}/")));
183            if has_version_prefix {
184                debug!(
185                    uri = %request.uri(),
186                    "not rewriting the path, because starts with valid version prefix"
187                );
188                return inner.call(request).await;
189            }
190
191            // Determine version.
192            let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
193            let version = version
194                .as_ref()
195                .map(|TypedHeader(XApiVersion(v))| v)
196                .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
197            if !versions.contains(version) {
198                let response = (
199                    StatusCode::NOT_FOUND,
200                    format!("unknown version '{version}'"),
201                );
202                return Ok(response.into_response());
203            }
204            debug!(?version, "using API version");
205
206            // Insert version prefix into request URI.
207            let mut parts = request.uri().to_owned().into_parts();
208            let paq = parts.path_and_query.expect("uri has 'path and query'");
209            let mut paq_parts = paq.as_str().split('?').skip(1);
210            let paq = match paq_parts.next() {
211                Some(query) => format!("{base_path}/v{version}{path}?{query}"),
212                None => format!("{base_path}/v{version}{path}"),
213            };
214            let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
215            parts.path_and_query = Some(paq);
216            let uri = Uri::from_parts(parts).expect("parts are valid");
217
218            // Rewrite the request URI and run the downstream services.
219            debug!(original_uri = %request.uri(), %uri, "rewrote the path");
220            request.uri_mut().clone_from(&uri);
221            inner.call(request).await
222        })
223    }
224}
225
226/// Header name for the [XApiVersion] custom HTTP header.
227pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
228
229/// Custom HTTP header conveying the API version, which is expected to be a version designator
230/// starting with `'v'` followed by a number within `0u16..10_000` without leading zero, e.g. `v0`.
231#[derive(Debug)]
232pub struct XApiVersion(u16);
233
234impl Header for XApiVersion {
235    fn name() -> &'static HeaderName {
236        &X_API_VERSION
237    }
238
239    fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
240    where
241        Self: Sized,
242        I: Iterator<Item = &'i HeaderValue>,
243    {
244        values
245            .next()
246            .and_then(|v| v.to_str().ok())
247            .and_then(|s| VERSION.captures(s).and_then(|c| c.get(1)))
248            .and_then(|m| m.as_str().parse().ok())
249            .map(XApiVersion)
250            .ok_or_else(headers::Error::invalid)
251    }
252
253    fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
254        // We do not yet need to encode this header.
255        unimplemented!("not yet needed");
256    }
257}
258
259const fn is_monotonically_increasing<const N: usize>(versions: [u16; N]) -> bool {
260    if N < 2 {
261        return true;
262    }
263
264    let mut n = 1;
265    while n < N {
266        if versions[n - 1] >= versions[n] {
267            return false;
268        }
269        n += 1;
270    }
271
272    true
273}
274
275#[cfg(test)]
276mod tests {
277    use crate::{VERSION, is_monotonically_increasing};
278    use assert_matches::assert_matches;
279
280    #[test]
281    fn test_x_api_header() {
282        let version = VERSION
283            .captures("v0")
284            .and_then(|c| c.get(1))
285            .map(|m| m.as_str());
286        assert_matches!(version, Some("0"));
287
288        let version = VERSION
289            .captures("v1")
290            .and_then(|c| c.get(1))
291            .map(|m| m.as_str());
292        assert_matches!(version, Some("1"));
293
294        let version = VERSION
295            .captures("v99")
296            .and_then(|c| c.get(1))
297            .map(|m| m.as_str());
298        assert_matches!(version, Some("99"));
299
300        let version = VERSION
301            .captures("v9999")
302            .and_then(|c| c.get(1))
303            .map(|m| m.as_str());
304        assert_matches!(version, Some("9999"));
305
306        let version = VERSION
307            .captures("v10000")
308            .and_then(|c| c.get(1))
309            .map(|m| m.as_str());
310        assert_matches!(version, None);
311
312        let version = VERSION
313            .captures("vx")
314            .and_then(|c| c.get(1))
315            .map(|m| m.as_str());
316        assert_matches!(version, None);
317    }
318
319    #[test]
320    fn test_is_monotonically_increasing() {
321        assert!(is_monotonically_increasing([]));
322        assert!(is_monotonically_increasing([0]));
323        assert!(is_monotonically_increasing([0, 1]));
324
325        assert!(!is_monotonically_increasing([0, 0]));
326        assert!(!is_monotonically_increasing([1, 0]));
327    }
328}