reinhardt_rest/versioning/
middleware.rs1use super::{BaseVersioning, VersioningError};
7use async_trait::async_trait;
8use reinhardt_core::exception::{Error, Result};
9use reinhardt_http::{Handler, Middleware};
10use reinhardt_http::{Request, Response};
11use std::sync::Arc;
12
13#[derive(Debug, Clone)]
15pub struct ApiVersion(pub String);
16
17impl ApiVersion {
18 pub fn as_str(&self) -> &str {
29 &self.0
30 }
31
32 pub fn new(version: String) -> Self {
43 Self(version)
44 }
45}
46
47impl std::fmt::Display for ApiVersion {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "{}", self.0)
50 }
51}
52
53pub struct VersioningMiddleware<V: BaseVersioning> {
70 versioning: Arc<V>,
71}
72
73impl<V: BaseVersioning> VersioningMiddleware<V> {
74 pub fn new(versioning: V) -> Self {
86 Self {
87 versioning: Arc::new(versioning),
88 }
89 }
90 pub fn versioning(&self) -> &V {
104 &self.versioning
105 }
106}
107
108impl<V: BaseVersioning> Clone for VersioningMiddleware<V> {
109 fn clone(&self) -> Self {
110 Self {
111 versioning: Arc::clone(&self.versioning),
112 }
113 }
114}
115
116#[async_trait]
117impl<V: BaseVersioning + 'static> Middleware for VersioningMiddleware<V> {
118 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
119 let version = self
121 .versioning
122 .determine_version(&request)
123 .await
124 .map_err(|e| match e {
125 Error::Validation(msg) => Error::Validation(msg),
126 _ => Error::Validation(VersioningError::InvalidAcceptHeader.to_string()),
127 })?;
128
129 request.extensions.insert(ApiVersion(version));
131
132 next.handle(request).await
134 }
135}
136
137pub trait RequestVersionExt {
139 fn version(&self) -> Option<String>;
141
142 fn version_or(&self, default: &str) -> String;
144}
145
146impl RequestVersionExt for Request {
147 fn version(&self) -> Option<String> {
148 self.extensions.get::<ApiVersion>().map(|v| v.0)
149 }
150
151 fn version_or(&self, default: &str) -> String {
152 self.version().unwrap_or_else(|| default.to_string())
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::versioning::{QueryParameterVersioning, URLPathVersioning};
160 use bytes::Bytes;
161 use hyper::{HeaderMap, Method, Uri, Version};
162
163 fn create_test_request(uri: &str) -> Request {
164 let uri = uri.parse::<Uri>().unwrap();
165 Request::builder()
166 .method(Method::GET)
167 .uri(uri)
168 .version(Version::HTTP_11)
169 .headers(HeaderMap::new())
170 .body(Bytes::new())
171 .build()
172 .unwrap()
173 }
174
175 struct DummyHandler;
177
178 #[async_trait]
179 impl Handler for DummyHandler {
180 async fn handle(&self, _request: Request) -> Result<Response> {
181 Ok(Response::ok())
182 }
183 }
184
185 #[tokio::test]
186 async fn test_middleware_url_path_versioning() {
187 let versioning = URLPathVersioning::new()
188 .with_default_version("1.0")
189 .with_allowed_versions(vec!["1.0", "2.0", "2"]);
190
191 let middleware = VersioningMiddleware::new(versioning);
192 let handler = Arc::new(DummyHandler);
193
194 let request = create_test_request("/v2/users/");
196 let _ = middleware.process(request, handler.clone()).await.unwrap();
197
198 let request = create_test_request("/users/");
200 let _ = middleware.process(request, handler.clone()).await.unwrap();
201 }
202
203 #[tokio::test]
204 async fn test_middleware_query_parameter_versioning() {
205 let versioning = QueryParameterVersioning::new()
206 .with_default_version("1.0")
207 .with_allowed_versions(vec!["1.0", "2.0", "3.0"]);
208
209 let middleware = VersioningMiddleware::new(versioning);
210 let handler = Arc::new(DummyHandler);
211
212 let request = create_test_request("/users/?version=2.0");
214 let _ = middleware.process(request, handler.clone()).await.unwrap();
215
216 let request = create_test_request("/users/");
218 let _ = middleware.process(request, handler.clone()).await.unwrap();
219 }
220
221 #[tokio::test]
222 async fn test_request_version_extension() {
223 let versioning = URLPathVersioning::new()
224 .with_default_version("1.0")
225 .with_allowed_versions(vec!["1.0", "2.0", "2"]);
226
227 let middleware = VersioningMiddleware::new(versioning);
228 let handler = Arc::new(DummyHandler);
229
230 let request = create_test_request("/v2/users/");
231 let _ = middleware.process(request, handler.clone()).await.unwrap();
232 }
233
234 #[tokio::test]
235 async fn test_request_version_extension_with_default() {
236 let request = create_test_request("/users/");
237
238 assert_eq!(request.version(), None);
240
241 assert_eq!(request.version_or("fallback"), "fallback");
243 }
244
245 #[tokio::test]
246 async fn test_middleware_invalid_version() {
247 let versioning = URLPathVersioning::new()
248 .with_default_version("1.0")
249 .with_allowed_versions(vec!["1.0", "2.0"]);
250
251 let middleware = VersioningMiddleware::new(versioning);
252 let handler = Arc::new(DummyHandler);
253
254 let request = create_test_request("/v3/users/");
256 let result = middleware.process(request, handler.clone()).await;
257
258 assert!(result.is_err());
259 }
260
261 #[tokio::test]
262 async fn test_api_version_methods() {
263 let version = ApiVersion("2.0".to_string());
264
265 assert_eq!(version.as_str(), "2.0");
266 assert_eq!(version.to_string(), "2.0");
267 }
268}