1pub mod config;
35pub mod handler;
36pub mod middleware;
37pub mod reverse;
38
39use async_trait::async_trait;
40pub use config::{VersioningConfig, VersioningManager, VersioningStrategy};
41pub use handler::{
42 ConfigurableVersionedHandler, SimpleVersionedHandler, VersionResponseBuilder, VersionedHandler,
43 VersionedHandlerBuilder, VersionedHandlerWrapper,
44};
45pub use middleware::{ApiVersion, RequestVersionExt, VersioningMiddleware};
46use regex::Regex;
47use reinhardt_core::exception::{Error, Result};
48use reinhardt_http::Request;
49pub use reverse::{
50 ApiDocFormat, ApiDocUrlBuilder, UrlReverseManager, VersionedUrlBuilder,
51 VersioningStrategy as ReverseVersioningStrategy,
52};
53use std::collections::{HashMap, HashSet};
54use std::sync::OnceLock;
55use thiserror::Error as ThisError;
56
57#[derive(Debug, ThisError)]
59pub enum VersioningError {
60 #[error("Invalid version in Accept header")]
62 InvalidAcceptHeader,
63
64 #[error("Invalid version in URL path")]
66 InvalidURLPath,
67
68 #[error("Invalid version in URL namespace")]
70 InvalidNamespace,
71
72 #[error("Invalid version in hostname")]
74 InvalidHostname,
75
76 #[error("Invalid version in query parameter")]
78 InvalidQueryParameter,
79
80 #[error("Version not allowed: {0}")]
82 VersionNotAllowed(String),
83}
84
85#[async_trait]
87pub trait BaseVersioning: Send + Sync {
88 async fn determine_version(&self, request: &Request) -> Result<String>;
90
91 fn default_version(&self) -> Option<&str>;
93
94 fn allowed_versions(&self) -> Option<&HashSet<String>>;
96
97 fn is_allowed_version(&self, version: &str) -> bool {
99 if let Some(allowed) = self.allowed_versions() {
100 if allowed.is_empty() {
101 return true;
102 }
103 return allowed.contains(version) || (self.default_version() == Some(version));
104 }
105 true
106 }
107
108 fn version_param(&self) -> &str {
110 "version"
111 }
112}
113
114#[derive(Debug, Clone)]
118pub struct AcceptHeaderVersioning {
119 pub default_version: Option<String>,
121 pub allowed_versions: HashSet<String>,
123 pub version_param: String,
125}
126
127impl AcceptHeaderVersioning {
128 pub fn new() -> Self {
140 Self {
141 default_version: None,
142 allowed_versions: HashSet::new(),
143 version_param: "version".to_string(),
144 }
145 }
146 pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
158 self.default_version = Some(version.into());
159 self
160 }
161 pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
175 self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
176 self
177 }
178 pub fn with_version_param(mut self, param: impl Into<String>) -> Self {
190 self.version_param = param.into();
191 self
192 }
193}
194
195impl Default for AcceptHeaderVersioning {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[async_trait]
202impl BaseVersioning for AcceptHeaderVersioning {
203 async fn determine_version(&self, request: &Request) -> Result<String> {
204 if let Some(accept) = request.headers.get("accept") {
206 let accept_str = accept
207 .to_str()
208 .map_err(|_| Error::Validation(VersioningError::InvalidAcceptHeader.to_string()))?;
209
210 if let Some(params_start) = accept_str.find(';') {
212 let params = &accept_str[params_start + 1..];
213 for param in params.split(';') {
214 let param = param.trim();
215 if let Some((key, value)) = param.split_once('=')
216 && key.trim() == self.version_param
217 {
218 let version = value.trim().trim_matches('"');
219 if self.is_allowed_version(version) {
220 return Ok(version.to_string());
221 } else {
222 return Err(Error::Validation(
223 VersioningError::VersionNotAllowed(version.to_string()).to_string(),
224 ));
225 }
226 }
227 }
228 }
229 }
230
231 Ok(self
233 .default_version
234 .clone()
235 .unwrap_or_else(|| "1.0".to_string()))
236 }
237
238 fn default_version(&self) -> Option<&str> {
239 self.default_version.as_deref()
240 }
241
242 fn allowed_versions(&self) -> Option<&HashSet<String>> {
243 Some(&self.allowed_versions)
244 }
245
246 fn version_param(&self) -> &str {
247 &self.version_param
248 }
249}
250
251#[derive(Debug, Clone)]
255pub struct URLPathVersioning {
256 pub default_version: Option<String>,
258 pub allowed_versions: HashSet<String>,
260 pub version_param: String,
262 pub path_regex: Regex,
264}
265
266impl URLPathVersioning {
267 pub fn new() -> Self {
278 Self {
279 default_version: None,
280 allowed_versions: HashSet::new(),
281 version_param: "version".to_string(),
282 path_regex: Regex::new(r"/v(\d+\.?\d*)(?:/|$)").unwrap(),
283 }
284 }
285 pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
297 self.default_version = Some(version.into());
298 self
299 }
300 pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
313 self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
314 self
315 }
316 pub fn with_version_param(mut self, param: impl Into<String>) -> Self {
328 self.version_param = param.into();
329 self
330 }
331 pub fn with_path_regex(mut self, regex: Regex) -> Self {
345 self.path_regex = regex;
346 self
347 }
348
349 pub fn with_pattern(mut self, pattern: &str) -> Self {
363 let regex_pattern = pattern.replace("{version}", "([^/]+)");
365 if let Ok(regex) = Regex::new(®ex_pattern) {
366 self.path_regex = regex;
367 }
368 self
369 }
370}
371
372impl Default for URLPathVersioning {
373 fn default() -> Self {
374 Self::new()
375 }
376}
377
378#[async_trait]
379impl BaseVersioning for URLPathVersioning {
380 async fn determine_version(&self, request: &Request) -> Result<String> {
381 let path = request.uri.path();
382
383 if let Some(captures) = self.path_regex.captures(path)
385 && let Some(version_match) = captures.get(1)
386 {
387 let version = version_match.as_str();
388 if self.is_allowed_version(version) {
389 return Ok(version.to_string());
390 } else {
391 return Err(Error::Validation(
392 VersioningError::VersionNotAllowed(version.to_string()).to_string(),
393 ));
394 }
395 }
396
397 Ok(self
399 .default_version
400 .clone()
401 .unwrap_or_else(|| "1.0".to_string()))
402 }
403
404 fn default_version(&self) -> Option<&str> {
405 self.default_version.as_deref()
406 }
407
408 fn allowed_versions(&self) -> Option<&HashSet<String>> {
409 Some(&self.allowed_versions)
410 }
411
412 fn version_param(&self) -> &str {
413 &self.version_param
414 }
415}
416
417#[derive(Debug, Clone)]
421pub struct HostNameVersioning {
422 pub default_version: Option<String>,
424 pub allowed_versions: HashSet<String>,
426 pub hostname_regex: Regex,
428 pub hostname_to_version: HashMap<String, String>,
431}
432
433impl HostNameVersioning {
434 pub fn new() -> Self {
445 Self {
446 default_version: None,
447 allowed_versions: HashSet::new(),
448 hostname_regex: Regex::new(r"^([a-zA-Z0-9]+)\.").unwrap(),
449 hostname_to_version: HashMap::new(),
450 }
451 }
452 pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
464 self.default_version = Some(version.into());
465 self
466 }
467 pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
480 self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
481 self
482 }
483 pub fn with_hostname_regex(mut self, regex: Regex) -> Self {
497 self.hostname_regex = regex;
498 self
499 }
500
501 pub fn with_host_format(mut self, format: &str) -> Self {
515 const PLACEHOLDER: &str = "__REINHARDT_VERSION_PLACEHOLDER__";
518 let pattern = format.replace("{version}", PLACEHOLDER);
519 let pattern = pattern.replace(".", "\\.");
520 let pattern = pattern.replace(PLACEHOLDER, "([^.]+)");
521 let pattern = format!("^{}", pattern);
522 if let Ok(regex) = Regex::new(&pattern) {
523 self.hostname_regex = regex;
524 }
525 self
526 }
527
528 pub fn with_hostname_pattern(mut self, version: &str, hostname: &str) -> Self {
545 self.allowed_versions.insert(version.to_string());
546 self.hostname_to_version
547 .insert(hostname.to_string(), version.to_string());
548 self
549 }
550}
551
552impl Default for HostNameVersioning {
553 fn default() -> Self {
554 Self::new()
555 }
556}
557
558#[async_trait]
559impl BaseVersioning for HostNameVersioning {
560 async fn determine_version(&self, request: &Request) -> Result<String> {
561 if let Some(host) = request.headers.get("host") {
563 let host_str = host
564 .to_str()
565 .map_err(|_| Error::Validation(VersioningError::InvalidHostname.to_string()))?;
566
567 let hostname = host_str.split(':').next().unwrap_or(host_str);
569
570 if let Some(version) = self.hostname_to_version.get(hostname)
572 && self.is_allowed_version(version)
573 {
574 return Ok(version.clone());
575 }
576
577 if let Some(captures) = self.hostname_regex.captures(hostname)
579 && let Some(version_match) = captures.get(1)
580 {
581 let version = version_match.as_str();
582 if self.is_allowed_version(version) {
583 return Ok(version.to_string());
584 }
585 }
586 }
587
588 Ok(self
590 .default_version
591 .clone()
592 .unwrap_or_else(|| "1.0".to_string()))
593 }
594
595 fn default_version(&self) -> Option<&str> {
596 self.default_version.as_deref()
597 }
598
599 fn allowed_versions(&self) -> Option<&HashSet<String>> {
600 Some(&self.allowed_versions)
601 }
602}
603
604#[derive(Debug, Clone)]
608pub struct QueryParameterVersioning {
609 pub default_version: Option<String>,
611 pub allowed_versions: HashSet<String>,
613 pub version_param: String,
615}
616
617impl QueryParameterVersioning {
618 pub fn new() -> Self {
630 Self {
631 default_version: None,
632 allowed_versions: HashSet::new(),
633 version_param: "version".to_string(),
634 }
635 }
636 pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
648 self.default_version = Some(version.into());
649 self
650 }
651 pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
664 self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
665 self
666 }
667 pub fn with_version_param(mut self, param: impl Into<String>) -> Self {
680 self.version_param = param.into();
681 self
682 }
683}
684
685impl Default for QueryParameterVersioning {
686 fn default() -> Self {
687 Self::new()
688 }
689}
690
691#[async_trait]
692impl BaseVersioning for QueryParameterVersioning {
693 async fn determine_version(&self, request: &Request) -> Result<String> {
694 if let Some(query) = request.uri.query() {
696 for param in query.split('&') {
697 if let Some((key, value)) = param.split_once('=')
698 && key == self.version_param
699 {
700 if self.is_allowed_version(value) {
701 return Ok(value.to_string());
702 } else {
703 return Err(Error::Validation(
704 VersioningError::VersionNotAllowed(value.to_string()).to_string(),
705 ));
706 }
707 }
708 }
709 }
710
711 Ok(self
713 .default_version
714 .clone()
715 .unwrap_or_else(|| "1.0".to_string()))
716 }
717
718 fn default_version(&self) -> Option<&str> {
719 self.default_version.as_deref()
720 }
721
722 fn allowed_versions(&self) -> Option<&HashSet<String>> {
723 Some(&self.allowed_versions)
724 }
725
726 fn version_param(&self) -> &str {
727 &self.version_param
728 }
729}
730
731#[derive(Debug)]
736pub struct NamespaceVersioning {
737 pub default_version: Option<String>,
739 pub allowed_versions: HashSet<String>,
741 pub pattern: String,
743 pub namespace_prefix: Option<String>,
745 compiled_regex: OnceLock<Option<Regex>>,
747}
748
749impl Clone for NamespaceVersioning {
750 fn clone(&self) -> Self {
751 Self {
752 default_version: self.default_version.clone(),
753 allowed_versions: self.allowed_versions.clone(),
754 pattern: self.pattern.clone(),
755 namespace_prefix: self.namespace_prefix.clone(),
756 compiled_regex: OnceLock::new(),
758 }
759 }
760}
761
762impl NamespaceVersioning {
763 pub fn new() -> Self {
775 Self {
776 default_version: None,
777 allowed_versions: HashSet::new(),
778 pattern: "/v{version}/".to_string(),
779 namespace_prefix: None,
780 compiled_regex: OnceLock::new(),
781 }
782 }
783 pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
795 self.default_version = Some(version.into());
796 self
797 }
798 pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
812 self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
813 self
814 }
815
816 pub fn with_namespace_prefix(mut self, prefix: &str) -> Self {
830 self.namespace_prefix = Some(prefix.to_string());
831 self
832 }
833
834 pub fn with_pattern(mut self, pattern: &str) -> Self {
849 self.pattern = pattern.to_string();
850 self.compiled_regex = OnceLock::new();
852 self
853 }
854}
855
856impl Default for NamespaceVersioning {
857 fn default() -> Self {
858 Self::new()
859 }
860}
861
862#[async_trait]
863impl BaseVersioning for NamespaceVersioning {
864 async fn determine_version(&self, request: &Request) -> Result<String> {
865 let path = request.uri.path();
866
867 if let Some(version) = self.extract_version_from_path(path)
869 && self.is_allowed_version(&version)
870 {
871 return Ok(version);
872 }
873
874 Ok(self
876 .default_version
877 .clone()
878 .unwrap_or_else(|| "1.0".to_string()))
879 }
880
881 fn default_version(&self) -> Option<&str> {
882 self.default_version.as_deref()
883 }
884
885 fn allowed_versions(&self) -> Option<&HashSet<String>> {
886 Some(&self.allowed_versions)
887 }
888}
889
890impl NamespaceVersioning {
891 fn get_compiled_regex(&self) -> Option<&Regex> {
893 self.compiled_regex
894 .get_or_init(|| {
895 let regex_pattern = self
896 .pattern
897 .replace("{version}", r"([^/]+)")
898 .replace("/", r"\/");
899 let full_pattern = format!("^{}", regex_pattern);
900 regex::Regex::new(&full_pattern).ok()
901 })
902 .as_ref()
903 }
904
905 fn extract_version_from_path(&self, path: &str) -> Option<String> {
907 if let Some(regex) = self.get_compiled_regex()
908 && let Some(captures) = regex.captures(path)
909 && let Some(version_match) = captures.get(1)
910 {
911 return Some(version_match.as_str().to_string());
912 }
913 None
914 }
915
916 fn is_allowed_version(&self, version: &str) -> bool {
918 self.allowed_versions.is_empty() || self.allowed_versions.contains(version)
919 }
920
921 #[allow(dead_code)]
941 fn extract_version_from_router_stub(&self, _router: &(), path: &str) -> Option<String> {
942 self.extract_version_from_path(path)
943 }
944
945 #[allow(dead_code)]
981 fn get_available_versions_from_router_stub(&self, _router: &()) -> Vec<String> {
982 Vec::new()
983 }
984}
985
986#[cfg(test)]
987pub mod test_utils {
988 use bytes::Bytes;
989 use hyper::header::HeaderName;
990 use hyper::{HeaderMap, Method, Uri, Version};
991 use reinhardt_http::Request;
992
993 pub fn create_test_request(uri: &str, headers: Vec<(String, String)>) -> Request {
994 let uri = uri.parse::<Uri>().unwrap();
995 let mut header_map = HeaderMap::new();
996 for (key, value) in headers {
997 let header_name: HeaderName = key.parse().unwrap();
998 header_map.insert(header_name, value.parse().unwrap());
999 }
1000
1001 Request::builder()
1002 .method(Method::GET)
1003 .uri(uri)
1004 .version(Version::HTTP_11)
1005 .headers(header_map)
1006 .body(Bytes::new())
1007 .build()
1008 .unwrap()
1009 }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use super::*;
1015 use test_utils::create_test_request;
1016
1017 #[tokio::test]
1018 async fn test_accept_header_versioning() {
1019 let versioning = AcceptHeaderVersioning::new()
1020 .with_default_version("1.0")
1021 .with_allowed_versions(vec!["1.0", "2.0"]);
1022
1023 let request = create_test_request(
1025 "/users/",
1026 vec![(
1027 "accept".to_string(),
1028 "application/json; version=2.0".to_string(),
1029 )],
1030 );
1031 let version = versioning.determine_version(&request).await.unwrap();
1032 assert_eq!(version, "2.0");
1033
1034 let request = create_test_request(
1036 "/users/",
1037 vec![("accept".to_string(), "application/json".to_string())],
1038 );
1039 let version = versioning.determine_version(&request).await.unwrap();
1040 assert_eq!(version, "1.0");
1041 }
1042
1043 #[tokio::test]
1044 async fn test_url_path_versioning() {
1045 let versioning = URLPathVersioning::new()
1046 .with_default_version("1.0")
1047 .with_allowed_versions(vec!["1.0", "2.0", "2"]);
1048
1049 let request = create_test_request("/v2/users/", vec![]);
1051 let version = versioning.determine_version(&request).await.unwrap();
1052 assert_eq!(version, "2");
1053
1054 let request = create_test_request("/users/", vec![]);
1056 let version = versioning.determine_version(&request).await.unwrap();
1057 assert_eq!(version, "1.0");
1058 }
1059
1060 #[tokio::test]
1061 async fn test_hostname_versioning() {
1062 let versioning = HostNameVersioning::new()
1063 .with_default_version("1.0")
1064 .with_allowed_versions(vec!["v1", "v2"]);
1065
1066 let request = create_test_request(
1068 "/users/",
1069 vec![("host".to_string(), "v2.api.example.com".to_string())],
1070 );
1071 let version = versioning.determine_version(&request).await.unwrap();
1072 assert_eq!(version, "v2");
1073
1074 let request = create_test_request(
1076 "/users/",
1077 vec![("host".to_string(), "api.example.com".to_string())],
1078 );
1079 let version = versioning.determine_version(&request).await.unwrap();
1080 assert_eq!(version, "1.0");
1081 }
1082
1083 #[tokio::test]
1084 async fn test_query_parameter_versioning() {
1085 let versioning = QueryParameterVersioning::new()
1086 .with_default_version("1.0")
1087 .with_allowed_versions(vec!["1.0", "2.0"]);
1088
1089 let request = create_test_request("/users/?version=2.0", vec![]);
1091 let version = versioning.determine_version(&request).await.unwrap();
1092 assert_eq!(version, "2.0");
1093
1094 let request = create_test_request("/users/", vec![]);
1096 let version = versioning.determine_version(&request).await.unwrap();
1097 assert_eq!(version, "1.0");
1098 }
1099
1100 #[tokio::test]
1101 async fn test_namespace_versioning() {
1102 let versioning = NamespaceVersioning::new()
1103 .with_default_version("1.0")
1104 .with_allowed_versions(vec!["1", "1.0", "2", "2.0", "3.0"]);
1105
1106 let request = create_test_request("/v1/users/", vec![]);
1108 let version = versioning.determine_version(&request).await.unwrap();
1109 assert_eq!(version, "1");
1110
1111 let request = create_test_request("/v2.0/users/", vec![]);
1113 let version = versioning.determine_version(&request).await.unwrap();
1114 assert_eq!(version, "2.0");
1115
1116 let request = create_test_request("/users/", vec![]);
1118 let version = versioning.determine_version(&request).await.unwrap();
1119 assert_eq!(version, "1.0");
1120
1121 let request = create_test_request("/api/users/", vec![]);
1123 let version = versioning.determine_version(&request).await.unwrap();
1124 assert_eq!(version, "1.0");
1125 }
1126
1127 #[tokio::test]
1128 async fn test_namespace_versioning_with_custom_pattern() {
1129 let versioning = NamespaceVersioning::new()
1130 .with_default_version("1.0")
1131 .with_pattern("/api/v{version}/")
1132 .with_allowed_versions(vec!["1", "2"]);
1133
1134 let request = create_test_request("/api/v1/users/", vec![]);
1136 let version = versioning.determine_version(&request).await.unwrap();
1137 assert_eq!(version, "1");
1138
1139 let request = create_test_request("/api/v2/users/", vec![]);
1141 let version = versioning.determine_version(&request).await.unwrap();
1142 assert_eq!(version, "2");
1143
1144 let request = create_test_request("/v1/users/", vec![]);
1146 let version = versioning.determine_version(&request).await.unwrap();
1147 assert_eq!(version, "1.0"); }
1149
1150 #[tokio::test]
1151 async fn test_hostname_versioning_with_host_format_dots_not_corrupted() {
1152 let versioning = HostNameVersioning::new()
1154 .with_host_format("{version}.api.v2.example.com")
1155 .with_allowed_versions(vec!["v1", "v3"]);
1156
1157 let request = create_test_request(
1159 "/users/",
1160 vec![("host".to_string(), "v1.api.v2.example.com".to_string())],
1161 );
1162 let version = versioning.determine_version(&request).await.unwrap();
1163
1164 assert_eq!(version, "v1");
1166
1167 let request = create_test_request(
1169 "/users/",
1170 vec![("host".to_string(), "v3.api.v2.example.com".to_string())],
1171 );
1172 let version = versioning.determine_version(&request).await.unwrap();
1173
1174 assert_eq!(version, "v3");
1176 }
1177
1178 }