api_version/
lib.rs

1pub use array_macro;
2
3use axum::{
4    extract::Request,
5    http::{uri::PathAndQuery, HeaderName, HeaderValue, StatusCode, Uri},
6    response::{IntoResponse, Response},
7    RequestExt,
8};
9use axum_extra::{
10    headers::{self, Header},
11    TypedHeader,
12};
13use futures::future::BoxFuture;
14use regex::Regex;
15use std::{
16    convert::Infallible,
17    error::Error as StdError,
18    fmt::Debug,
19    future::Future,
20    sync::OnceLock,
21    task::{Context, Poll},
22};
23use thiserror::Error;
24use tower::{Layer, Service};
25use tracing::{debug, error};
26
27/// Create an [ApiVersionLayer] correctly initialized with non-empty and strictly monotonically
28/// increasing versions in the given inclusive range.
29#[macro_export]
30macro_rules! api_version {
31    ($from:literal..=$to:literal) => {
32        {
33            $crate::api_version!($from..=$to, $crate::All)
34        }
35    };
36
37    ($from:literal..=$to:literal, $filter:expr) => {
38        {
39            let versions = $crate::array_macro::array![n => n as u16 + $from; $to - $from + 1];
40            $crate::ApiVersionLayer::new(versions, $filter).expect("versions are valid")
41        }
42    };
43}
44
45/// Axum middleware to rewrite a request such that a version prefix is added to the path. This is
46/// based on a set of versions and an optional `"x-api-version"` custom HTTP header: if no such
47/// header is present, the highest version is used. Yet this only applies to requests the URIs of
48/// which pass a filter; others are not rewritten.
49///
50/// Paths must not start with a version prefix, e.g. `"/v0"`.
51#[derive(Clone)]
52pub struct ApiVersionLayer<const N: usize, F> {
53    versions: [u16; N],
54    filter: F,
55}
56
57impl<const N: usize, F> ApiVersionLayer<N, F> {
58    /// Create a new [ApiVersionLayer].
59    ///
60    /// The given versions must not be empty and must be strictly monotonically increasing, e.g.
61    /// `[0, 1, 2]`.
62    pub fn new(versions: [u16; N], filter: F) -> Result<Self, NewApiVersionLayerError> {
63        if versions.is_empty() {
64            return Err(NewApiVersionLayerError::Empty);
65        }
66
67        if versions.as_slice().windows(2).any(|w| w[0] >= w[1]) {
68            return Err(NewApiVersionLayerError::NotIncreasing);
69        }
70
71        Ok(Self { versions, filter })
72    }
73}
74
75impl<const N: usize, S, F> Layer<S> for ApiVersionLayer<N, F>
76where
77    F: ApiVersionFilter,
78{
79    type Service = ApiVersion<N, S, F>;
80
81    fn layer(&self, inner: S) -> Self::Service {
82        ApiVersion {
83            inner,
84            versions: self.versions,
85            filter: self.filter.clone(),
86        }
87    }
88}
89
90/// Determine which requests are rewritten.
91pub trait ApiVersionFilter: Clone + Send + 'static {
92    type Error: std::error::Error;
93
94    /// Requests are only rewritten, if the given URI passes, i.e. results in `true`.
95    fn filter(&self, uri: &Uri) -> impl Future<Output = Result<bool, Self::Error>> + Send;
96}
97
98/// [ApiVersionFilter] making all requests be rewritten.
99#[derive(Clone, Copy)]
100pub struct All;
101
102impl ApiVersionFilter for All {
103    type Error = Infallible;
104
105    async fn filter(&self, _uri: &Uri) -> Result<bool, Self::Error> {
106        Ok(true)
107    }
108}
109
110/// Error creating an [ApiVersionLayer].
111#[derive(Debug, Error)]
112pub enum NewApiVersionLayerError {
113    #[error("versions must not be empty")]
114    Empty,
115
116    #[error("versions must be strictly monotonically increasing")]
117    NotIncreasing,
118}
119
120/// See [ApiVersionLayer].
121#[derive(Clone)]
122pub struct ApiVersion<const N: usize, S, F> {
123    inner: S,
124    versions: [u16; N],
125    filter: F,
126}
127
128impl<const N: usize, S, F> Service<Request> for ApiVersion<N, S, F>
129where
130    S: Service<Request, Response = Response> + Clone + Send + 'static,
131    S::Future: Send + 'static,
132    F: ApiVersionFilter,
133{
134    type Response = S::Response;
135    type Error = S::Error;
136    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
137
138    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139        self.inner.poll_ready(cx)
140    }
141
142    fn call(&mut self, mut request: Request) -> Self::Future {
143        let mut inner = self.inner.clone();
144        let versions = self.versions;
145        let filter = self.filter.clone();
146
147        Box::pin(async move {
148            // Do not allow the path to start with one of the valid version prefixes.
149            if versions
150                .iter()
151                .any(|version| request.uri().path().starts_with(&format!("/v{version}")))
152            {
153                let response = (
154                    StatusCode::BAD_REQUEST,
155                    "path must not start with version prefix like '/v0'",
156                );
157                return Ok(response.into_response());
158            }
159
160            let pass_through = match filter.filter(request.uri()).await {
161                Ok(pass_through) => pass_through,
162
163                Err(error) => {
164                    error!(error = error.as_chain(), "cannot apply filter");
165                    return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
166                }
167            };
168
169            if !pass_through {
170                debug!(uri = %request.uri(), "not rewriting the path");
171                return inner.call(request).await;
172            }
173
174            // Determine API version.
175            let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
176            let version = version
177                .as_ref()
178                .map(|TypedHeader(XApiVersion(v))| v)
179                .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
180            if !versions.contains(version) {
181                let response = (
182                    StatusCode::NOT_FOUND,
183                    format!("unknown version '{version}'"),
184                );
185                return Ok(response.into_response());
186            }
187            debug!(?version, "using API version");
188
189            // Prepend the suitable prefix to the request URI.
190            let mut parts = request.uri().to_owned().into_parts();
191            let paq = parts.path_and_query.expect("uri has 'path and query'");
192            let mut paq_parts = paq.as_str().split('?');
193            let path = paq_parts.next().expect("uri has path");
194            let paq = match paq_parts.next() {
195                Some(query) => format!("/v{version}{path}?{query}"),
196                None if path != "/" => format!("/v{version}{path}"),
197                None => format!("/v{version}"),
198            };
199            let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
200            parts.path_and_query = Some(paq);
201            let uri = Uri::from_parts(parts).expect("parts are valid");
202
203            // Rewrite the request URI and run the downstream services.
204            debug!(original_uri = %request.uri(), %uri, "rewrote the path");
205            request.uri_mut().clone_from(&uri);
206            inner.call(request).await
207        })
208    }
209}
210
211/// Header name for the [XApiVersion] custom HTTP header.
212pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
213
214/// Custom HTTP header conveying the API version, which is expected to be a version designator
215/// starting with `'v'` followed by a number from 0..+99 without leading zero, e.g. `v0`.
216#[derive(Debug)]
217pub struct XApiVersion(u16);
218
219impl Header for XApiVersion {
220    fn name() -> &'static HeaderName {
221        &X_API_VERSION
222    }
223
224    fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
225    where
226        Self: Sized,
227        I: Iterator<Item = &'i HeaderValue>,
228    {
229        values
230            .next()
231            .and_then(|v| v.to_str().ok())
232            .and_then(|s| version().captures(s).and_then(|c| c.get(1)))
233            .and_then(|m| m.as_str().parse().ok())
234            .map(XApiVersion)
235            .ok_or_else(headers::Error::invalid)
236    }
237
238    fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
239        // We do not yet need to encode this header.
240        unimplemented!("not yet needed");
241    }
242}
243
244// TODO Use `LazyLock` once stabilized!
245fn version() -> &'static Regex {
246    static VERSION: OnceLock<Regex> = OnceLock::new();
247    VERSION.get_or_init(|| Regex::new(r#"^v(0|[1-9][0-9]?)$"#).expect("version regex is valid"))
248}
249
250trait StdErrorExt
251where
252    Self: StdError,
253{
254    fn as_chain(&self) -> String {
255        let mut sources = vec![];
256        sources.push(self.to_string());
257
258        let mut source = self.source();
259        while let Some(s) = source {
260            sources.push(s.to_string());
261            source = s.source();
262        }
263
264        sources.join(": ")
265    }
266}
267
268impl<T> StdErrorExt for T where T: StdError {}