1use crate::{
2 errors::HttpError,
3 middleware::v2::{Middleware, Next, NextFuture},
4 request::ElifRequest,
5 response::ElifResponse,
6};
7use once_cell::sync::Lazy;
8use serde::{Deserialize, Serialize};
9use service_builder::builder;
10use std::collections::HashMap;
11use std::future::Future;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use tower::{Layer, Service};
15
16static URL_PATH_VERSION_REGEX: Lazy<regex::Regex> = Lazy::new(|| {
18 regex::Regex::new(r"/api/v?(\d+(?:\.\d+)?)/").expect("Invalid URL path version regex")
19});
20
21static ACCEPT_HEADER_VERSION_REGEX: Lazy<regex::Regex> = Lazy::new(|| {
22 regex::Regex::new(r"version=([^;,\s]+)").expect("Invalid Accept header version regex")
23});
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum VersionStrategy {
28 UrlPath,
30 Header(String),
32 QueryParam(String),
34 AcceptHeader,
36}
37
38impl Default for VersionStrategy {
39 fn default() -> Self {
40 Self::UrlPath
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ApiVersion {
47 pub version: String,
49 pub deprecated: bool,
51 pub deprecation_message: Option<String>,
53 pub sunset_date: Option<String>,
55 pub is_default: bool,
57}
58
59#[derive(Debug, Clone, Default)]
61#[builder]
62pub struct VersioningConfig {
63 #[builder(default)]
65 pub versions: HashMap<String, ApiVersion>,
66 #[builder(default)]
68 pub strategy: VersionStrategy,
69 #[builder(optional)]
71 pub default_version: Option<String>,
72 #[builder(default = "true")]
74 pub include_deprecation_headers: bool,
75 #[builder(default = "\"Api-Version\".to_string()")]
77 pub version_header_name: String,
78 #[builder(default = "\"version\".to_string()")]
80 pub version_param_name: String,
81 #[builder(default = "true")]
83 pub strict_validation: bool,
84}
85
86impl VersioningConfig {
87 pub fn add_version(&mut self, version: String, api_version: ApiVersion) {
89 self.versions.insert(version, api_version);
90 }
91
92 pub fn deprecate_version(
94 &mut self,
95 version: &str,
96 message: Option<String>,
97 sunset_date: Option<String>,
98 ) {
99 if let Some(api_version) = self.versions.get_mut(version) {
100 api_version.deprecated = true;
101 api_version.deprecation_message = message;
102 api_version.sunset_date = sunset_date;
103 }
104 }
105
106 pub fn get_default_version(&self) -> Option<&ApiVersion> {
108 if let Some(default_version) = &self.default_version {
109 return self.versions.get(default_version);
110 }
111
112 self.versions.values().find(|v| v.is_default)
114 }
115
116 pub fn get_version(&self, version: &str) -> Option<&ApiVersion> {
118 self.versions.get(version)
119 }
120
121 pub fn get_strategy(&self) -> &VersionStrategy {
123 &self.strategy
124 }
125
126 pub fn get_versions(&self) -> &HashMap<String, ApiVersion> {
128 &self.versions
129 }
130
131 pub fn get_versions_mut(&mut self) -> &mut HashMap<String, ApiVersion> {
133 &mut self.versions
134 }
135
136 pub fn clone_config(
138 &self,
139 ) -> (
140 HashMap<String, ApiVersion>,
141 VersionStrategy,
142 Option<String>,
143 bool,
144 String,
145 String,
146 bool,
147 ) {
148 (
149 self.versions.clone(),
150 self.strategy.clone(),
151 self.default_version.clone(),
152 self.include_deprecation_headers,
153 self.version_header_name.clone(),
154 self.version_param_name.clone(),
155 self.strict_validation,
156 )
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct VersionInfo {
163 pub version: String,
165 pub api_version: ApiVersion,
167 pub is_deprecated: bool,
169}
170
171#[derive(Debug)]
173pub struct VersioningMiddleware {
174 config: VersioningConfig,
175}
176
177impl VersioningMiddleware {
178 pub fn new(config: VersioningConfig) -> Self {
180 Self { config }
181 }
182}
183
184fn extract_version_from_request(
186 request: &ElifRequest,
187 strategy: &VersionStrategy,
188) -> Result<Option<String>, HttpError> {
189 match strategy {
190 VersionStrategy::UrlPath => {
191 let path = request.path();
192 if let Some(captures) = URL_PATH_VERSION_REGEX.captures(path) {
193 Ok(Some(captures[1].to_string()))
194 } else {
195 Ok(None)
196 }
197 }
198 VersionStrategy::Header(header_name) => {
199 if let Some(header_value) = request.header(header_name) {
200 if let Ok(version_str) = header_value.to_str() {
201 Ok(Some(version_str.to_string()))
202 } else {
203 Err(HttpError::bad_request("Invalid version header"))
204 }
205 } else {
206 Ok(None)
207 }
208 }
209 VersionStrategy::QueryParam(param_name) => {
210 if let Some(query) = request.uri.query() {
211 for pair in query.split('&') {
212 let mut parts = pair.split('=');
213 if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
214 if key == param_name {
215 return Ok(Some(value.to_string()));
216 }
217 }
218 }
219 }
220 Ok(None)
221 }
222 VersionStrategy::AcceptHeader => {
223 if let Some(accept_header) = request.header("Accept") {
224 if let Ok(accept_str) = accept_header.to_str() {
225 if let Some(captures) = ACCEPT_HEADER_VERSION_REGEX.captures(accept_str) {
226 return Ok(Some(captures[1].to_string()));
227 }
228 }
229 }
230 Ok(None)
231 }
232 }
233}
234
235fn resolve_version(
237 config: &VersioningConfig,
238 extracted_version: Option<String>,
239) -> Result<VersionInfo, HttpError> {
240 let version_key = match extracted_version {
241 Some(v) => v,
242 None => {
243 if let Some(default) = &config.default_version {
244 default.clone()
245 } else if config.strict_validation {
246 return Err(HttpError::bad_request("Version is required"));
247 } else {
248 let mut sorted_keys: Vec<_> = config.versions.keys().cloned().collect();
250 sorted_keys.sort();
251 if let Some(first_version) = sorted_keys.first() {
252 first_version.clone()
253 } else {
254 return Err(HttpError::bad_request("No versions configured"));
255 }
256 }
257 }
258 };
259
260 if let Some(api_version) = config.versions.get(&version_key) {
261 Ok(VersionInfo {
262 version: version_key,
263 api_version: api_version.clone(),
264 is_deprecated: api_version.deprecated,
265 })
266 } else {
267 Err(HttpError::bad_request(format!(
268 "Unsupported version: {}",
269 version_key
270 )))
271 }
272}
273
274impl Middleware for VersioningMiddleware {
275 fn handle(&self, mut request: ElifRequest, next: Next) -> NextFuture<'static> {
276 let config = self.config.clone();
277
278 Box::pin(async move {
279 let extracted_version = match extract_version_from_request(&request, &config.strategy) {
281 Ok(version) => version,
282 Err(err) => {
283 return ElifResponse::bad_request().json_value(serde_json::json!({
284 "error": {
285 "code": "VERSION_EXTRACTION_FAILED",
286 "message": err.to_string()
287 }
288 }));
289 }
290 };
291
292 let version_info = match resolve_version(&config, extracted_version) {
294 Ok(info) => info,
295 Err(err) => {
296 return ElifResponse::bad_request().json_value(serde_json::json!({
297 "error": {
298 "code": "VERSION_RESOLUTION_FAILED",
299 "message": err.to_string()
300 }
301 }));
302 }
303 };
304
305 request.insert_extension(version_info.clone());
307
308 let mut response = next.run(request).await;
310
311 if config.include_deprecation_headers && version_info.api_version.deprecated {
313 let _ = response.add_header("Deprecation", "true");
315
316 if let Some(message) = &version_info.api_version.deprecation_message {
318 let _ = response.add_header("Warning", format!("299 - \"{}\"", message));
319 }
320
321 if let Some(sunset) = &version_info.api_version.sunset_date {
323 let _ = response.add_header("Sunset", sunset);
324 }
325 }
326
327 response
328 })
329 }
330
331 fn name(&self) -> &'static str {
332 "VersioningMiddleware"
333 }
334}
335
336#[derive(Debug, Clone)]
338pub struct VersioningLayer {
339 config: VersioningConfig,
340}
341
342impl VersioningLayer {
343 pub fn new(config: VersioningConfig) -> Self {
345 Self { config }
346 }
347}
348
349impl<S> Layer<S> for VersioningLayer {
350 type Service = VersioningService<S>;
351
352 fn layer(&self, inner: S) -> Self::Service {
353 VersioningService {
354 inner,
355 config: self.config.clone(),
356 }
357 }
358}
359
360#[derive(Debug, Clone)]
362pub struct VersioningService<S> {
363 inner: S,
364 config: VersioningConfig,
365}
366
367impl<S> Service<axum::extract::Request> for VersioningService<S>
368where
369 S: Service<axum::extract::Request, Response = axum::response::Response>
370 + Clone
371 + Send
372 + 'static,
373 S::Future: Send + 'static,
374 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
375{
376 type Response = axum::response::Response;
377 type Error = S::Error;
378 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
379
380 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
381 self.inner.poll_ready(cx)
382 }
383
384 fn call(&mut self, mut request: axum::extract::Request) -> Self::Future {
385 let config = self.config.clone();
386 let mut inner = self.inner.clone();
387
388 Box::pin(async move {
389 let extracted_version = match Self::extract_version_from_request(&config, &request) {
391 Ok(version) => version,
392 Err(error_response) => return Ok(error_response),
393 };
394
395 let version_info = match Self::resolve_version(&config, extracted_version) {
396 Ok(info) => info,
397 Err(error_response) => return Ok(error_response),
398 };
399
400 request.extensions_mut().insert(version_info.clone());
402
403 let mut response = inner.call(request).await?;
405
406 Self::add_version_headers(&config, &version_info, &mut response);
408
409 Ok(response)
410 })
411 }
412}
413
414impl<S> VersioningService<S> {
415 fn extract_version_from_request(
417 config: &VersioningConfig,
418 request: &axum::extract::Request,
419 ) -> Result<Option<String>, axum::response::Response> {
420 static URL_PATH_REGEX: Lazy<regex::Regex> = Lazy::new(|| {
422 regex::Regex::new(r"/api/v?(\d+(?:\.\d+)?)/").expect("Failed to compile URL path regex")
423 });
424 static ACCEPT_HEADER_REGEX: Lazy<regex::Regex> = Lazy::new(|| {
425 regex::Regex::new(r"version=([^;,\s]+)").expect("Failed to compile Accept header regex")
426 });
427
428 let extracted = match &config.strategy {
429 VersionStrategy::UrlPath => {
430 let path = request.uri().path();
432 if let Some(captures) = URL_PATH_REGEX.captures(path) {
433 captures
434 .get(1)
435 .map(|version| format!("v{}", version.as_str()))
436 } else {
437 None
438 }
439 }
440 VersionStrategy::Header(header_name) => request
441 .headers()
442 .get(header_name)
443 .and_then(|h| h.to_str().ok())
444 .map(|s| s.to_string()),
445 VersionStrategy::QueryParam(param_name) => {
446 if let Some(query) = request.uri().query() {
448 if let Ok(params) = serde_urlencoded::from_str::<HashMap<String, String>>(query)
449 {
450 params.get(param_name).map(|s| s.to_string())
451 } else {
452 None
453 }
454 } else {
455 None
456 }
457 }
458 VersionStrategy::AcceptHeader => {
459 if let Some(accept) = request.headers().get("accept") {
460 if let Ok(accept_str) = accept.to_str() {
461 if let Some(captures) = ACCEPT_HEADER_REGEX.captures(accept_str) {
463 captures
464 .get(1)
465 .map(|version| format!("v{}", version.as_str()))
466 } else {
467 None
468 }
469 } else {
470 None
471 }
472 } else {
473 None
474 }
475 }
476 };
477
478 Ok(extracted)
479 }
480
481 fn resolve_version(
483 config: &VersioningConfig,
484 requested_version: Option<String>,
485 ) -> Result<VersionInfo, axum::response::Response> {
486 let version_key = if let Some(version) = requested_version {
487 if config.versions.contains_key(&version) {
488 version
489 } else if config.strict_validation {
490 let error_response = axum::response::Response::builder()
491 .status(400)
492 .body(axum::body::Body::from(format!(
493 "Unsupported API version: {}",
494 version
495 )))
496 .unwrap();
497 return Err(error_response);
498 } else if let Some(default) = &config.default_version {
499 default.clone()
500 } else {
501 let error_response = axum::response::Response::builder()
502 .status(400)
503 .body(axum::body::Body::from(
504 "No valid API version specified and no default available",
505 ))
506 .unwrap();
507 return Err(error_response);
508 }
509 } else if let Some(default) = &config.default_version {
510 default.clone()
511 } else {
512 let error_response = axum::response::Response::builder()
513 .status(400)
514 .body(axum::body::Body::from("API version is required"))
515 .unwrap();
516 return Err(error_response);
517 };
518
519 let api_version = config.versions.get(&version_key).ok_or_else(|| {
520 axum::response::Response::builder()
521 .status(500)
522 .body(axum::body::Body::from(format!(
523 "Version configuration not found: {}",
524 version_key
525 )))
526 .unwrap()
527 })?;
528
529 Ok(VersionInfo {
530 version: version_key,
531 is_deprecated: api_version.deprecated,
532 api_version: api_version.clone(),
533 })
534 }
535
536 fn add_version_headers(
538 config: &VersioningConfig,
539 version_info: &VersionInfo,
540 response: &mut axum::response::Response,
541 ) {
542 let headers = response.headers_mut();
543
544 if let Ok(value) = version_info.version.parse() {
546 headers.insert("X-Api-Version", value);
547 }
548
549 if let Some(default_version) = &config.default_version {
551 if let Ok(value) = default_version.parse() {
552 headers.insert("X-Api-Default-Version", value);
553 }
554 }
555
556 let supported_versions: Vec<String> = config.versions.keys().cloned().collect();
558 if !supported_versions.is_empty() {
559 let versions_str = supported_versions.join(",");
560 if let Ok(value) = versions_str.parse() {
561 headers.insert("X-Api-Supported-Versions", value);
562 }
563 }
564
565 if config.include_deprecation_headers && version_info.is_deprecated {
567 headers.insert("Deprecation", axum::http::HeaderValue::from_static("true"));
569
570 if let Some(message) = &version_info.api_version.deprecation_message {
572 let warning_value = format!("299 - \"{}\"", message);
573 if let Ok(value) = warning_value.parse() {
574 headers.insert("Warning", value);
575 }
576 }
577
578 if let Some(sunset) = &version_info.api_version.sunset_date {
580 if let Ok(value) = sunset.parse() {
581 headers.insert("Sunset", value);
582 }
583 }
584 }
585 }
586}
587
588pub fn versioning_middleware(config: VersioningConfig) -> VersioningMiddleware {
590 VersioningMiddleware::new(config)
591}
592
593pub fn versioning_layer(config: VersioningConfig) -> VersioningLayer {
595 VersioningLayer::new(config)
596}
597
598pub fn default_versioning_middleware() -> VersioningMiddleware {
600 let mut config = VersioningConfig {
601 versions: HashMap::new(),
602 strategy: VersionStrategy::UrlPath,
603 default_version: Some("v1".to_string()),
604 include_deprecation_headers: true,
605 version_header_name: "Api-Version".to_string(),
606 version_param_name: "version".to_string(),
607 strict_validation: true,
608 };
609
610 config.add_version(
612 "v1".to_string(),
613 ApiVersion {
614 version: "v1".to_string(),
615 deprecated: false,
616 deprecation_message: None,
617 sunset_date: None,
618 is_default: true,
619 },
620 );
621
622 VersioningMiddleware::new(config)
623}
624
625pub trait RequestVersionExt {
627 fn version_info(&self) -> Option<&VersionInfo>;
629
630 fn api_version(&self) -> Option<&str>;
632
633 fn is_deprecated_version(&self) -> bool;
635}
636
637impl RequestVersionExt for axum::extract::Request {
638 fn version_info(&self) -> Option<&VersionInfo> {
639 self.extensions().get::<VersionInfo>()
640 }
641
642 fn api_version(&self) -> Option<&str> {
643 self.version_info().map(|v| v.version.as_str())
644 }
645
646 fn is_deprecated_version(&self) -> bool {
647 self.version_info()
648 .map(|v| v.is_deprecated)
649 .unwrap_or(false)
650 }
651}
652
653impl RequestVersionExt for ElifRequest {
654 fn version_info(&self) -> Option<&VersionInfo> {
655 None
657 }
658
659 fn api_version(&self) -> Option<&str> {
660 self.version_info().map(|v| v.version.as_str())
661 }
662
663 fn is_deprecated_version(&self) -> bool {
664 self.version_info()
665 .map(|v| v.is_deprecated)
666 .unwrap_or(false)
667 }
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[test]
675 fn test_version_config_builder() {
676 let config = VersioningConfig::builder()
677 .strategy(VersionStrategy::Header("X-Api-Version".to_string()))
678 .default_version(Some("v2".to_string()))
679 .strict_validation(false)
680 .build()
681 .unwrap();
682
683 assert!(!config.strict_validation);
684 assert_eq!(config.default_version, Some("v2".to_string()));
685 match config.strategy {
686 VersionStrategy::Header(name) => assert_eq!(name, "X-Api-Version"),
687 _ => panic!("Expected Header strategy"),
688 }
689 }
690
691 #[test]
692 fn test_version_deprecation() {
693 let mut config = VersioningConfig::builder().build().unwrap();
694
695 config.add_version(
696 "v1".to_string(),
697 ApiVersion {
698 version: "v1".to_string(),
699 deprecated: false,
700 deprecation_message: None,
701 sunset_date: None,
702 is_default: false,
703 },
704 );
705
706 config.deprecate_version(
707 "v1",
708 Some("Version v1 is deprecated, please use v2".to_string()),
709 Some("2024-12-31".to_string()),
710 );
711
712 let version = config.versions.get("v1").unwrap();
713 assert!(version.deprecated);
714 assert_eq!(
715 version.deprecation_message,
716 Some("Version v1 is deprecated, please use v2".to_string())
717 );
718 }
719
720 #[tokio::test]
721 async fn test_url_path_version_extraction() {
722 let config = VersioningConfig::builder()
723 .strategy(VersionStrategy::UrlPath)
724 .build()
725 .unwrap();
726
727 let _middleware = VersioningMiddleware::new(config);
728
729 }
732}