Skip to main content

reinhardt_rest/versioning/
middleware.rs

1//! Middleware for automatic API version detection
2//!
3//! This module provides middleware that automatically detects the API version
4//! from requests and stores it in request extensions for easy access in handlers.
5
6use super::{BaseVersioning, VersioningError};
7use async_trait::async_trait;
8use reinhardt_core::exception::{Error, Result};
9use reinhardt_http::{Handler, Middleware};
10use reinhardt_http::{Request, Response};
11use std::sync::Arc;
12
13/// API version extracted from request
14#[derive(Debug, Clone)]
15pub struct ApiVersion(pub String);
16
17impl ApiVersion {
18	/// Get the version string as a string slice
19	///
20	/// # Examples
21	///
22	/// ```
23	/// use reinhardt_rest::versioning::ApiVersion;
24	///
25	/// let version = ApiVersion::new("2.0".to_string());
26	/// assert_eq!(version.as_str(), "2.0");
27	/// ```
28	pub fn as_str(&self) -> &str {
29		&self.0
30	}
31
32	/// Create a new ApiVersion with the given version string
33	///
34	/// # Examples
35	///
36	/// ```
37	/// use reinhardt_rest::versioning::ApiVersion;
38	///
39	/// let version = ApiVersion::new("1.0".to_string());
40	/// assert_eq!(version.as_str(), "1.0");
41	/// ```
42	pub fn new(version: String) -> Self {
43		Self(version)
44	}
45}
46
47impl std::fmt::Display for ApiVersion {
48	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49		write!(f, "{}", self.0)
50	}
51}
52
53/// Middleware for automatic API version detection
54///
55/// This middleware uses a versioning strategy to automatically detect
56/// the API version from incoming requests and stores it in request extensions.
57///
58/// # Example
59///
60/// ```rust
61/// use reinhardt_rest::versioning::{URLPathVersioning, VersioningMiddleware};
62///
63/// let versioning = URLPathVersioning::new()
64///     .with_default_version("1.0")
65///     .with_allowed_versions(vec!["1.0", "2.0"]);
66///
67/// let middleware = VersioningMiddleware::new(versioning);
68/// ```
69pub struct VersioningMiddleware<V: BaseVersioning> {
70	versioning: Arc<V>,
71}
72
73impl<V: BaseVersioning> VersioningMiddleware<V> {
74	/// Create a new versioning middleware with the given versioning strategy
75	///
76	/// # Examples
77	///
78	/// ```
79	/// use reinhardt_rest::versioning::{URLPathVersioning, VersioningMiddleware};
80	///
81	/// let versioning = URLPathVersioning::new()
82	///     .with_default_version("1.0");
83	/// let middleware = VersioningMiddleware::new(versioning);
84	/// ```
85	pub fn new(versioning: V) -> Self {
86		Self {
87			versioning: Arc::new(versioning),
88		}
89	}
90	/// Get a reference to the underlying versioning strategy
91	///
92	/// # Examples
93	///
94	/// ```
95	/// use reinhardt_rest::versioning::{URLPathVersioning, VersioningMiddleware, BaseVersioning};
96	///
97	/// let url_versioning = URLPathVersioning::new()
98	///     .with_default_version("1.0");
99	/// let middleware = VersioningMiddleware::new(url_versioning);
100	///
101	/// assert_eq!(middleware.versioning().default_version(), Some("1.0"));
102	/// ```
103	pub fn versioning(&self) -> &V {
104		&self.versioning
105	}
106}
107
108impl<V: BaseVersioning> Clone for VersioningMiddleware<V> {
109	fn clone(&self) -> Self {
110		Self {
111			versioning: Arc::clone(&self.versioning),
112		}
113	}
114}
115
116#[async_trait]
117impl<V: BaseVersioning + 'static> Middleware for VersioningMiddleware<V> {
118	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
119		// Determine version from request
120		let version = self
121			.versioning
122			.determine_version(&request)
123			.await
124			.map_err(|e| match e {
125				Error::Validation(msg) => Error::Validation(msg),
126				_ => Error::Validation(VersioningError::InvalidAcceptHeader.to_string()),
127			})?;
128
129		// Store version in request extensions
130		request.extensions.insert(ApiVersion(version));
131
132		// Call next handler
133		next.handle(request).await
134	}
135}
136
137/// Extension trait to get API version from request
138pub trait RequestVersionExt {
139	/// Get the API version from request extensions
140	fn version(&self) -> Option<String>;
141
142	/// Get the API version or return default
143	fn version_or(&self, default: &str) -> String;
144}
145
146impl RequestVersionExt for Request {
147	fn version(&self) -> Option<String> {
148		self.extensions.get::<ApiVersion>().map(|v| v.0)
149	}
150
151	fn version_or(&self, default: &str) -> String {
152		self.version().unwrap_or_else(|| default.to_string())
153	}
154}
155
156#[cfg(test)]
157mod tests {
158	use super::*;
159	use crate::versioning::{QueryParameterVersioning, URLPathVersioning};
160	use bytes::Bytes;
161	use hyper::{HeaderMap, Method, Uri, Version};
162
163	fn create_test_request(uri: &str) -> Request {
164		let uri = uri.parse::<Uri>().unwrap();
165		Request::builder()
166			.method(Method::GET)
167			.uri(uri)
168			.version(Version::HTTP_11)
169			.headers(HeaderMap::new())
170			.body(Bytes::new())
171			.build()
172			.unwrap()
173	}
174
175	// Dummy handler for testing
176	struct DummyHandler;
177
178	#[async_trait]
179	impl Handler for DummyHandler {
180		async fn handle(&self, _request: Request) -> Result<Response> {
181			Ok(Response::ok())
182		}
183	}
184
185	#[tokio::test]
186	async fn test_middleware_url_path_versioning() {
187		let versioning = URLPathVersioning::new()
188			.with_default_version("1.0")
189			.with_allowed_versions(vec!["1.0", "2.0", "2"]);
190
191		let middleware = VersioningMiddleware::new(versioning);
192		let handler = Arc::new(DummyHandler);
193
194		// Test with version in path
195		let request = create_test_request("/v2/users/");
196		let _ = middleware.process(request, handler.clone()).await.unwrap();
197
198		// Test without version (should use default)
199		let request = create_test_request("/users/");
200		let _ = middleware.process(request, handler.clone()).await.unwrap();
201	}
202
203	#[tokio::test]
204	async fn test_middleware_query_parameter_versioning() {
205		let versioning = QueryParameterVersioning::new()
206			.with_default_version("1.0")
207			.with_allowed_versions(vec!["1.0", "2.0", "3.0"]);
208
209		let middleware = VersioningMiddleware::new(versioning);
210		let handler = Arc::new(DummyHandler);
211
212		// Test with version in query
213		let request = create_test_request("/users/?version=2.0");
214		let _ = middleware.process(request, handler.clone()).await.unwrap();
215
216		// Test without version (should use default)
217		let request = create_test_request("/users/");
218		let _ = middleware.process(request, handler.clone()).await.unwrap();
219	}
220
221	#[tokio::test]
222	async fn test_request_version_extension() {
223		let versioning = URLPathVersioning::new()
224			.with_default_version("1.0")
225			.with_allowed_versions(vec!["1.0", "2.0", "2"]);
226
227		let middleware = VersioningMiddleware::new(versioning);
228		let handler = Arc::new(DummyHandler);
229
230		let request = create_test_request("/v2/users/");
231		let _ = middleware.process(request, handler.clone()).await.unwrap();
232	}
233
234	#[tokio::test]
235	async fn test_request_version_extension_with_default() {
236		let request = create_test_request("/users/");
237
238		// No version set, should return None
239		assert_eq!(request.version(), None);
240
241		// Should use provided default
242		assert_eq!(request.version_or("fallback"), "fallback");
243	}
244
245	#[tokio::test]
246	async fn test_middleware_invalid_version() {
247		let versioning = URLPathVersioning::new()
248			.with_default_version("1.0")
249			.with_allowed_versions(vec!["1.0", "2.0"]);
250
251		let middleware = VersioningMiddleware::new(versioning);
252		let handler = Arc::new(DummyHandler);
253
254		// Test with invalid version (not in allowed list)
255		let request = create_test_request("/v3/users/");
256		let result = middleware.process(request, handler.clone()).await;
257
258		assert!(result.is_err());
259	}
260
261	#[tokio::test]
262	async fn test_api_version_methods() {
263		let version = ApiVersion("2.0".to_string());
264
265		assert_eq!(version.as_str(), "2.0");
266		assert_eq!(version.to_string(), "2.0");
267	}
268}