1pub use array_macro;
2
3use axum::{
4 extract::Request,
5 http::{uri::PathAndQuery, HeaderName, HeaderValue, StatusCode, Uri},
6 response::{IntoResponse, Response},
7 RequestExt,
8};
9use axum_extra::{
10 headers::{self, Header},
11 TypedHeader,
12};
13use futures::future::BoxFuture;
14use regex::Regex;
15use std::{
16 convert::Infallible,
17 error::Error as StdError,
18 fmt::Debug,
19 future::Future,
20 sync::OnceLock,
21 task::{Context, Poll},
22};
23use thiserror::Error;
24use tower::{Layer, Service};
25use tracing::{debug, error};
26
27#[macro_export]
30macro_rules! api_version {
31 ($from:literal..=$to:literal) => {
32 {
33 $crate::api_version!($from..=$to, $crate::All)
34 }
35 };
36
37 ($from:literal..=$to:literal, $filter:expr) => {
38 {
39 let versions = $crate::array_macro::array![n => n as u16 + $from; $to - $from + 1];
40 $crate::ApiVersionLayer::new(versions, $filter).expect("versions are valid")
41 }
42 };
43}
44
45#[derive(Clone)]
52pub struct ApiVersionLayer<const N: usize, F> {
53 versions: [u16; N],
54 filter: F,
55}
56
57impl<const N: usize, F> ApiVersionLayer<N, F> {
58 pub fn new(versions: [u16; N], filter: F) -> Result<Self, NewApiVersionLayerError> {
63 if versions.is_empty() {
64 return Err(NewApiVersionLayerError::Empty);
65 }
66
67 if versions.as_slice().windows(2).any(|w| w[0] >= w[1]) {
68 return Err(NewApiVersionLayerError::NotIncreasing);
69 }
70
71 Ok(Self { versions, filter })
72 }
73}
74
75impl<const N: usize, S, F> Layer<S> for ApiVersionLayer<N, F>
76where
77 F: ApiVersionFilter,
78{
79 type Service = ApiVersion<N, S, F>;
80
81 fn layer(&self, inner: S) -> Self::Service {
82 ApiVersion {
83 inner,
84 versions: self.versions,
85 filter: self.filter.clone(),
86 }
87 }
88}
89
90pub trait ApiVersionFilter: Clone + Send + 'static {
92 type Error: std::error::Error;
93
94 fn filter(&self, uri: &Uri) -> impl Future<Output = Result<bool, Self::Error>> + Send;
96}
97
98#[derive(Clone, Copy)]
100pub struct All;
101
102impl ApiVersionFilter for All {
103 type Error = Infallible;
104
105 async fn filter(&self, _uri: &Uri) -> Result<bool, Self::Error> {
106 Ok(true)
107 }
108}
109
110#[derive(Debug, Error)]
112pub enum NewApiVersionLayerError {
113 #[error("versions must not be empty")]
114 Empty,
115
116 #[error("versions must be strictly monotonically increasing")]
117 NotIncreasing,
118}
119
120#[derive(Clone)]
122pub struct ApiVersion<const N: usize, S, F> {
123 inner: S,
124 versions: [u16; N],
125 filter: F,
126}
127
128impl<const N: usize, S, F> Service<Request> for ApiVersion<N, S, F>
129where
130 S: Service<Request, Response = Response> + Clone + Send + 'static,
131 S::Future: Send + 'static,
132 F: ApiVersionFilter,
133{
134 type Response = S::Response;
135 type Error = S::Error;
136 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
137
138 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139 self.inner.poll_ready(cx)
140 }
141
142 fn call(&mut self, mut request: Request) -> Self::Future {
143 let mut inner = self.inner.clone();
144 let versions = self.versions;
145 let filter = self.filter.clone();
146
147 Box::pin(async move {
148 if versions
150 .iter()
151 .any(|version| request.uri().path().starts_with(&format!("/v{version}")))
152 {
153 let response = (
154 StatusCode::BAD_REQUEST,
155 "path must not start with version prefix like '/v0'",
156 );
157 return Ok(response.into_response());
158 }
159
160 let pass_through = match filter.filter(request.uri()).await {
161 Ok(pass_through) => pass_through,
162
163 Err(error) => {
164 error!(error = error.as_chain(), "cannot apply filter");
165 return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
166 }
167 };
168
169 if !pass_through {
170 debug!(uri = %request.uri(), "not rewriting the path");
171 return inner.call(request).await;
172 }
173
174 let version = request.extract_parts::<TypedHeader<XApiVersion>>().await;
176 let version = version
177 .as_ref()
178 .map(|TypedHeader(XApiVersion(v))| v)
179 .unwrap_or_else(|_| versions.last().expect("versions is not empty"));
180 if !versions.contains(version) {
181 let response = (
182 StatusCode::NOT_FOUND,
183 format!("unknown version '{version}'"),
184 );
185 return Ok(response.into_response());
186 }
187 debug!(?version, "using API version");
188
189 let mut parts = request.uri().to_owned().into_parts();
191 let paq = parts.path_and_query.expect("uri has 'path and query'");
192 let mut paq_parts = paq.as_str().split('?');
193 let path = paq_parts.next().expect("uri has path");
194 let paq = match paq_parts.next() {
195 Some(query) => format!("/v{version}{path}?{query}"),
196 None if path != "/" => format!("/v{version}{path}"),
197 None => format!("/v{version}"),
198 };
199 let paq = PathAndQuery::from_maybe_shared(paq).expect("new 'path and query' is valid");
200 parts.path_and_query = Some(paq);
201 let uri = Uri::from_parts(parts).expect("parts are valid");
202
203 debug!(original_uri = %request.uri(), %uri, "rewrote the path");
205 request.uri_mut().clone_from(&uri);
206 inner.call(request).await
207 })
208 }
209}
210
211pub static X_API_VERSION: HeaderName = HeaderName::from_static("x-api-version");
213
214#[derive(Debug)]
217pub struct XApiVersion(u16);
218
219impl Header for XApiVersion {
220 fn name() -> &'static HeaderName {
221 &X_API_VERSION
222 }
223
224 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
225 where
226 Self: Sized,
227 I: Iterator<Item = &'i HeaderValue>,
228 {
229 values
230 .next()
231 .and_then(|v| v.to_str().ok())
232 .and_then(|s| version().captures(s).and_then(|c| c.get(1)))
233 .and_then(|m| m.as_str().parse().ok())
234 .map(XApiVersion)
235 .ok_or_else(headers::Error::invalid)
236 }
237
238 fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
239 unimplemented!("not yet needed");
241 }
242}
243
244fn version() -> &'static Regex {
246 static VERSION: OnceLock<Regex> = OnceLock::new();
247 VERSION.get_or_init(|| Regex::new(r#"^v(0|[1-9][0-9]?)$"#).expect("version regex is valid"))
248}
249
250trait StdErrorExt
251where
252 Self: StdError,
253{
254 fn as_chain(&self) -> String {
255 let mut sources = vec![];
256 sources.push(self.to_string());
257
258 let mut source = self.source();
259 while let Some(s) = source {
260 sources.push(s.to_string());
261 source = s.source();
262 }
263
264 sources.join(": ")
265 }
266}
267
268impl<T> StdErrorExt for T where T: StdError {}