api_version/
lib.rs

1//! Axum middleware to rewrite a request such that a version prefix, e.g. `"/v0"`, is added to the
2//! path.
3
4use axum::{
5    RequestExt,
6    extract::Request,
7    http::{HeaderName, HeaderValue, StatusCode, Uri, uri::PathAndQuery},
8    response::{IntoResponse, Response},
9};
10use axum_extra::{
11    TypedHeader,
12    headers::{self, Header},
13};
14use futures::future::BoxFuture;
15use regex::Regex;
16use std::{
17    convert::Infallible,
18    error::Error as StdError,
19    fmt::Debug,
20    ops::Deref,
21    sync::LazyLock,
22    task::{Context, Poll},
23};
24use tower::{Layer, Service};
25use tracing::{debug, error};
26
27static VERSION: LazyLock<Regex> =
28    LazyLock::new(|| Regex::new(r#"^v(\d{1,4})$"#).expect("version regex is valid"));
29
30/// Axum middleware to rewrite a request such that a version prefix is added to the path. This is
31/// based on a set of API versions and an optional `"x-api-version"` custom HTTP header: if no such
32/// header is present, the highest version is used. Yet this only applies to requests the URIs of
33/// which pass a filter; others are not rewritten.  Also, paths starting with a valid/existing
34/// version prefix, e.g. `"/v0"`, are not rewritten.
35///
36/// # Examples
37///
38/// The middleware needs to be applied to the "root" router:
39///
40/// ```ignore
41/// let app = Router::new()
42///     .route("/", get(ok_0))
43///     .route("/v0/test", get(ok_0))
44///     .route("/v1/test", get(ok_1))
45///     .route("/foo", get(ok_foo));
46///
47/// const API_VERSIONS: ApiVersions<2> = ApiVersions::new([0, 1]);
48///
49/// let mut app = ApiVersionLayer::new(API_VERSIONS, FooFilter).layer(app);
50/// ```
51#[derive(Clone)]
52pub struct ApiVersionLayer<const N: usize, F> {
53    versions: ApiVersions<N>,
54    filter: F,
55}
56
57impl<const N: usize, F> ApiVersionLayer<N, F> {
58    /// Create a new API version layer.
59    pub fn new(versions: ApiVersions<N>, filter: F) -> Self {
60        Self { versions, filter }
61    }
62}
63
64impl<const N: usize, S, F> Layer<S> for ApiVersionLayer<N, F>
65where
66    F: ApiVersionFilter,
67{
68    type Service = ApiVersionService<N, S, F>;
69
70    fn layer(&self, inner: S) -> Self::Service {
71        ApiVersionService {
72            inner,
73            versions: self.versions,
74            filter: self.filter.clone(),
75        }
76    }
77}
78
79/// API versions; a validated newtype for a `u16` array.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct ApiVersions<const N: usize>([u16; N]);
82
83impl<const N: usize> ApiVersions<N> {
84    /// Create API versions. The given numbers must not be empty, must be strictly monotonically
85    /// increasing and less than `10_000`; otherwise `new` fails to compile in const contexts or
86    /// panics otherwise.
87    ///
88    /// # Examples
89    ///
90    /// Strictly monotonically versions `1` and `2` are valid:
91    ///
92    /// ```
93    /// # use api_version::ApiVersions;
94    /// const VERSIONS: ApiVersions<2> = ApiVersions::new([1, 2]);;
95    /// ```
96    ///
97    /// Empty versions or such that are not strictly monotonically increasing are invalid and fail
98    /// to compile in const contexts or panic otherwise.
99    ///
100    /// ```compile_fail
101    /// # use api_version::ApiVersions;
102    /// /// API versions must not be empty!
103    /// const VERSIONS: ApiVersions<0> = ApiVersions::new([]);
104    /// /// API versions must be strictly monotonically increasing!
105    /// const VERSIONS: ApiVersions<0> = ApiVersions::new([2, 1]);
106    /// /// API versions must be within 0u16..10_000!
107    /// const VERSIONS: ApiVersions<0> = ApiVersions::new([10_000]);
108    /// ```
109    pub const fn new(versions: [u16; N]) -> Self {
110        assert!(!versions.is_empty(), "API versions must not be empty");
111        assert!(
112            is_monotonically_increasing(versions),
113            "API versions must be strictly monotonically increasing"
114        );
115        assert!(
116            versions[N - 1] < 10_000,
117            "API versions must be within 0u16..10_000"
118        );
119
120        Self(versions)
121    }
122}
123
124impl<const N: usize> Deref for ApiVersions<N> {
125    type Target = [u16; N];
126
127    fn deref(&self) -> &Self::Target {
128        &self.0
129    }
130}
131
132/// Filter to determine which requests are rewritten.
133#[trait_variant::make(Send)]
134pub trait ApiVersionFilter: Clone + Send + 'static {
135    type Error: std::error::Error;
136
137    /// Should a request with the given URI be rewritten.
138    async fn should_rewrite(&self, uri: &Uri) -> Result<bool, Self::Error>;
139}
140
141/// [ApiVersionFilter] making all requests be rewritten.
142#[derive(Clone, Copy)]
143pub struct All;
144
145impl ApiVersionFilter for All {
146    type Error = Infallible;
147
148    async fn should_rewrite(&self, _uri: &Uri) -> Result<bool, Self::Error> {
149        Ok(true)
150    }
151}
152
153/// See [ApiVersionLayer].
154#[derive(Clone)]
155pub struct ApiVersionService<const N: usize, S, F> {
156    inner: S,
157    versions: ApiVersions<N>,
158    filter: F,
159}
160
161impl<const N: usize, S, F> Service<Request> for ApiVersionService<N, S, F>
162where
163    S: Service<Request, Response = Response> + Clone + Send + 'static,
164    S::Future: Send + 'static,
165    F: ApiVersionFilter,
166{
167    type Response = S::Response;
168    type Error = S::Error;
169    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
170
171    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
172        self.inner.poll_ready(cx)
173    }
174
175    fn call(&mut self, mut request: Request) -> Self::Future {
176        let mut inner = self.inner.clone();
177        let versions = self.versions;
178        let filter = self.filter.clone();
179
180        Box::pin(async move {
181            // Return without rewriting for paths starting with a valid version prefix.
182            let has_version_prefix = versions
183                .iter()
184                .any(|version| request.uri().path().starts_with(&format!("/v{version}/")));
185            if has_version_prefix {
186                debug!(
187                    uri = %request.uri(),
188                    "not rewriting the path, because starts with valid version prefix"
189                );
190                return inner.call(request).await;
191            }
192
193            // Apply filter and possibly return without rewriting.
194            let should_rewrite = match filter.should_rewrite(request.uri()).await {
195                Ok(should_rewrite) => should_rewrite,
196
197                Err(error) => {
198                    error!(error = chained_sources(error), "cannot apply filter");
199                    return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
200                }
201            };
202            if !should_rewrite {
203                debug!(uri = %request.uri(), "not rewriting the path, because URI filtered");
204                return inner.call(request).await;
205            }
206
207            // Determine API version.
208            let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
209            let version = version
210                .as_ref()
211                .map(|TypedHeader(XApiVersion(v))| v)
212                .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
213            if !versions.contains(version) {
214                let response = (
215                    StatusCode::NOT_FOUND,
216                    format!("unknown version '{version}'"),
217                );
218                return Ok(response.into_response());
219            }
220            debug!(?version, "using API version");
221
222            // Prepend the suitable prefix to the request URI.
223            let mut parts = request.uri().to_owned().into_parts();
224            let paq = parts.path_and_query.expect("uri has 'path and query'");
225            let mut paq_parts = paq.as_str().split('?');
226            let path = paq_parts.next().expect("uri has path");
227            let paq = match paq_parts.next() {
228                Some(query) => format!("/v{version}{path}?{query}"),
229                None if path != "/" => format!("/v{version}{path}"),
230                None => format!("/v{version}"),
231            };
232            let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
233            parts.path_and_query = Some(paq);
234            let uri = Uri::from_parts(parts).expect("parts are valid");
235
236            // Rewrite the request URI and run the downstream services.
237            debug!(original_uri = %request.uri(), %uri, "rewrote the path");
238            request.uri_mut().clone_from(&uri);
239            inner.call(request).await
240        })
241    }
242}
243
244/// Header name for the [XApiVersion] custom HTTP header.
245pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
246
247/// Custom HTTP header conveying the API version, which is expected to be a version designator
248/// starting with `'v'` followed by a number within `0u16..10_000` without leading zero, e.g. `v0`.
249#[derive(Debug)]
250pub struct XApiVersion(u16);
251
252impl Header for XApiVersion {
253    fn name() -> &'static HeaderName {
254        &X_API_VERSION
255    }
256
257    fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
258    where
259        Self: Sized,
260        I: Iterator<Item = &'i HeaderValue>,
261    {
262        values
263            .next()
264            .and_then(|v| v.to_str().ok())
265            .and_then(|s| VERSION.captures(s).and_then(|c| c.get(1)))
266            .and_then(|m| m.as_str().parse().ok())
267            .map(XApiVersion)
268            .ok_or_else(headers::Error::invalid)
269    }
270
271    fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
272        // We do not yet need to encode this header.
273        unimplemented!("not yet needed");
274    }
275}
276
277fn chained_sources<E>(error: E) -> String
278where
279    E: StdError,
280{
281    let mut sources = vec![];
282    sources.push(error.to_string());
283
284    let mut source = error.source();
285    while let Some(s) = source {
286        sources.push(s.to_string());
287        source = s.source();
288    }
289
290    sources.join(": ")
291}
292
293const fn is_monotonically_increasing<const N: usize>(versions: [u16; N]) -> bool {
294    if N < 2 {
295        return true;
296    }
297
298    let mut n = 1;
299    while n < N {
300        if versions[n - 1] >= versions[n] {
301            return false;
302        }
303        n += 1;
304    }
305
306    true
307}
308
309#[cfg(test)]
310mod tests {
311    use crate::{VERSION, is_monotonically_increasing};
312    use assert_matches::assert_matches;
313
314    #[test]
315    fn test_x_api_header() {
316        let version = VERSION
317            .captures("v0")
318            .and_then(|c| c.get(1))
319            .map(|m| m.as_str());
320        assert_matches!(version, Some("0"));
321
322        let version = VERSION
323            .captures("v1")
324            .and_then(|c| c.get(1))
325            .map(|m| m.as_str());
326        assert_matches!(version, Some("1"));
327
328        let version = VERSION
329            .captures("v99")
330            .and_then(|c| c.get(1))
331            .map(|m| m.as_str());
332        assert_matches!(version, Some("99"));
333
334        let version = VERSION
335            .captures("v9999")
336            .and_then(|c| c.get(1))
337            .map(|m| m.as_str());
338        assert_matches!(version, Some("9999"));
339
340        let version = VERSION
341            .captures("v10000")
342            .and_then(|c| c.get(1))
343            .map(|m| m.as_str());
344        assert_matches!(version, None);
345
346        let version = VERSION
347            .captures("vx")
348            .and_then(|c| c.get(1))
349            .map(|m| m.as_str());
350        assert_matches!(version, None);
351    }
352
353    #[test]
354    fn test_is_monotonically_increasing() {
355        assert!(is_monotonically_increasing([]));
356        assert!(is_monotonically_increasing([0]));
357        assert!(is_monotonically_increasing([0, 1]));
358
359        assert!(!is_monotonically_increasing([0, 0]));
360        assert!(!is_monotonically_increasing([1, 0]));
361    }
362}