1use axum::{
5 RequestExt,
6 extract::Request,
7 http::{HeaderName, HeaderValue, StatusCode, Uri, uri::PathAndQuery},
8 response::{IntoResponse, Response},
9};
10use axum_extra::{
11 TypedHeader,
12 headers::{self, Header},
13};
14use futures::future::BoxFuture;
15use regex::Regex;
16use std::{
17 convert::Infallible,
18 error::Error as StdError,
19 fmt::Debug,
20 ops::Deref,
21 sync::LazyLock,
22 task::{Context, Poll},
23};
24use tower::{Layer, Service};
25use tracing::{debug, error};
26
27static VERSION: LazyLock<Regex> =
28 LazyLock::new(|| Regex::new(r#"^v(\d{1,4})$"#).expect("version regex is valid"));
29
30#[derive(Clone)]
52pub struct ApiVersionLayer<const N: usize, F> {
53 versions: ApiVersions<N>,
54 filter: F,
55}
56
57impl<const N: usize, F> ApiVersionLayer<N, F> {
58 pub fn new(versions: ApiVersions<N>, filter: F) -> Self {
60 Self { versions, filter }
61 }
62}
63
64impl<const N: usize, S, F> Layer<S> for ApiVersionLayer<N, F>
65where
66 F: ApiVersionFilter,
67{
68 type Service = ApiVersionService<N, S, F>;
69
70 fn layer(&self, inner: S) -> Self::Service {
71 ApiVersionService {
72 inner,
73 versions: self.versions,
74 filter: self.filter.clone(),
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct ApiVersions<const N: usize>([u16; N]);
82
83impl<const N: usize> ApiVersions<N> {
84 pub const fn new(versions: [u16; N]) -> Self {
110 assert!(!versions.is_empty(), "API versions must not be empty");
111 assert!(
112 is_monotonically_increasing(versions),
113 "API versions must be strictly monotonically increasing"
114 );
115 assert!(
116 versions[N - 1] < 10_000,
117 "API versions must be within 0u16..10_000"
118 );
119
120 Self(versions)
121 }
122}
123
124impl<const N: usize> Deref for ApiVersions<N> {
125 type Target = [u16; N];
126
127 fn deref(&self) -> &Self::Target {
128 &self.0
129 }
130}
131
132#[trait_variant::make(Send)]
134pub trait ApiVersionFilter: Clone + Send + 'static {
135 type Error: std::error::Error;
136
137 async fn should_rewrite(&self, uri: &Uri) -> Result<bool, Self::Error>;
139}
140
141#[derive(Clone, Copy)]
143pub struct All;
144
145impl ApiVersionFilter for All {
146 type Error = Infallible;
147
148 async fn should_rewrite(&self, _uri: &Uri) -> Result<bool, Self::Error> {
149 Ok(true)
150 }
151}
152
153#[derive(Clone)]
155pub struct ApiVersionService<const N: usize, S, F> {
156 inner: S,
157 versions: ApiVersions<N>,
158 filter: F,
159}
160
161impl<const N: usize, S, F> Service<Request> for ApiVersionService<N, S, F>
162where
163 S: Service<Request, Response = Response> + Clone + Send + 'static,
164 S::Future: Send + 'static,
165 F: ApiVersionFilter,
166{
167 type Response = S::Response;
168 type Error = S::Error;
169 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
170
171 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
172 self.inner.poll_ready(cx)
173 }
174
175 fn call(&mut self, mut request: Request) -> Self::Future {
176 let mut inner = self.inner.clone();
177 let versions = self.versions;
178 let filter = self.filter.clone();
179
180 Box::pin(async move {
181 let has_version_prefix = versions
183 .iter()
184 .any(|version| request.uri().path().starts_with(&format!("/v{version}/")));
185 if has_version_prefix {
186 debug!(
187 uri = %request.uri(),
188 "not rewriting the path, because starts with valid version prefix"
189 );
190 return inner.call(request).await;
191 }
192
193 let should_rewrite = match filter.should_rewrite(request.uri()).await {
195 Ok(should_rewrite) => should_rewrite,
196
197 Err(error) => {
198 error!(error = chained_sources(error), "cannot apply filter");
199 return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
200 }
201 };
202 if !should_rewrite {
203 debug!(uri = %request.uri(), "not rewriting the path, because URI filtered");
204 return inner.call(request).await;
205 }
206
207 let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
209 let version = version
210 .as_ref()
211 .map(|TypedHeader(XApiVersion(v))| v)
212 .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
213 if !versions.contains(version) {
214 let response = (
215 StatusCode::NOT_FOUND,
216 format!("unknown version '{version}'"),
217 );
218 return Ok(response.into_response());
219 }
220 debug!(?version, "using API version");
221
222 let mut parts = request.uri().to_owned().into_parts();
224 let paq = parts.path_and_query.expect("uri has 'path and query'");
225 let mut paq_parts = paq.as_str().split('?');
226 let path = paq_parts.next().expect("uri has path");
227 let paq = match paq_parts.next() {
228 Some(query) => format!("/v{version}{path}?{query}"),
229 None if path != "/" => format!("/v{version}{path}"),
230 None => format!("/v{version}"),
231 };
232 let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
233 parts.path_and_query = Some(paq);
234 let uri = Uri::from_parts(parts).expect("parts are valid");
235
236 debug!(original_uri = %request.uri(), %uri, "rewrote the path");
238 request.uri_mut().clone_from(&uri);
239 inner.call(request).await
240 })
241 }
242}
243
244pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
246
247#[derive(Debug)]
250pub struct XApiVersion(u16);
251
252impl Header for XApiVersion {
253 fn name() -> &'static HeaderName {
254 &X_API_VERSION
255 }
256
257 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
258 where
259 Self: Sized,
260 I: Iterator<Item = &'i HeaderValue>,
261 {
262 values
263 .next()
264 .and_then(|v| v.to_str().ok())
265 .and_then(|s| VERSION.captures(s).and_then(|c| c.get(1)))
266 .and_then(|m| m.as_str().parse().ok())
267 .map(XApiVersion)
268 .ok_or_else(headers::Error::invalid)
269 }
270
271 fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
272 unimplemented!("not yet needed");
274 }
275}
276
277fn chained_sources<E>(error: E) -> String
278where
279 E: StdError,
280{
281 let mut sources = vec![];
282 sources.push(error.to_string());
283
284 let mut source = error.source();
285 while let Some(s) = source {
286 sources.push(s.to_string());
287 source = s.source();
288 }
289
290 sources.join(": ")
291}
292
293const fn is_monotonically_increasing<const N: usize>(versions: [u16; N]) -> bool {
294 if N < 2 {
295 return true;
296 }
297
298 let mut n = 1;
299 while n < N {
300 if versions[n - 1] >= versions[n] {
301 return false;
302 }
303 n += 1;
304 }
305
306 true
307}
308
309#[cfg(test)]
310mod tests {
311 use crate::{VERSION, is_monotonically_increasing};
312 use assert_matches::assert_matches;
313
314 #[test]
315 fn test_x_api_header() {
316 let version = VERSION
317 .captures("v0")
318 .and_then(|c| c.get(1))
319 .map(|m| m.as_str());
320 assert_matches!(version, Some("0"));
321
322 let version = VERSION
323 .captures("v1")
324 .and_then(|c| c.get(1))
325 .map(|m| m.as_str());
326 assert_matches!(version, Some("1"));
327
328 let version = VERSION
329 .captures("v99")
330 .and_then(|c| c.get(1))
331 .map(|m| m.as_str());
332 assert_matches!(version, Some("99"));
333
334 let version = VERSION
335 .captures("v9999")
336 .and_then(|c| c.get(1))
337 .map(|m| m.as_str());
338 assert_matches!(version, Some("9999"));
339
340 let version = VERSION
341 .captures("v10000")
342 .and_then(|c| c.get(1))
343 .map(|m| m.as_str());
344 assert_matches!(version, None);
345
346 let version = VERSION
347 .captures("vx")
348 .and_then(|c| c.get(1))
349 .map(|m| m.as_str());
350 assert_matches!(version, None);
351 }
352
353 #[test]
354 fn test_is_monotonically_increasing() {
355 assert!(is_monotonically_increasing([]));
356 assert!(is_monotonically_increasing([0]));
357 assert!(is_monotonically_increasing([0, 1]));
358
359 assert!(!is_monotonically_increasing([0, 0]));
360 assert!(!is_monotonically_increasing([1, 0]));
361 }
362}