1use axum::{
5 RequestExt,
6 extract::Request,
7 http::{HeaderName, HeaderValue, StatusCode, Uri, uri::PathAndQuery},
8 response::{IntoResponse, Response},
9};
10use axum_extra::{
11 TypedHeader,
12 headers::{self, Header},
13};
14use futures::future::BoxFuture;
15use regex::Regex;
16use std::{
17 fmt::Debug,
18 ops::Deref,
19 sync::LazyLock,
20 task::{Context, Poll},
21};
22use tower::{Layer, Service};
23use tracing::debug;
24
25static VERSION: LazyLock<Regex> =
26 LazyLock::new(|| Regex::new(r#"^v(\d{1,4})$"#).expect("version regex is valid"));
27
28#[derive(Clone)]
50pub struct ApiVersionLayer<const N: usize> {
51 base_path: String,
52 versions: ApiVersions<N>,
53}
54
55impl<const N: usize> ApiVersionLayer<N> {
56 pub fn new(base_path: impl AsRef<str>, versions: ApiVersions<N>) -> Self {
58 let base_path = base_path.as_ref().trim_end_matches('/').to_string();
59 assert!(base_path.starts_with('/'), "base path must start with '/'");
60 assert!(!base_path.len() > 1, "base path must not be empty");
61
62 Self {
63 base_path,
64 versions,
65 }
66 }
67}
68
69impl<const N: usize, S> Layer<S> for ApiVersionLayer<N> {
70 type Service = ApiVersionService<N, S>;
71
72 fn layer(&self, inner: S) -> Self::Service {
73 ApiVersionService {
74 inner,
75 base_path: self.base_path.clone(),
76 versions: self.versions,
77 }
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub struct ApiVersions<const N: usize>([u16; N]);
84
85impl<const N: usize> ApiVersions<N> {
86 pub const fn new(versions: [u16; N]) -> Self {
112 assert!(!versions.is_empty(), "API versions must not be empty");
113 assert!(
114 is_monotonically_increasing(versions),
115 "API versions must be strictly monotonically increasing"
116 );
117 assert!(
118 versions[N - 1] < 10_000,
119 "API versions must be within 0u16..10_000"
120 );
121
122 Self(versions)
123 }
124}
125
126impl<const N: usize> Deref for ApiVersions<N> {
127 type Target = [u16; N];
128
129 fn deref(&self) -> &Self::Target {
130 &self.0
131 }
132}
133
134#[derive(Clone)]
136pub struct ApiVersionService<const N: usize, S> {
137 inner: S,
138 base_path: String,
139 versions: ApiVersions<N>,
140}
141
142impl<const N: usize, S> Service<Request> for ApiVersionService<N, S>
143where
144 S: Service<Request, Response = Response> + Clone + Send + 'static,
145 S::Future: Send + 'static,
146{
147 type Response = S::Response;
148 type Error = S::Error;
149 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
150
151 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
152 self.inner.poll_ready(cx)
153 }
154
155 fn call(&mut self, mut request: Request) -> Self::Future {
156 let mut inner = self.inner.clone();
157 let base_path = self.base_path.clone();
158 let versions = self.versions;
159
160 Box::pin(async move {
161 let Some(path) = request.uri().path().strip_prefix(&base_path) else {
163 debug!(
164 uri = %request.uri(),
165 "not rewriting the path, because does not start with base path"
166 );
167 return inner.call(request).await;
168 };
169 let path = path.to_owned();
170
171 let has_version_prefix = versions
173 .iter()
174 .any(|version| path.starts_with(&format!("/v{version}/")));
175 if has_version_prefix {
176 debug!(
177 uri = %request.uri(),
178 "not rewriting the path, because starts with valid version prefix"
179 );
180 return inner.call(request).await;
181 }
182
183 let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
185 let version = version
186 .as_ref()
187 .map(|TypedHeader(XApiVersion(v))| v)
188 .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
189 if !versions.contains(version) {
190 let response = (
191 StatusCode::NOT_FOUND,
192 format!("unknown version '{version}'"),
193 );
194 return Ok(response.into_response());
195 }
196 debug!(?version, "using API version");
197
198 let mut parts = request.uri().to_owned().into_parts();
200 let paq = parts.path_and_query.expect("uri has 'path and query'");
201 let mut paq_parts = paq.as_str().split('?').skip(1);
202 let paq = match paq_parts.next() {
203 Some(query) => format!("{base_path}/v{version}{path}?{query}"),
204 None => format!("{base_path}/v{version}{path}"),
205 };
206 let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
207 parts.path_and_query = Some(paq);
208 let uri = Uri::from_parts(parts).expect("parts are valid");
209
210 debug!(original_uri = %request.uri(), %uri, "rewrote the path");
212 request.uri_mut().clone_from(&uri);
213 inner.call(request).await
214 })
215 }
216}
217
218pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
220
221#[derive(Debug)]
224pub struct XApiVersion(u16);
225
226impl Header for XApiVersion {
227 fn name() -> &'static HeaderName {
228 &X_API_VERSION
229 }
230
231 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
232 where
233 Self: Sized,
234 I: Iterator<Item = &'i HeaderValue>,
235 {
236 values
237 .next()
238 .and_then(|v| v.to_str().ok())
239 .and_then(|s| VERSION.captures(s).and_then(|c| c.get(1)))
240 .and_then(|m| m.as_str().parse().ok())
241 .map(XApiVersion)
242 .ok_or_else(headers::Error::invalid)
243 }
244
245 fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
246 unimplemented!("not yet needed");
248 }
249}
250
251const fn is_monotonically_increasing<const N: usize>(versions: [u16; N]) -> bool {
252 if N < 2 {
253 return true;
254 }
255
256 let mut n = 1;
257 while n < N {
258 if versions[n - 1] >= versions[n] {
259 return false;
260 }
261 n += 1;
262 }
263
264 true
265}
266
267#[cfg(test)]
268mod tests {
269 use crate::{VERSION, is_monotonically_increasing};
270 use assert_matches::assert_matches;
271
272 #[test]
273 fn test_x_api_header() {
274 let version = VERSION
275 .captures("v0")
276 .and_then(|c| c.get(1))
277 .map(|m| m.as_str());
278 assert_matches!(version, Some("0"));
279
280 let version = VERSION
281 .captures("v1")
282 .and_then(|c| c.get(1))
283 .map(|m| m.as_str());
284 assert_matches!(version, Some("1"));
285
286 let version = VERSION
287 .captures("v99")
288 .and_then(|c| c.get(1))
289 .map(|m| m.as_str());
290 assert_matches!(version, Some("99"));
291
292 let version = VERSION
293 .captures("v9999")
294 .and_then(|c| c.get(1))
295 .map(|m| m.as_str());
296 assert_matches!(version, Some("9999"));
297
298 let version = VERSION
299 .captures("v10000")
300 .and_then(|c| c.get(1))
301 .map(|m| m.as_str());
302 assert_matches!(version, None);
303
304 let version = VERSION
305 .captures("vx")
306 .and_then(|c| c.get(1))
307 .map(|m| m.as_str());
308 assert_matches!(version, None);
309 }
310
311 #[test]
312 fn test_is_monotonically_increasing() {
313 assert!(is_monotonically_increasing([]));
314 assert!(is_monotonically_increasing([0]));
315 assert!(is_monotonically_increasing([0, 1]));
316
317 assert!(!is_monotonically_increasing([0, 0]));
318 assert!(!is_monotonically_increasing([1, 0]));
319 }
320}