elif_http/middleware/
versioning.rs

1use crate::{
2    errors::HttpError,
3    middleware::v2::{Middleware, Next, NextFuture},
4    request::ElifRequest,
5    response::ElifResponse,
6};
7use once_cell::sync::Lazy;
8use serde::{Deserialize, Serialize};
9use service_builder::builder;
10use std::collections::HashMap;
11use std::future::Future;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use tower::{Layer, Service};
15
16// Static regex patterns compiled once for performance
17static URL_PATH_VERSION_REGEX: Lazy<regex::Regex> = Lazy::new(|| {
18    regex::Regex::new(r"/api/v?(\d+(?:\.\d+)?)/").expect("Invalid URL path version regex")
19});
20
21static ACCEPT_HEADER_VERSION_REGEX: Lazy<regex::Regex> = Lazy::new(|| {
22    regex::Regex::new(r"version=([^;,\s]+)").expect("Invalid Accept header version regex")
23});
24
25/// API versioning strategy
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum VersionStrategy {
28    /// Version specified in URL path (e.g., /api/v1/users)
29    UrlPath,
30    /// Version specified in header (e.g., Api-Version: v1)
31    Header(String),
32    /// Version specified in query parameter (e.g., ?version=v1)
33    QueryParam(String),
34    /// Version specified in Accept header (e.g., Accept: application/vnd.api+json;version=1)
35    AcceptHeader,
36}
37
38impl Default for VersionStrategy {
39    fn default() -> Self {
40        Self::UrlPath
41    }
42}
43
44/// API version configuration
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ApiVersion {
47    /// Version identifier (e.g., "v1", "v2", "1.0", "2024-01-01")
48    pub version: String,
49    /// Whether this version is deprecated
50    pub deprecated: bool,
51    /// Deprecation warning message
52    pub deprecation_message: Option<String>,
53    /// Date when this version will be removed (ISO 8601)
54    pub sunset_date: Option<String>,
55    /// Whether this version is the default
56    pub is_default: bool,
57}
58
59/// API versioning middleware configuration
60#[derive(Debug, Clone, Default)]
61#[builder]
62pub struct VersioningConfig {
63    /// Available API versions
64    #[builder(default)]
65    pub versions: HashMap<String, ApiVersion>,
66    /// Versioning strategy to use
67    #[builder(default)]
68    pub strategy: VersionStrategy,
69    /// Default version if none specified
70    #[builder(optional)]
71    pub default_version: Option<String>,
72    /// Whether to include deprecation headers
73    #[builder(default = "true")]
74    pub include_deprecation_headers: bool,
75    /// Custom header name for version (when using Header strategy)
76    #[builder(default = "\"Api-Version\".to_string()")]
77    pub version_header_name: String,
78    /// Custom query parameter name for version (when using QueryParam strategy)
79    #[builder(default = "\"version\".to_string()")]
80    pub version_param_name: String,
81    /// Whether to be strict about version validation
82    #[builder(default = "true")]
83    pub strict_validation: bool,
84}
85
86impl VersioningConfig {
87    /// Add a new API version
88    pub fn add_version(&mut self, version: String, api_version: ApiVersion) {
89        self.versions.insert(version, api_version);
90    }
91
92    /// Set a version as deprecated
93    pub fn deprecate_version(
94        &mut self,
95        version: &str,
96        message: Option<String>,
97        sunset_date: Option<String>,
98    ) {
99        if let Some(api_version) = self.versions.get_mut(version) {
100            api_version.deprecated = true;
101            api_version.deprecation_message = message;
102            api_version.sunset_date = sunset_date;
103        }
104    }
105
106    /// Get the default version
107    pub fn get_default_version(&self) -> Option<&ApiVersion> {
108        if let Some(default_version) = &self.default_version {
109            return self.versions.get(default_version);
110        }
111
112        // Find the version marked as default
113        self.versions.values().find(|v| v.is_default)
114    }
115
116    /// Get a specific version
117    pub fn get_version(&self, version: &str) -> Option<&ApiVersion> {
118        self.versions.get(version)
119    }
120
121    /// Get the versioning strategy
122    pub fn get_strategy(&self) -> &VersionStrategy {
123        &self.strategy
124    }
125
126    /// Get all versions
127    pub fn get_versions(&self) -> &HashMap<String, ApiVersion> {
128        &self.versions
129    }
130
131    /// Get all versions as mutable reference
132    pub fn get_versions_mut(&mut self) -> &mut HashMap<String, ApiVersion> {
133        &mut self.versions
134    }
135
136    /// Clone all configuration for rebuilding
137    pub fn clone_config(
138        &self,
139    ) -> (
140        HashMap<String, ApiVersion>,
141        VersionStrategy,
142        Option<String>,
143        bool,
144        String,
145        String,
146        bool,
147    ) {
148        (
149            self.versions.clone(),
150            self.strategy.clone(),
151            self.default_version.clone(),
152            self.include_deprecation_headers,
153            self.version_header_name.clone(),
154            self.version_param_name.clone(),
155            self.strict_validation,
156        )
157    }
158}
159
160/// Extracted version information from request
161#[derive(Debug, Clone)]
162pub struct VersionInfo {
163    /// The requested version
164    pub version: String,
165    /// The API version configuration
166    pub api_version: ApiVersion,
167    /// Whether this version is deprecated
168    pub is_deprecated: bool,
169}
170
171/// API versioning middleware
172#[derive(Debug)]
173pub struct VersioningMiddleware {
174    config: VersioningConfig,
175}
176
177impl VersioningMiddleware {
178    /// Create new versioning middleware
179    pub fn new(config: VersioningConfig) -> Self {
180        Self { config }
181    }
182}
183
184/// Extract version from ElifRequest based on strategy
185fn extract_version_from_request(
186    request: &ElifRequest,
187    strategy: &VersionStrategy,
188) -> Result<Option<String>, HttpError> {
189    match strategy {
190        VersionStrategy::UrlPath => {
191            let path = request.path();
192            if let Some(captures) = URL_PATH_VERSION_REGEX.captures(path) {
193                Ok(Some(captures[1].to_string()))
194            } else {
195                Ok(None)
196            }
197        }
198        VersionStrategy::Header(header_name) => {
199            if let Some(header_value) = request.header(header_name) {
200                if let Ok(version_str) = header_value.to_str() {
201                    Ok(Some(version_str.to_string()))
202                } else {
203                    Err(HttpError::bad_request("Invalid version header"))
204                }
205            } else {
206                Ok(None)
207            }
208        }
209        VersionStrategy::QueryParam(param_name) => {
210            if let Some(query) = request.uri.query() {
211                for pair in query.split('&') {
212                    let mut parts = pair.split('=');
213                    if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
214                        if key == param_name {
215                            return Ok(Some(value.to_string()));
216                        }
217                    }
218                }
219            }
220            Ok(None)
221        }
222        VersionStrategy::AcceptHeader => {
223            if let Some(accept_header) = request.header("Accept") {
224                if let Ok(accept_str) = accept_header.to_str() {
225                    if let Some(captures) = ACCEPT_HEADER_VERSION_REGEX.captures(accept_str) {
226                        return Ok(Some(captures[1].to_string()));
227                    }
228                }
229            }
230            Ok(None)
231        }
232    }
233}
234
235/// Resolve version info from extracted version and config
236fn resolve_version(
237    config: &VersioningConfig,
238    extracted_version: Option<String>,
239) -> Result<VersionInfo, HttpError> {
240    let version_key = match extracted_version {
241        Some(v) => v,
242        None => {
243            if let Some(default) = &config.default_version {
244                default.clone()
245            } else if config.strict_validation {
246                return Err(HttpError::bad_request("Version is required"));
247            } else {
248                // Pick first available version if not strict (sorted for deterministic behavior)
249                let mut sorted_keys: Vec<_> = config.versions.keys().cloned().collect();
250                sorted_keys.sort();
251                if let Some(first_version) = sorted_keys.first() {
252                    first_version.clone()
253                } else {
254                    return Err(HttpError::bad_request("No versions configured"));
255                }
256            }
257        }
258    };
259
260    if let Some(api_version) = config.versions.get(&version_key) {
261        Ok(VersionInfo {
262            version: version_key,
263            api_version: api_version.clone(),
264            is_deprecated: api_version.deprecated,
265        })
266    } else {
267        Err(HttpError::bad_request(format!(
268            "Unsupported version: {}",
269            version_key
270        )))
271    }
272}
273
274impl Middleware for VersioningMiddleware {
275    fn handle(&self, mut request: ElifRequest, next: Next) -> NextFuture<'static> {
276        let config = self.config.clone();
277
278        Box::pin(async move {
279            // Extract version from request
280            let extracted_version = match extract_version_from_request(&request, &config.strategy) {
281                Ok(version) => version,
282                Err(err) => {
283                    return ElifResponse::bad_request().json_value(serde_json::json!({
284                        "error": {
285                            "code": "VERSION_EXTRACTION_FAILED",
286                            "message": err.to_string()
287                        }
288                    }));
289                }
290            };
291
292            // Resolve version using the extracted version
293            let version_info = match resolve_version(&config, extracted_version) {
294                Ok(info) => info,
295                Err(err) => {
296                    return ElifResponse::bad_request().json_value(serde_json::json!({
297                        "error": {
298                            "code": "VERSION_RESOLUTION_FAILED",
299                            "message": err.to_string()
300                        }
301                    }));
302                }
303            };
304
305            // Store version info in request extensions for handlers to use
306            request.insert_extension(version_info.clone());
307
308            // Call next middleware/handler
309            let mut response = next.run(request).await;
310
311            // Add deprecation headers if needed
312            if config.include_deprecation_headers && version_info.api_version.deprecated {
313                // Add Deprecation header
314                let _ = response.add_header("Deprecation", "true");
315
316                // Add Warning header if deprecation message exists
317                if let Some(message) = &version_info.api_version.deprecation_message {
318                    let _ = response.add_header("Warning", format!("299 - \"{}\"", message));
319                }
320
321                // Add Sunset header if sunset date exists
322                if let Some(sunset) = &version_info.api_version.sunset_date {
323                    let _ = response.add_header("Sunset", sunset);
324                }
325            }
326
327            response
328        })
329    }
330
331    fn name(&self) -> &'static str {
332        "VersioningMiddleware"
333    }
334}
335
336/// Tower Layer implementation for VersioningMiddleware
337#[derive(Debug, Clone)]
338pub struct VersioningLayer {
339    config: VersioningConfig,
340}
341
342impl VersioningLayer {
343    /// Create a new versioning layer
344    pub fn new(config: VersioningConfig) -> Self {
345        Self { config }
346    }
347}
348
349impl<S> Layer<S> for VersioningLayer {
350    type Service = VersioningService<S>;
351
352    fn layer(&self, inner: S) -> Self::Service {
353        VersioningService {
354            inner,
355            config: self.config.clone(),
356        }
357    }
358}
359
360/// Tower Service implementation for versioning
361#[derive(Debug, Clone)]
362pub struct VersioningService<S> {
363    inner: S,
364    config: VersioningConfig,
365}
366
367impl<S> Service<axum::extract::Request> for VersioningService<S>
368where
369    S: Service<axum::extract::Request, Response = axum::response::Response>
370        + Clone
371        + Send
372        + 'static,
373    S::Future: Send + 'static,
374    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
375{
376    type Response = axum::response::Response;
377    type Error = S::Error;
378    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
379
380    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
381        self.inner.poll_ready(cx)
382    }
383
384    fn call(&mut self, mut request: axum::extract::Request) -> Self::Future {
385        let config = self.config.clone();
386        let mut inner = self.inner.clone();
387
388        Box::pin(async move {
389            // Extract version from request
390            let extracted_version = match Self::extract_version_from_request(&config, &request) {
391                Ok(version) => version,
392                Err(error_response) => return Ok(error_response),
393            };
394
395            let version_info = match Self::resolve_version(&config, extracted_version) {
396                Ok(info) => info,
397                Err(error_response) => return Ok(error_response),
398            };
399
400            // Store version info in request extensions
401            request.extensions_mut().insert(version_info.clone());
402
403            // Call the inner service
404            let mut response = inner.call(request).await?;
405
406            // Add versioning headers to response
407            Self::add_version_headers(&config, &version_info, &mut response);
408
409            Ok(response)
410        })
411    }
412}
413
414impl<S> VersioningService<S> {
415    /// Extract version from request based on strategy
416    fn extract_version_from_request(
417        config: &VersioningConfig,
418        request: &axum::extract::Request,
419    ) -> Result<Option<String>, axum::response::Response> {
420        // Local static regex definitions for better encapsulation and performance
421        static URL_PATH_REGEX: Lazy<regex::Regex> = Lazy::new(|| {
422            regex::Regex::new(r"/api/v?(\d+(?:\.\d+)?)/").expect("Failed to compile URL path regex")
423        });
424        static ACCEPT_HEADER_REGEX: Lazy<regex::Regex> = Lazy::new(|| {
425            regex::Regex::new(r"version=([^;,\s]+)").expect("Failed to compile Accept header regex")
426        });
427
428        let extracted = match &config.strategy {
429            VersionStrategy::UrlPath => {
430                // Extract version from URL path (e.g., /api/v1/users -> v1)
431                let path = request.uri().path();
432                if let Some(captures) = URL_PATH_REGEX.captures(path) {
433                    captures
434                        .get(1)
435                        .map(|version| format!("v{}", version.as_str()))
436                } else {
437                    None
438                }
439            }
440            VersionStrategy::Header(header_name) => request
441                .headers()
442                .get(header_name)
443                .and_then(|h| h.to_str().ok())
444                .map(|s| s.to_string()),
445            VersionStrategy::QueryParam(param_name) => {
446                // Parse query parameters from URI
447                if let Some(query) = request.uri().query() {
448                    if let Ok(params) = serde_urlencoded::from_str::<HashMap<String, String>>(query)
449                    {
450                        params.get(param_name).map(|s| s.to_string())
451                    } else {
452                        None
453                    }
454                } else {
455                    None
456                }
457            }
458            VersionStrategy::AcceptHeader => {
459                if let Some(accept) = request.headers().get("accept") {
460                    if let Ok(accept_str) = accept.to_str() {
461                        // Parse Accept header for version (e.g., application/vnd.api+json;version=1)
462                        if let Some(captures) = ACCEPT_HEADER_REGEX.captures(accept_str) {
463                            captures
464                                .get(1)
465                                .map(|version| format!("v{}", version.as_str()))
466                        } else {
467                            None
468                        }
469                    } else {
470                        None
471                    }
472                } else {
473                    None
474                }
475            }
476        };
477
478        Ok(extracted)
479    }
480
481    /// Resolve version to API version configuration
482    fn resolve_version(
483        config: &VersioningConfig,
484        requested_version: Option<String>,
485    ) -> Result<VersionInfo, axum::response::Response> {
486        let version_key = if let Some(version) = requested_version {
487            if config.versions.contains_key(&version) {
488                version
489            } else if config.strict_validation {
490                let error_response = axum::response::Response::builder()
491                    .status(400)
492                    .body(axum::body::Body::from(format!(
493                        "Unsupported API version: {}",
494                        version
495                    )))
496                    .unwrap();
497                return Err(error_response);
498            } else if let Some(default) = &config.default_version {
499                default.clone()
500            } else {
501                let error_response = axum::response::Response::builder()
502                    .status(400)
503                    .body(axum::body::Body::from(
504                        "No valid API version specified and no default available",
505                    ))
506                    .unwrap();
507                return Err(error_response);
508            }
509        } else if let Some(default) = &config.default_version {
510            default.clone()
511        } else {
512            let error_response = axum::response::Response::builder()
513                .status(400)
514                .body(axum::body::Body::from("API version is required"))
515                .unwrap();
516            return Err(error_response);
517        };
518
519        let api_version = config.versions.get(&version_key).ok_or_else(|| {
520            axum::response::Response::builder()
521                .status(500)
522                .body(axum::body::Body::from(format!(
523                    "Version configuration not found: {}",
524                    version_key
525                )))
526                .unwrap()
527        })?;
528
529        Ok(VersionInfo {
530            version: version_key,
531            is_deprecated: api_version.deprecated,
532            api_version: api_version.clone(),
533        })
534    }
535
536    /// Add version headers to response
537    fn add_version_headers(
538        config: &VersioningConfig,
539        version_info: &VersionInfo,
540        response: &mut axum::response::Response,
541    ) {
542        let headers = response.headers_mut();
543
544        // Add current version header
545        if let Ok(value) = version_info.version.parse() {
546            headers.insert("X-Api-Version", value);
547        }
548
549        // Add API version support information
550        if let Some(default_version) = &config.default_version {
551            if let Ok(value) = default_version.parse() {
552                headers.insert("X-Api-Default-Version", value);
553            }
554        }
555
556        // Add supported versions list
557        let supported_versions: Vec<String> = config.versions.keys().cloned().collect();
558        if !supported_versions.is_empty() {
559            let versions_str = supported_versions.join(",");
560            if let Ok(value) = versions_str.parse() {
561                headers.insert("X-Api-Supported-Versions", value);
562            }
563        }
564
565        // Add deprecation headers if needed
566        if config.include_deprecation_headers && version_info.is_deprecated {
567            // Use from_static for known static values
568            headers.insert("Deprecation", axum::http::HeaderValue::from_static("true"));
569
570            // Handle dynamic warning message safely
571            if let Some(message) = &version_info.api_version.deprecation_message {
572                let warning_value = format!("299 - \"{}\"", message);
573                if let Ok(value) = warning_value.parse() {
574                    headers.insert("Warning", value);
575                }
576            }
577
578            // Handle dynamic sunset date safely
579            if let Some(sunset) = &version_info.api_version.sunset_date {
580                if let Ok(value) = sunset.parse() {
581                    headers.insert("Sunset", value);
582                }
583            }
584        }
585    }
586}
587
588/// Convenience functions for creating versioning middleware
589pub fn versioning_middleware(config: VersioningConfig) -> VersioningMiddleware {
590    VersioningMiddleware::new(config)
591}
592
593/// Create versioning layer for use with axum routers
594pub fn versioning_layer(config: VersioningConfig) -> VersioningLayer {
595    VersioningLayer::new(config)
596}
597
598/// Create versioning middleware with default configuration
599pub fn default_versioning_middleware() -> VersioningMiddleware {
600    let mut config = VersioningConfig {
601        versions: HashMap::new(),
602        strategy: VersionStrategy::UrlPath,
603        default_version: Some("v1".to_string()),
604        include_deprecation_headers: true,
605        version_header_name: "Api-Version".to_string(),
606        version_param_name: "version".to_string(),
607        strict_validation: true,
608    };
609
610    // Add default v1 version
611    config.add_version(
612        "v1".to_string(),
613        ApiVersion {
614            version: "v1".to_string(),
615            deprecated: false,
616            deprecation_message: None,
617            sunset_date: None,
618            is_default: true,
619        },
620    );
621
622    VersioningMiddleware::new(config)
623}
624
625/// Extension trait to get version info from request
626pub trait RequestVersionExt {
627    /// Get version information from request
628    fn version_info(&self) -> Option<&VersionInfo>;
629
630    /// Get current API version string
631    fn api_version(&self) -> Option<&str>;
632
633    /// Check if current version is deprecated
634    fn is_deprecated_version(&self) -> bool;
635}
636
637impl RequestVersionExt for axum::extract::Request {
638    fn version_info(&self) -> Option<&VersionInfo> {
639        self.extensions().get::<VersionInfo>()
640    }
641
642    fn api_version(&self) -> Option<&str> {
643        self.version_info().map(|v| v.version.as_str())
644    }
645
646    fn is_deprecated_version(&self) -> bool {
647        self.version_info()
648            .map(|v| v.is_deprecated)
649            .unwrap_or(false)
650    }
651}
652
653impl RequestVersionExt for ElifRequest {
654    fn version_info(&self) -> Option<&VersionInfo> {
655        // Note: This will need implementation when ElifRequest has extensions support
656        None
657    }
658
659    fn api_version(&self) -> Option<&str> {
660        self.version_info().map(|v| v.version.as_str())
661    }
662
663    fn is_deprecated_version(&self) -> bool {
664        self.version_info()
665            .map(|v| v.is_deprecated)
666            .unwrap_or(false)
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673
674    #[test]
675    fn test_version_config_builder() {
676        let config = VersioningConfig::builder()
677            .strategy(VersionStrategy::Header("X-Api-Version".to_string()))
678            .default_version(Some("v2".to_string()))
679            .strict_validation(false)
680            .build()
681            .unwrap();
682
683        assert!(!config.strict_validation);
684        assert_eq!(config.default_version, Some("v2".to_string()));
685        match config.strategy {
686            VersionStrategy::Header(name) => assert_eq!(name, "X-Api-Version"),
687            _ => panic!("Expected Header strategy"),
688        }
689    }
690
691    #[test]
692    fn test_version_deprecation() {
693        let mut config = VersioningConfig::builder().build().unwrap();
694
695        config.add_version(
696            "v1".to_string(),
697            ApiVersion {
698                version: "v1".to_string(),
699                deprecated: false,
700                deprecation_message: None,
701                sunset_date: None,
702                is_default: false,
703            },
704        );
705
706        config.deprecate_version(
707            "v1",
708            Some("Version v1 is deprecated, please use v2".to_string()),
709            Some("2024-12-31".to_string()),
710        );
711
712        let version = config.versions.get("v1").unwrap();
713        assert!(version.deprecated);
714        assert_eq!(
715            version.deprecation_message,
716            Some("Version v1 is deprecated, please use v2".to_string())
717        );
718    }
719
720    #[tokio::test]
721    async fn test_url_path_version_extraction() {
722        let config = VersioningConfig::builder()
723            .strategy(VersionStrategy::UrlPath)
724            .build()
725            .unwrap();
726
727        let _middleware = VersioningMiddleware::new(config);
728
729        // Test URL path extraction logic would go here
730        // This is a simplified test structure
731    }
732}