api_version/
lib.rs

1pub use array_macro;
2
3use axum::{
4    RequestExt,
5    extract::Request,
6    http::{HeaderName, HeaderValue, StatusCode, Uri, uri::PathAndQuery},
7    response::{IntoResponse, Response},
8};
9use axum_extra::{
10    TypedHeader,
11    headers::{self, Header},
12};
13use futures::future::BoxFuture;
14use regex::Regex;
15use std::{
16    convert::Infallible,
17    error::Error as StdError,
18    fmt::Debug,
19    future::Future,
20    sync::LazyLock,
21    task::{Context, Poll},
22};
23use thiserror::Error;
24use tower::{Layer, Service};
25use tracing::{debug, error};
26
27/// Create an [ApiVersionLayer] correctly initialized with non-empty and strictly monotonically
28/// increasing versions in the given inclusive range.
29#[macro_export]
30macro_rules! api_version {
31    ($from:literal..=$to:literal) => {
32        {
33            $crate::api_version!($from..=$to, $crate::All)
34        }
35    };
36
37    ($from:literal..=$to:literal, $filter:expr) => {
38        {
39            let versions = $crate::array_macro::array![n => n as u16 + $from; $to - $from + 1];
40            $crate::ApiVersionLayer::new(versions, $filter).expect("versions are valid")
41        }
42    };
43}
44
45static VERSION: LazyLock<Regex> =
46    LazyLock::new(|| Regex::new(r#"^v(0|[1-9][0-9]?)$"#).expect("version regex is valid"));
47
48/// Axum middleware to rewrite a request such that a version prefix is added to the path. This is
49/// based on a set of versions and an optional `"x-api-version"` custom HTTP header: if no such
50/// header is present, the highest version is used. Yet this only applies to requests the URIs of
51/// which pass a filter; others are not rewritten.
52///
53/// Paths must not start with a version prefix, e.g. `"/v0"`.
54#[derive(Clone)]
55pub struct ApiVersionLayer<const N: usize, F> {
56    versions: [u16; N],
57    filter: F,
58}
59
60impl<const N: usize, F> ApiVersionLayer<N, F> {
61    /// Create a new [ApiVersionLayer].
62    ///
63    /// The given versions must not be empty and must be strictly monotonically increasing, e.g.
64    /// `[0, 1, 2]`.
65    pub fn new(versions: [u16; N], filter: F) -> Result<Self, NewApiVersionLayerError> {
66        if versions.is_empty() {
67            return Err(NewApiVersionLayerError::Empty);
68        }
69
70        if versions.as_slice().windows(2).any(|w| w[0] >= w[1]) {
71            return Err(NewApiVersionLayerError::NotIncreasing);
72        }
73
74        Ok(Self { versions, filter })
75    }
76}
77
78impl<const N: usize, S, F> Layer<S> for ApiVersionLayer<N, F>
79where
80    F: ApiVersionFilter,
81{
82    type Service = ApiVersion<N, S, F>;
83
84    fn layer(&self, inner: S) -> Self::Service {
85        ApiVersion {
86            inner,
87            versions: self.versions,
88            filter: self.filter.clone(),
89        }
90    }
91}
92
93/// Determine which requests are rewritten.
94pub trait ApiVersionFilter: Clone + Send + 'static {
95    type Error: std::error::Error;
96
97    /// Requests are only rewritten, if the given URI passes, i.e. results in `true`.
98    fn filter(&self, uri: &Uri) -> impl Future<Output = Result<bool, Self::Error>> + Send;
99}
100
101/// [ApiVersionFilter] making all requests be rewritten.
102#[derive(Clone, Copy)]
103pub struct All;
104
105impl ApiVersionFilter for All {
106    type Error = Infallible;
107
108    async fn filter(&self, _uri: &Uri) -> Result<bool, Self::Error> {
109        Ok(true)
110    }
111}
112
113/// Error creating an [ApiVersionLayer].
114#[derive(Debug, Error)]
115pub enum NewApiVersionLayerError {
116    #[error("versions must not be empty")]
117    Empty,
118
119    #[error("versions must be strictly monotonically increasing")]
120    NotIncreasing,
121}
122
123/// See [ApiVersionLayer].
124#[derive(Clone)]
125pub struct ApiVersion<const N: usize, S, F> {
126    inner: S,
127    versions: [u16; N],
128    filter: F,
129}
130
131impl<const N: usize, S, F> Service<Request> for ApiVersion<N, S, F>
132where
133    S: Service<Request, Response = Response> + Clone + Send + 'static,
134    S::Future: Send + 'static,
135    F: ApiVersionFilter,
136{
137    type Response = S::Response;
138    type Error = S::Error;
139    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
140
141    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142        self.inner.poll_ready(cx)
143    }
144
145    fn call(&mut self, mut request: Request) -> Self::Future {
146        let mut inner = self.inner.clone();
147        let versions = self.versions;
148        let filter = self.filter.clone();
149
150        Box::pin(async move {
151            // Do not allow the path to start with one of the valid version prefixes.
152            if versions
153                .iter()
154                .any(|version| request.uri().path().starts_with(&format!("/v{version}")))
155            {
156                let response = (
157                    StatusCode::BAD_REQUEST,
158                    "path must not start with version prefix like '/v0'",
159                );
160                return Ok(response.into_response());
161            }
162
163            let pass_through = match filter.filter(request.uri()).await {
164                Ok(pass_through) => pass_through,
165
166                Err(error) => {
167                    error!(error = error.as_chain(), "cannot apply filter");
168                    return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
169                }
170            };
171
172            if !pass_through {
173                debug!(uri = %request.uri(), "not rewriting the path");
174                return inner.call(request).await;
175            }
176
177            // Determine API version.
178            let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
179            let version = version
180                .as_ref()
181                .map(|TypedHeader(XApiVersion(v))| v)
182                .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
183            if !versions.contains(version) {
184                let response = (
185                    StatusCode::NOT_FOUND,
186                    format!("unknown version '{version}'"),
187                );
188                return Ok(response.into_response());
189            }
190            debug!(?version, "using API version");
191
192            // Prepend the suitable prefix to the request URI.
193            let mut parts = request.uri().to_owned().into_parts();
194            let paq = parts.path_and_query.expect("uri has 'path and query'");
195            let mut paq_parts = paq.as_str().split('?');
196            let path = paq_parts.next().expect("uri has path");
197            let paq = match paq_parts.next() {
198                Some(query) => format!("/v{version}{path}?{query}"),
199                None if path != "/" => format!("/v{version}{path}"),
200                None => format!("/v{version}"),
201            };
202            let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
203            parts.path_and_query = Some(paq);
204            let uri = Uri::from_parts(parts).expect("parts are valid");
205
206            // Rewrite the request URI and run the downstream services.
207            debug!(original_uri = %request.uri(), %uri, "rewrote the path");
208            request.uri_mut().clone_from(&uri);
209            inner.call(request).await
210        })
211    }
212}
213
214/// Header name for the [XApiVersion] custom HTTP header.
215pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
216
217/// Custom HTTP header conveying the API version, which is expected to be a version designator
218/// starting with `'v'` followed by a number from 0..+99 without leading zero, e.g. `v0`.
219#[derive(Debug)]
220pub struct XApiVersion(u16);
221
222impl Header for XApiVersion {
223    fn name() -> &'static HeaderName {
224        &X_API_VERSION
225    }
226
227    fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
228    where
229        Self: Sized,
230        I: Iterator<Item = &'i HeaderValue>,
231    {
232        values
233            .next()
234            .and_then(|v| v.to_str().ok())
235            .and_then(|s| VERSION.captures(s).and_then(|c| c.get(1)))
236            .and_then(|m| m.as_str().parse().ok())
237            .map(XApiVersion)
238            .ok_or_else(headers::Error::invalid)
239    }
240
241    fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
242        // We do not yet need to encode this header.
243        unimplemented!("not yet needed");
244    }
245}
246
247trait StdErrorExt
248where
249    Self: StdError,
250{
251    fn as_chain(&self) -> String {
252        let mut sources = vec![];
253        sources.push(self.to_string());
254
255        let mut source = self.source();
256        while let Some(s) = source {
257            sources.push(s.to_string());
258            source = s.source();
259        }
260
261        sources.join(": ")
262    }
263}
264
265impl<T> StdErrorExt for T where T: StdError {}