1pub use array_macro;
2
3use axum::{
4 RequestExt,
5 extract::Request,
6 http::{HeaderName, HeaderValue, StatusCode, Uri, uri::PathAndQuery},
7 response::{IntoResponse, Response},
8};
9use axum_extra::{
10 TypedHeader,
11 headers::{self, Header},
12};
13use futures::future::BoxFuture;
14use regex::Regex;
15use std::{
16 convert::Infallible,
17 error::Error as StdError,
18 fmt::Debug,
19 future::Future,
20 sync::LazyLock,
21 task::{Context, Poll},
22};
23use thiserror::Error;
24use tower::{Layer, Service};
25use tracing::{debug, error};
26
27#[macro_export]
30macro_rules! api_version {
31 ($from:literal..=$to:literal) => {
32 {
33 $crate::api_version!($from..=$to, $crate::All)
34 }
35 };
36
37 ($from:literal..=$to:literal, $filter:expr) => {
38 {
39 let versions = $crate::array_macro::array![n => n as u16 + $from; $to - $from + 1];
40 $crate::ApiVersionLayer::new(versions, $filter).expect("versions are valid")
41 }
42 };
43}
44
45static VERSION: LazyLock<Regex> =
46 LazyLock::new(|| Regex::new(r#"^v(0|[1-9][0-9]?)$"#).expect("version regex is valid"));
47
48#[derive(Clone)]
55pub struct ApiVersionLayer<const N: usize, F> {
56 versions: [u16; N],
57 filter: F,
58}
59
60impl<const N: usize, F> ApiVersionLayer<N, F> {
61 pub fn new(versions: [u16; N], filter: F) -> Result<Self, NewApiVersionLayerError> {
66 if versions.is_empty() {
67 return Err(NewApiVersionLayerError::Empty);
68 }
69
70 if versions.as_slice().windows(2).any(|w| w[0] >= w[1]) {
71 return Err(NewApiVersionLayerError::NotIncreasing);
72 }
73
74 Ok(Self { versions, filter })
75 }
76}
77
78impl<const N: usize, S, F> Layer<S> for ApiVersionLayer<N, F>
79where
80 F: ApiVersionFilter,
81{
82 type Service = ApiVersion<N, S, F>;
83
84 fn layer(&self, inner: S) -> Self::Service {
85 ApiVersion {
86 inner,
87 versions: self.versions,
88 filter: self.filter.clone(),
89 }
90 }
91}
92
93pub trait ApiVersionFilter: Clone + Send + 'static {
95 type Error: std::error::Error;
96
97 fn filter(&self, uri: &Uri) -> impl Future<Output = Result<bool, Self::Error>> + Send;
99}
100
101#[derive(Clone, Copy)]
103pub struct All;
104
105impl ApiVersionFilter for All {
106 type Error = Infallible;
107
108 async fn filter(&self, _uri: &Uri) -> Result<bool, Self::Error> {
109 Ok(true)
110 }
111}
112
113#[derive(Debug, Error)]
115pub enum NewApiVersionLayerError {
116 #[error("versions must not be empty")]
117 Empty,
118
119 #[error("versions must be strictly monotonically increasing")]
120 NotIncreasing,
121}
122
123#[derive(Clone)]
125pub struct ApiVersion<const N: usize, S, F> {
126 inner: S,
127 versions: [u16; N],
128 filter: F,
129}
130
131impl<const N: usize, S, F> Service<Request> for ApiVersion<N, S, F>
132where
133 S: Service<Request, Response = Response> + Clone + Send + 'static,
134 S::Future: Send + 'static,
135 F: ApiVersionFilter,
136{
137 type Response = S::Response;
138 type Error = S::Error;
139 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
140
141 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142 self.inner.poll_ready(cx)
143 }
144
145 fn call(&mut self, mut request: Request) -> Self::Future {
146 let mut inner = self.inner.clone();
147 let versions = self.versions;
148 let filter = self.filter.clone();
149
150 Box::pin(async move {
151 if versions
153 .iter()
154 .any(|version| request.uri().path().starts_with(&format!("/v{version}")))
155 {
156 let response = (
157 StatusCode::BAD_REQUEST,
158 "path must not start with version prefix like '/v0'",
159 );
160 return Ok(response.into_response());
161 }
162
163 let pass_through = match filter.filter(request.uri()).await {
164 Ok(pass_through) => pass_through,
165
166 Err(error) => {
167 error!(error = error.as_chain(), "cannot apply filter");
168 return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
169 }
170 };
171
172 if !pass_through {
173 debug!(uri = %request.uri(), "not rewriting the path");
174 return inner.call(request).await;
175 }
176
177 let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
179 let version = version
180 .as_ref()
181 .map(|TypedHeader(XApiVersion(v))| v)
182 .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
183 if !versions.contains(version) {
184 let response = (
185 StatusCode::NOT_FOUND,
186 format!("unknown version '{version}'"),
187 );
188 return Ok(response.into_response());
189 }
190 debug!(?version, "using API version");
191
192 let mut parts = request.uri().to_owned().into_parts();
194 let paq = parts.path_and_query.expect("uri has 'path and query'");
195 let mut paq_parts = paq.as_str().split('?');
196 let path = paq_parts.next().expect("uri has path");
197 let paq = match paq_parts.next() {
198 Some(query) => format!("/v{version}{path}?{query}"),
199 None if path != "/" => format!("/v{version}{path}"),
200 None => format!("/v{version}"),
201 };
202 let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
203 parts.path_and_query = Some(paq);
204 let uri = Uri::from_parts(parts).expect("parts are valid");
205
206 debug!(original_uri = %request.uri(), %uri, "rewrote the path");
208 request.uri_mut().clone_from(&uri);
209 inner.call(request).await
210 })
211 }
212}
213
214pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
216
217#[derive(Debug)]
220pub struct XApiVersion(u16);
221
222impl Header for XApiVersion {
223 fn name() -> &'static HeaderName {
224 &X_API_VERSION
225 }
226
227 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
228 where
229 Self: Sized,
230 I: Iterator<Item = &'i HeaderValue>,
231 {
232 values
233 .next()
234 .and_then(|v| v.to_str().ok())
235 .and_then(|s| VERSION.captures(s).and_then(|c| c.get(1)))
236 .and_then(|m| m.as_str().parse().ok())
237 .map(XApiVersion)
238 .ok_or_else(headers::Error::invalid)
239 }
240
241 fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
242 unimplemented!("not yet needed");
244 }
245}
246
247trait StdErrorExt
248where
249 Self: StdError,
250{
251 fn as_chain(&self) -> String {
252 let mut sources = vec![];
253 sources.push(self.to_string());
254
255 let mut source = self.source();
256 while let Some(s) = source {
257 sources.push(s.to_string());
258 source = s.source();
259 }
260
261 sources.join(": ")
262 }
263}
264
265impl<T> StdErrorExt for T where T: StdError {}