1use axum::{
5 RequestExt,
6 extract::Request,
7 http::{HeaderName, HeaderValue, StatusCode, Uri, uri::PathAndQuery},
8 response::{IntoResponse, Response},
9};
10use axum_extra::{
11 TypedHeader,
12 headers::{self, Header},
13};
14use futures::future::BoxFuture;
15use regex::Regex;
16use std::{
17 fmt::Debug,
18 ops::Deref,
19 sync::LazyLock,
20 task::{Context, Poll},
21};
22use tower::{Layer, Service};
23use tracing::debug;
24
25static VERSION: LazyLock<Regex> =
26 LazyLock::new(|| Regex::new(r#"^v(\d{1,4})$"#).expect("version regex is valid"));
27
28#[derive(Clone)]
49pub struct ApiVersionLayer<const N: usize> {
50 base_path: String,
51 versions: ApiVersions<N>,
52}
53
54impl<const N: usize> ApiVersionLayer<N> {
55 pub fn new(base_path: impl AsRef<str>, versions: ApiVersions<N>) -> Self {
61 let base_path = base_path.as_ref().trim_end_matches('/').to_string();
62 assert!(base_path.starts_with('/'), "base path must start with '/'");
63 assert!(!base_path.len() > 1, "base path must not be empty");
64
65 Self {
66 base_path,
67 versions,
68 }
69 }
70}
71
72impl<const N: usize, S> Layer<S> for ApiVersionLayer<N> {
73 type Service = ApiVersionService<N, S>;
74
75 fn layer(&self, inner: S) -> Self::Service {
76 ApiVersionService {
77 inner,
78 base_path: self.base_path.clone(),
79 versions: self.versions,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub struct ApiVersions<const N: usize>([u16; N]);
87
88impl<const N: usize> ApiVersions<N> {
89 pub const fn new(versions: [u16; N]) -> Self {
117 assert!(!versions.is_empty(), "API versions must not be empty");
118 assert!(
119 is_monotonically_increasing(versions),
120 "API versions must be strictly monotonically increasing"
121 );
122 assert!(
123 versions[N - 1] < 10_000,
124 "API versions must be within 0u16..10_000"
125 );
126
127 Self(versions)
128 }
129}
130
131impl<const N: usize> Deref for ApiVersions<N> {
132 type Target = [u16; N];
133
134 fn deref(&self) -> &Self::Target {
135 &self.0
136 }
137}
138
139#[derive(Clone)]
141pub struct ApiVersionService<const N: usize, S> {
142 inner: S,
143 base_path: String,
144 versions: ApiVersions<N>,
145}
146
147impl<const N: usize, S> Service<Request> for ApiVersionService<N, S>
148where
149 S: Service<Request, Response = Response> + Clone + Send + 'static,
150 S::Future: Send + 'static,
151{
152 type Response = S::Response;
153 type Error = S::Error;
154 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
155
156 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
157 self.inner.poll_ready(cx)
158 }
159
160 fn call(&mut self, mut request: Request) -> Self::Future {
161 let mut inner = self.inner.clone();
162 let base_path = self.base_path.clone();
163 let versions = self.versions;
164
165 Box::pin(async move {
166 let path = if let Some(path) = request.uri().path().strip_prefix(&base_path)
168 && path.starts_with('/')
169 {
170 path.to_owned()
171 } else {
172 debug!(
173 uri = %request.uri(),
174 "not rewriting the path, because does not start with base path"
175 );
176 return inner.call(request).await;
177 };
178
179 let has_version_prefix = versions
181 .iter()
182 .any(|version| path.starts_with(&format!("/v{version}/")));
183 if has_version_prefix {
184 debug!(
185 uri = %request.uri(),
186 "not rewriting the path, because starts with valid version prefix"
187 );
188 return inner.call(request).await;
189 }
190
191 let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
193 let version = version
194 .as_ref()
195 .map(|TypedHeader(XApiVersion(v))| v)
196 .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
197 if !versions.contains(version) {
198 let response = (
199 StatusCode::NOT_FOUND,
200 format!("unknown version '{version}'"),
201 );
202 return Ok(response.into_response());
203 }
204 debug!(?version, "using API version");
205
206 let mut parts = request.uri().to_owned().into_parts();
208 let paq = parts.path_and_query.expect("uri has 'path and query'");
209 let mut paq_parts = paq.as_str().split('?').skip(1);
210 let paq = match paq_parts.next() {
211 Some(query) => format!("{base_path}/v{version}{path}?{query}"),
212 None => format!("{base_path}/v{version}{path}"),
213 };
214 let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
215 parts.path_and_query = Some(paq);
216 let uri = Uri::from_parts(parts).expect("parts are valid");
217
218 debug!(original_uri = %request.uri(), %uri, "rewrote the path");
220 request.uri_mut().clone_from(&uri);
221 inner.call(request).await
222 })
223 }
224}
225
226pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
228
229#[derive(Debug)]
232pub struct XApiVersion(u16);
233
234impl Header for XApiVersion {
235 fn name() -> &'static HeaderName {
236 &X_API_VERSION
237 }
238
239 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
240 where
241 Self: Sized,
242 I: Iterator<Item = &'i HeaderValue>,
243 {
244 values
245 .next()
246 .and_then(|v| v.to_str().ok())
247 .and_then(|s| VERSION.captures(s).and_then(|c| c.get(1)))
248 .and_then(|m| m.as_str().parse().ok())
249 .map(XApiVersion)
250 .ok_or_else(headers::Error::invalid)
251 }
252
253 fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
254 unimplemented!("not yet needed");
256 }
257}
258
259const fn is_monotonically_increasing<const N: usize>(versions: [u16; N]) -> bool {
260 if N < 2 {
261 return true;
262 }
263
264 let mut n = 1;
265 while n < N {
266 if versions[n - 1] >= versions[n] {
267 return false;
268 }
269 n += 1;
270 }
271
272 true
273}
274
275#[cfg(test)]
276mod tests {
277 use crate::{VERSION, is_monotonically_increasing};
278 use assert_matches::assert_matches;
279
280 #[test]
281 fn test_x_api_header() {
282 let version = VERSION
283 .captures("v0")
284 .and_then(|c| c.get(1))
285 .map(|m| m.as_str());
286 assert_matches!(version, Some("0"));
287
288 let version = VERSION
289 .captures("v1")
290 .and_then(|c| c.get(1))
291 .map(|m| m.as_str());
292 assert_matches!(version, Some("1"));
293
294 let version = VERSION
295 .captures("v99")
296 .and_then(|c| c.get(1))
297 .map(|m| m.as_str());
298 assert_matches!(version, Some("99"));
299
300 let version = VERSION
301 .captures("v9999")
302 .and_then(|c| c.get(1))
303 .map(|m| m.as_str());
304 assert_matches!(version, Some("9999"));
305
306 let version = VERSION
307 .captures("v10000")
308 .and_then(|c| c.get(1))
309 .map(|m| m.as_str());
310 assert_matches!(version, None);
311
312 let version = VERSION
313 .captures("vx")
314 .and_then(|c| c.get(1))
315 .map(|m| m.as_str());
316 assert_matches!(version, None);
317 }
318
319 #[test]
320 fn test_is_monotonically_increasing() {
321 assert!(is_monotonically_increasing([]));
322 assert!(is_monotonically_increasing([0]));
323 assert!(is_monotonically_increasing([0, 1]));
324
325 assert!(!is_monotonically_increasing([0, 0]));
326 assert!(!is_monotonically_increasing([1, 0]));
327 }
328}