use super::version::ApiVersion;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VersionStrategy {
Path {
pattern: String,
},
Header {
name: String,
},
Query {
param: String,
},
Accept {
pattern: String,
},
Custom {
name: String,
},
}
impl VersionStrategy {
pub fn path() -> Self {
Self::Path {
pattern: "/v{version}/".to_string(),
}
}
pub fn path_with_pattern(pattern: impl Into<String>) -> Self {
Self::Path {
pattern: pattern.into(),
}
}
pub fn header() -> Self {
Self::Header {
name: "X-API-Version".to_string(),
}
}
pub fn header_with_name(name: impl Into<String>) -> Self {
Self::Header { name: name.into() }
}
pub fn query() -> Self {
Self::Query {
param: "version".to_string(),
}
}
pub fn query_with_param(param: impl Into<String>) -> Self {
Self::Query {
param: param.into(),
}
}
pub fn accept() -> Self {
Self::Accept {
pattern: "application/vnd.api.v{version}+json".to_string(),
}
}
pub fn accept_with_pattern(pattern: impl Into<String>) -> Self {
Self::Accept {
pattern: pattern.into(),
}
}
pub fn custom(name: impl Into<String>) -> Self {
Self::Custom { name: name.into() }
}
}
impl Default for VersionStrategy {
fn default() -> Self {
Self::path()
}
}
#[derive(Debug, Clone)]
pub struct VersionExtractor {
strategies: Vec<VersionStrategy>,
default: ApiVersion,
}
impl VersionExtractor {
pub fn new() -> Self {
Self {
strategies: vec![VersionStrategy::path()],
default: ApiVersion::v1(),
}
}
pub fn with_strategy(strategy: VersionStrategy) -> Self {
Self {
strategies: vec![strategy],
default: ApiVersion::v1(),
}
}
pub fn with_strategies(strategies: Vec<VersionStrategy>) -> Self {
Self {
strategies,
default: ApiVersion::v1(),
}
}
pub fn default_version(mut self, version: ApiVersion) -> Self {
self.default = version;
self
}
pub fn add_strategy(mut self, strategy: VersionStrategy) -> Self {
self.strategies.push(strategy);
self
}
pub fn extract_from_path(&self, path: &str) -> Option<ApiVersion> {
for strategy in &self.strategies {
if let VersionStrategy::Path { pattern } = strategy {
if let Some(version) = Self::extract_path_version(path, pattern) {
return Some(version);
}
}
}
None
}
pub fn extract_from_headers(&self, headers: &HashMap<String, String>) -> Option<ApiVersion> {
for strategy in &self.strategies {
match strategy {
VersionStrategy::Header { name } => {
if let Some(value) = headers.get(&name.to_lowercase()) {
if let Ok(version) = value.parse() {
return Some(version);
}
}
}
VersionStrategy::Accept { pattern } => {
if let Some(accept) = headers.get("accept") {
if let Some(version) = Self::extract_accept_version(accept, pattern) {
return Some(version);
}
}
}
_ => {}
}
}
None
}
pub fn extract_from_query(&self, query: &str) -> Option<ApiVersion> {
let params: HashMap<_, _> = query
.split('&')
.filter_map(|pair| {
let mut parts = pair.splitn(2, '=');
Some((parts.next()?.to_string(), parts.next()?.to_string()))
})
.collect();
for strategy in &self.strategies {
if let VersionStrategy::Query { param } = strategy {
if let Some(value) = params.get(param) {
if let Ok(version) = value.parse() {
return Some(version);
}
}
}
}
None
}
pub fn get_default(&self) -> ApiVersion {
self.default
}
fn extract_path_version(path: &str, pattern: &str) -> Option<ApiVersion> {
let before = pattern.split("{version}").next()?;
let after = pattern.split("{version}").nth(1)?;
if let Some(start) = path.find(before) {
let version_start = start + before.len();
let remaining = &path[version_start..];
let version_end = if after.is_empty() {
remaining.len()
} else {
remaining.find(after).unwrap_or(remaining.len())
};
let version_str = &remaining[..version_end];
version_str.parse().ok()
} else {
None
}
}
fn extract_accept_version(accept: &str, pattern: &str) -> Option<ApiVersion> {
let before = pattern.split("{version}").next()?;
let after = pattern.split("{version}").nth(1)?;
for media_type in accept.split(',').map(|s| s.trim()) {
if let Some(start) = media_type.find(before) {
let version_start = start + before.len();
let remaining = &media_type[version_start..];
let version_end = if after.is_empty() {
remaining.len()
} else {
remaining.find(after).unwrap_or(remaining.len())
};
let version_str = &remaining[..version_end];
if let Ok(version) = version_str.parse() {
return Some(version);
}
}
}
None
}
pub fn strip_version_from_path(&self, path: &str) -> String {
for strategy in &self.strategies {
if let VersionStrategy::Path { pattern } = strategy {
if let Some(stripped) = Self::strip_path_version(path, pattern) {
return stripped;
}
}
}
path.to_string()
}
fn strip_path_version(path: &str, pattern: &str) -> Option<String> {
let before = pattern.split("{version}").next()?;
let after = pattern.split("{version}").nth(1)?;
if let Some(start) = path.find(before) {
let version_start = start + before.len();
let remaining = &path[version_start..];
let version_end = if after.is_empty() {
remaining.len()
} else {
remaining.find(after)?
};
let version_str = &remaining[..version_end];
if version_str.parse::<ApiVersion>().is_ok() {
let prefix = &path[..start];
let suffix = &remaining[version_end + after.len()..];
if path.starts_with('/') && prefix.is_empty() && !suffix.starts_with('/') {
return Some(format!("/{}", suffix));
}
return Some(format!("{}{}", prefix, suffix));
}
}
None
}
}
impl Default for VersionExtractor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ExtractedVersion {
pub version: ApiVersion,
pub source: VersionSource,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionSource {
Path,
Header,
Query,
Accept,
Default,
Custom,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_from_path() {
let extractor = VersionExtractor::new();
assert_eq!(
extractor.extract_from_path("/v1/users"),
Some(ApiVersion::major(1))
);
assert_eq!(
extractor.extract_from_path("/v2/products/123"),
Some(ApiVersion::major(2))
);
assert_eq!(
extractor.extract_from_path("/v1.2/items"),
Some(ApiVersion::new(1, 2, 0))
);
}
#[test]
fn test_extract_from_header() {
let extractor = VersionExtractor::with_strategy(VersionStrategy::header());
let mut headers = HashMap::new();
headers.insert("x-api-version".to_string(), "2.0".to_string());
assert_eq!(
extractor.extract_from_headers(&headers),
Some(ApiVersion::new(2, 0, 0))
);
}
#[test]
fn test_extract_from_query() {
let extractor = VersionExtractor::with_strategy(VersionStrategy::query());
assert_eq!(
extractor.extract_from_query("version=1&other=value"),
Some(ApiVersion::major(1))
);
assert_eq!(
extractor.extract_from_query("foo=bar&version=2.1"),
Some(ApiVersion::new(2, 1, 0))
);
}
#[test]
fn test_extract_from_accept() {
let extractor = VersionExtractor::with_strategy(VersionStrategy::accept());
let mut headers = HashMap::new();
headers.insert(
"accept".to_string(),
"application/vnd.api.v2+json".to_string(),
);
assert_eq!(
extractor.extract_from_headers(&headers),
Some(ApiVersion::major(2))
);
}
#[test]
fn test_strip_version_from_path() {
let extractor = VersionExtractor::new();
assert_eq!(extractor.strip_version_from_path("/v1/users"), "/users");
assert_eq!(
extractor.strip_version_from_path("/v2.0/products/123"),
"/products/123"
);
}
#[test]
fn test_multiple_strategies() {
let extractor = VersionExtractor::with_strategies(vec![
VersionStrategy::path(),
VersionStrategy::header(),
VersionStrategy::query(),
])
.default_version(ApiVersion::v1());
assert_eq!(
extractor.extract_from_path("/v2/test"),
Some(ApiVersion::major(2))
);
assert_eq!(
extractor.extract_from_query("version=3"),
Some(ApiVersion::major(3))
);
}
#[test]
fn test_custom_path_pattern() {
let extractor =
VersionExtractor::with_strategy(VersionStrategy::path_with_pattern("/api/{version}/"));
assert_eq!(
extractor.extract_from_path("/api/1/users"),
Some(ApiVersion::major(1))
);
assert_eq!(
extractor.extract_from_path("/api/2.0/products"),
Some(ApiVersion::new(2, 0, 0))
);
}
}