pub mod config;
pub mod handler;
pub mod middleware;
pub mod reverse;
pub mod settings;
use async_trait::async_trait;
pub use config::{VersioningConfig, VersioningManager, VersioningStrategy};
pub use handler::{
ConfigurableVersionedHandler, SimpleVersionedHandler, VersionResponseBuilder, VersionedHandler,
VersionedHandlerBuilder, VersionedHandlerWrapper,
};
pub use middleware::{ApiVersion, RequestVersionExt, VersioningMiddleware};
use regex::Regex;
use reinhardt_core::exception::{Error, Result};
use reinhardt_http::Request;
pub use reverse::{
ApiDocFormat, ApiDocUrlBuilder, UrlReverseManager, VersionedUrlBuilder,
VersioningStrategy as ReverseVersioningStrategy,
};
pub use settings::VersioningSettings;
use std::collections::{HashMap, HashSet};
use std::sync::OnceLock;
use thiserror::Error as ThisError;
#[derive(Debug, ThisError)]
pub enum VersioningError {
#[error("Invalid version in Accept header")]
InvalidAcceptHeader,
#[error("Invalid version in URL path")]
InvalidURLPath,
#[error("Invalid version in URL namespace")]
InvalidNamespace,
#[error("Invalid version in hostname")]
InvalidHostname,
#[error("Invalid version in query parameter")]
InvalidQueryParameter,
#[error("Version not allowed: {0}")]
VersionNotAllowed(String),
}
#[async_trait]
pub trait BaseVersioning: Send + Sync {
async fn determine_version(&self, request: &Request) -> Result<String>;
fn default_version(&self) -> Option<&str>;
fn allowed_versions(&self) -> Option<&HashSet<String>>;
fn is_allowed_version(&self, version: &str) -> bool {
if let Some(allowed) = self.allowed_versions() {
if allowed.is_empty() {
return true;
}
return allowed.contains(version) || (self.default_version() == Some(version));
}
true
}
fn version_param(&self) -> &str {
"version"
}
}
#[derive(Debug, Clone)]
pub struct AcceptHeaderVersioning {
pub default_version: Option<String>,
pub allowed_versions: HashSet<String>,
pub version_param: String,
}
impl AcceptHeaderVersioning {
pub fn new() -> Self {
Self {
default_version: None,
allowed_versions: HashSet::new(),
version_param: "version".to_string(),
}
}
pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
self.default_version = Some(version.into());
self
}
pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
self
}
pub fn with_version_param(mut self, param: impl Into<String>) -> Self {
self.version_param = param.into();
self
}
}
impl Default for AcceptHeaderVersioning {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BaseVersioning for AcceptHeaderVersioning {
async fn determine_version(&self, request: &Request) -> Result<String> {
if let Some(accept) = request.headers.get("accept") {
let accept_str = accept
.to_str()
.map_err(|_| Error::Validation(VersioningError::InvalidAcceptHeader.to_string()))?;
if let Some(params_start) = accept_str.find(';') {
let params = &accept_str[params_start + 1..];
for param in params.split(';') {
let param = param.trim();
if let Some((key, value)) = param.split_once('=')
&& key.trim() == self.version_param
{
let version = value.trim().trim_matches('"');
if self.is_allowed_version(version) {
return Ok(version.to_owned());
} else {
return Err(Error::Validation(format!(
"Version not allowed: {version}"
)));
}
}
}
}
}
Ok(self.default_version.as_deref().unwrap_or("1.0").to_owned())
}
fn default_version(&self) -> Option<&str> {
self.default_version.as_deref()
}
fn allowed_versions(&self) -> Option<&HashSet<String>> {
Some(&self.allowed_versions)
}
fn version_param(&self) -> &str {
&self.version_param
}
}
#[derive(Debug, Clone)]
pub struct URLPathVersioning {
pub default_version: Option<String>,
pub allowed_versions: HashSet<String>,
pub version_param: String,
pub path_regex: Regex,
}
impl URLPathVersioning {
pub fn new() -> Self {
Self {
default_version: None,
allowed_versions: HashSet::new(),
version_param: "version".to_string(),
path_regex: Regex::new(r"/v(\d+\.?\d*)(?:/|$)").unwrap(),
}
}
pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
self.default_version = Some(version.into());
self
}
pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
self
}
pub fn with_version_param(mut self, param: impl Into<String>) -> Self {
self.version_param = param.into();
self
}
pub fn with_path_regex(mut self, regex: Regex) -> Self {
self.path_regex = regex;
self
}
pub fn with_pattern(mut self, pattern: &str) -> Self {
let regex_pattern = pattern.replace("{version}", "([^/]+)");
if let Ok(regex) = Regex::new(®ex_pattern) {
self.path_regex = regex;
}
self
}
}
impl Default for URLPathVersioning {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BaseVersioning for URLPathVersioning {
async fn determine_version(&self, request: &Request) -> Result<String> {
let path = request.uri.path();
if let Some(captures) = self.path_regex.captures(path)
&& let Some(version_match) = captures.get(1)
{
let version = version_match.as_str();
if self.is_allowed_version(version) {
return Ok(version.to_owned());
} else {
return Err(Error::Validation(format!("Version not allowed: {version}")));
}
}
Ok(self.default_version.as_deref().unwrap_or("1.0").to_owned())
}
fn default_version(&self) -> Option<&str> {
self.default_version.as_deref()
}
fn allowed_versions(&self) -> Option<&HashSet<String>> {
Some(&self.allowed_versions)
}
fn version_param(&self) -> &str {
&self.version_param
}
}
#[derive(Debug, Clone)]
pub struct HostNameVersioning {
pub default_version: Option<String>,
pub allowed_versions: HashSet<String>,
pub hostname_regex: Regex,
pub hostname_to_version: HashMap<String, String>,
}
impl HostNameVersioning {
pub fn new() -> Self {
Self {
default_version: None,
allowed_versions: HashSet::new(),
hostname_regex: Regex::new(r"^([a-zA-Z0-9]+)\.").unwrap(),
hostname_to_version: HashMap::new(),
}
}
pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
self.default_version = Some(version.into());
self
}
pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
self
}
pub fn with_hostname_regex(mut self, regex: Regex) -> Self {
self.hostname_regex = regex;
self
}
pub fn with_host_format(mut self, format: &str) -> Self {
const PLACEHOLDER: &str = "__REINHARDT_VERSION_PLACEHOLDER__";
let pattern = format.replace("{version}", PLACEHOLDER);
let pattern = pattern.replace(".", "\\.");
let pattern = pattern.replace(PLACEHOLDER, "([^.]+)");
let pattern = format!("^{}", pattern);
if let Ok(regex) = Regex::new(&pattern) {
self.hostname_regex = regex;
}
self
}
pub fn with_hostname_pattern(mut self, version: &str, hostname: &str) -> Self {
self.allowed_versions.insert(version.to_string());
self.hostname_to_version
.insert(hostname.to_string(), version.to_string());
self
}
}
impl Default for HostNameVersioning {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BaseVersioning for HostNameVersioning {
async fn determine_version(&self, request: &Request) -> Result<String> {
if let Some(host) = request.headers.get("host") {
let host_str = host
.to_str()
.map_err(|_| Error::Validation(VersioningError::InvalidHostname.to_string()))?;
let hostname = host_str.split(':').next().unwrap_or(host_str);
if let Some(version) = self.hostname_to_version.get(hostname)
&& self.is_allowed_version(version)
{
return Ok(version.clone());
}
if let Some(captures) = self.hostname_regex.captures(hostname)
&& let Some(version_match) = captures.get(1)
{
let version = version_match.as_str();
if self.is_allowed_version(version) {
return Ok(version.to_string());
}
}
}
Ok(self.default_version.as_deref().unwrap_or("1.0").to_owned())
}
fn default_version(&self) -> Option<&str> {
self.default_version.as_deref()
}
fn allowed_versions(&self) -> Option<&HashSet<String>> {
Some(&self.allowed_versions)
}
}
#[derive(Debug, Clone)]
pub struct QueryParameterVersioning {
pub default_version: Option<String>,
pub allowed_versions: HashSet<String>,
pub version_param: String,
}
impl QueryParameterVersioning {
pub fn new() -> Self {
Self {
default_version: None,
allowed_versions: HashSet::new(),
version_param: "version".to_string(),
}
}
pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
self.default_version = Some(version.into());
self
}
pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
self
}
pub fn with_version_param(mut self, param: impl Into<String>) -> Self {
self.version_param = param.into();
self
}
}
impl Default for QueryParameterVersioning {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BaseVersioning for QueryParameterVersioning {
async fn determine_version(&self, request: &Request) -> Result<String> {
if let Some(query) = request.uri.query() {
for param in query.split('&') {
if let Some((key, value)) = param.split_once('=')
&& key == self.version_param
{
if self.is_allowed_version(value) {
return Ok(value.to_owned());
} else {
return Err(Error::Validation(format!("Version not allowed: {value}")));
}
}
}
}
Ok(self.default_version.as_deref().unwrap_or("1.0").to_owned())
}
fn default_version(&self) -> Option<&str> {
self.default_version.as_deref()
}
fn allowed_versions(&self) -> Option<&HashSet<String>> {
Some(&self.allowed_versions)
}
fn version_param(&self) -> &str {
&self.version_param
}
}
#[derive(Debug)]
pub struct NamespaceVersioning {
pub default_version: Option<String>,
pub allowed_versions: HashSet<String>,
pub pattern: String,
pub namespace_prefix: Option<String>,
compiled_regex: OnceLock<Option<Regex>>,
}
impl Clone for NamespaceVersioning {
fn clone(&self) -> Self {
Self {
default_version: self.default_version.clone(),
allowed_versions: self.allowed_versions.clone(),
pattern: self.pattern.clone(),
namespace_prefix: self.namespace_prefix.clone(),
compiled_regex: OnceLock::new(),
}
}
}
impl NamespaceVersioning {
pub fn new() -> Self {
Self {
default_version: None,
allowed_versions: HashSet::new(),
pattern: "/v{version}/".to_string(),
namespace_prefix: None,
compiled_regex: OnceLock::new(),
}
}
pub fn with_default_version(mut self, version: impl Into<String>) -> Self {
self.default_version = Some(version.into());
self
}
pub fn with_allowed_versions(mut self, versions: Vec<impl Into<String>>) -> Self {
self.allowed_versions = versions.into_iter().map(|v| v.into()).collect();
self
}
pub fn with_namespace_prefix(mut self, prefix: &str) -> Self {
self.namespace_prefix = Some(prefix.to_string());
self
}
pub fn with_pattern(mut self, pattern: &str) -> Self {
self.pattern = pattern.to_string();
self.compiled_regex = OnceLock::new();
self
}
}
impl Default for NamespaceVersioning {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BaseVersioning for NamespaceVersioning {
async fn determine_version(&self, request: &Request) -> Result<String> {
let path = request.uri.path();
if let Some(version) = self.extract_version_from_path(path)
&& self.is_allowed_version(&version)
{
return Ok(version);
}
Ok(self.default_version.as_deref().unwrap_or("1.0").to_owned())
}
fn default_version(&self) -> Option<&str> {
self.default_version.as_deref()
}
fn allowed_versions(&self) -> Option<&HashSet<String>> {
Some(&self.allowed_versions)
}
}
impl NamespaceVersioning {
fn get_compiled_regex(&self) -> Option<&Regex> {
self.compiled_regex
.get_or_init(|| {
let regex_pattern = self
.pattern
.replace("{version}", r"([^/]+)")
.replace("/", r"\/");
let full_pattern = format!("^{}", regex_pattern);
regex::Regex::new(&full_pattern).ok()
})
.as_ref()
}
fn extract_version_from_path(&self, path: &str) -> Option<String> {
if let Some(regex) = self.get_compiled_regex()
&& let Some(captures) = regex.captures(path)
&& let Some(version_match) = captures.get(1)
{
return Some(version_match.as_str().to_string());
}
None
}
fn is_allowed_version(&self, version: &str) -> bool {
self.allowed_versions.is_empty() || self.allowed_versions.contains(version)
}
pub fn extract_version_from_router<R: reinhardt_router::VersionedRouter + ?Sized>(
&self,
router: &R,
path: &str,
) -> Option<String> {
router
.route_version_infos()
.into_iter()
.find(|info| path.starts_with(&info.path_prefix))
.and_then(|info| self.extract_version_from_path(&info.path_prefix))
}
pub fn get_available_versions_from_router<R: reinhardt_router::VersionedRouter + ?Sized>(
&self,
router: &R,
) -> Vec<String> {
let mut versions: Vec<String> = router
.route_version_infos()
.into_iter()
.filter_map(|info| self.extract_version_from_path(&info.path_prefix))
.filter(|version| self.is_allowed_version(version))
.collect();
versions.sort();
versions.dedup();
versions
}
}
#[cfg(test)]
pub mod test_utils {
use bytes::Bytes;
use hyper::header::HeaderName;
use hyper::{HeaderMap, Method, Uri, Version};
use reinhardt_http::Request;
pub fn create_test_request(uri: &str, headers: Vec<(String, String)>) -> Request {
let uri = uri.parse::<Uri>().unwrap();
let mut header_map = HeaderMap::new();
for (key, value) in headers {
let header_name: HeaderName = key.parse().unwrap();
header_map.insert(header_name, value.parse().unwrap());
}
Request::builder()
.method(Method::GET)
.uri(uri)
.version(Version::HTTP_11)
.headers(header_map)
.body(Bytes::new())
.build()
.unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_utils::create_test_request;
#[tokio::test]
async fn test_accept_header_versioning() {
let versioning = AcceptHeaderVersioning::new()
.with_default_version("1.0")
.with_allowed_versions(vec!["1.0", "2.0"]);
let request = create_test_request(
"/users/",
vec![(
"accept".to_string(),
"application/json; version=2.0".to_string(),
)],
);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "2.0");
let request = create_test_request(
"/users/",
vec![("accept".to_string(), "application/json".to_string())],
);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "1.0");
}
#[tokio::test]
async fn test_url_path_versioning() {
let versioning = URLPathVersioning::new()
.with_default_version("1.0")
.with_allowed_versions(vec!["1.0", "2.0", "2"]);
let request = create_test_request("/v2/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "2");
let request = create_test_request("/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "1.0");
}
#[tokio::test]
async fn test_hostname_versioning() {
let versioning = HostNameVersioning::new()
.with_default_version("1.0")
.with_allowed_versions(vec!["v1", "v2"]);
let request = create_test_request(
"/users/",
vec![("host".to_string(), "v2.api.example.com".to_string())],
);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "v2");
let request = create_test_request(
"/users/",
vec![("host".to_string(), "api.example.com".to_string())],
);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "1.0");
}
#[tokio::test]
async fn test_query_parameter_versioning() {
let versioning = QueryParameterVersioning::new()
.with_default_version("1.0")
.with_allowed_versions(vec!["1.0", "2.0"]);
let request = create_test_request("/users/?version=2.0", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "2.0");
let request = create_test_request("/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "1.0");
}
#[tokio::test]
async fn test_namespace_versioning() {
let versioning = NamespaceVersioning::new()
.with_default_version("1.0")
.with_allowed_versions(vec!["1", "1.0", "2", "2.0", "3.0"]);
let request = create_test_request("/v1/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "1");
let request = create_test_request("/v2.0/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "2.0");
let request = create_test_request("/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "1.0");
let request = create_test_request("/api/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "1.0");
}
#[tokio::test]
async fn test_namespace_versioning_with_custom_pattern() {
let versioning = NamespaceVersioning::new()
.with_default_version("1.0")
.with_pattern("/api/v{version}/")
.with_allowed_versions(vec!["1", "2"]);
let request = create_test_request("/api/v1/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "1");
let request = create_test_request("/api/v2/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "2");
let request = create_test_request("/v1/users/", vec![]);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "1.0"); }
#[tokio::test]
async fn test_hostname_versioning_with_host_format_dots_not_corrupted() {
let versioning = HostNameVersioning::new()
.with_host_format("{version}.api.v2.example.com")
.with_allowed_versions(vec!["v1", "v3"]);
let request = create_test_request(
"/users/",
vec![("host".to_string(), "v1.api.v2.example.com".to_string())],
);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "v1");
let request = create_test_request(
"/users/",
vec![("host".to_string(), "v3.api.v2.example.com".to_string())],
);
let version = versioning.determine_version(&request).await.unwrap();
assert_eq!(version, "v3");
}
}