1use async_trait::async_trait;
6use hyper::header::HOST;
7use hyper::{Method, StatusCode};
8use reinhardt_http::{Handler, Middleware, Request, Response, Result};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11
12#[non_exhaustive]
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct CommonConfig {
16 pub append_slash: bool,
18 pub prepend_www: bool,
20}
21
22impl CommonConfig {
23 pub fn new() -> Self {
39 Self {
40 append_slash: true,
41 prepend_www: false,
42 }
43 }
44}
45
46impl Default for CommonConfig {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52pub struct CommonMiddleware {
99 config: CommonConfig,
100}
101
102impl CommonMiddleware {
103 pub fn new() -> Self {
113 Self {
114 config: CommonConfig::default(),
115 }
116 }
117
118 pub fn with_config(config: CommonConfig) -> Self {
132 Self { config }
133 }
134
135 fn should_append_slash(&self, path: &str) -> bool {
137 if !self.config.append_slash {
138 return false;
139 }
140
141 if path.ends_with('/') {
143 return false;
144 }
145
146 if let Some(last_segment) = path.rsplit('/').next()
148 && last_segment.contains('.')
149 {
150 return false;
151 }
152
153 true
154 }
155
156 fn should_prepend_www(&self, host: &str) -> bool {
158 if !self.config.prepend_www {
159 return false;
160 }
161
162 if host.starts_with("www.") {
164 return false;
165 }
166
167 if host.starts_with("localhost") || host.starts_with("127.") || host.starts_with("192.168.")
169 {
170 return false;
171 }
172
173 true
174 }
175
176 fn build_redirect_url(&self, request: &Request) -> Option<String> {
178 let path = request.uri.path();
179 let query = request.uri.query();
180
181 let host = request
182 .headers
183 .get(HOST)
184 .and_then(|h| h.to_str().ok())
185 .unwrap_or("localhost");
186
187 let mut redirect_needed = false;
188 let mut new_path = path.to_string();
189 let mut new_host = host.to_string();
190
191 if self.should_append_slash(path) {
193 new_path.push('/');
194 redirect_needed = true;
195 }
196
197 if self.should_prepend_www(host) {
199 new_host = format!("www.{}", host);
200 redirect_needed = true;
201 }
202
203 if !redirect_needed {
204 return None;
205 }
206
207 let scheme = request.scheme();
210
211 let url = if let Some(q) = query {
212 format!("{}://{}{}?{}", scheme, new_host, new_path, q)
213 } else {
214 format!("{}://{}{}", scheme, new_host, new_path)
215 };
216
217 Some(url)
218 }
219}
220
221impl Default for CommonMiddleware {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227#[async_trait]
228impl Middleware for CommonMiddleware {
229 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
230 if let Some(redirect_url) = self.build_redirect_url(&request) {
232 let status = if matches!(request.method, Method::GET | Method::HEAD) {
235 StatusCode::MOVED_PERMANENTLY
236 } else {
237 StatusCode::TEMPORARY_REDIRECT
238 };
239 let mut response = Response::new(status);
240 response.headers.insert(
241 hyper::header::LOCATION,
242 redirect_url
243 .parse()
244 .unwrap_or_else(|_| hyper::header::HeaderValue::from_static("/")),
245 );
246 return Ok(response);
247 }
248
249 handler.handle(request).await
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use bytes::Bytes;
258 use hyper::{HeaderMap, Method, Version};
259 use rstest::rstest;
260
261 struct TestHandler;
262
263 #[async_trait]
264 impl Handler for TestHandler {
265 async fn handle(&self, _request: Request) -> Result<Response> {
266 Ok(Response::new(StatusCode::OK).with_body("test response".as_bytes()))
267 }
268 }
269
270 #[tokio::test]
271 async fn test_append_slash_redirects() {
272 let config = CommonConfig {
273 append_slash: true,
274 prepend_www: false,
275 };
276 let middleware = CommonMiddleware::with_config(config);
277 let handler = Arc::new(TestHandler);
278
279 let request = Request::builder()
280 .method(Method::GET)
281 .uri("/path/to/page")
282 .version(Version::HTTP_11)
283 .headers(HeaderMap::new())
284 .body(Bytes::new())
285 .build()
286 .unwrap();
287
288 let response = middleware.process(request, handler).await.unwrap();
289
290 assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
291 let location = response.headers.get(hyper::header::LOCATION).unwrap();
292 assert!(location.to_str().unwrap().contains("/path/to/page/"));
293 }
294
295 #[tokio::test]
296 async fn test_no_redirect_with_trailing_slash() {
297 let config = CommonConfig {
298 append_slash: true,
299 prepend_www: false,
300 };
301 let middleware = CommonMiddleware::with_config(config);
302 let handler = Arc::new(TestHandler);
303
304 let request = Request::builder()
305 .method(Method::GET)
306 .uri("/path/to/page/")
307 .version(Version::HTTP_11)
308 .headers(HeaderMap::new())
309 .body(Bytes::new())
310 .build()
311 .unwrap();
312
313 let response = middleware.process(request, handler).await.unwrap();
314
315 assert_eq!(response.status, StatusCode::OK);
316 }
317
318 #[tokio::test]
319 async fn test_no_redirect_for_file_extensions() {
320 let config = CommonConfig {
321 append_slash: true,
322 prepend_www: false,
323 };
324 let middleware = CommonMiddleware::with_config(config);
325 let handler = Arc::new(TestHandler);
326
327 let request = Request::builder()
328 .method(Method::GET)
329 .uri("/static/file.css")
330 .version(Version::HTTP_11)
331 .headers(HeaderMap::new())
332 .body(Bytes::new())
333 .build()
334 .unwrap();
335
336 let response = middleware.process(request, handler).await.unwrap();
337
338 assert_eq!(response.status, StatusCode::OK);
339 }
340
341 #[tokio::test]
342 async fn test_append_slash_with_query_params() {
343 let config = CommonConfig {
344 append_slash: true,
345 prepend_www: false,
346 };
347 let middleware = CommonMiddleware::with_config(config);
348 let handler = Arc::new(TestHandler);
349
350 let request = Request::builder()
351 .method(Method::GET)
352 .uri("/search?q=test")
353 .version(Version::HTTP_11)
354 .headers(HeaderMap::new())
355 .body(Bytes::new())
356 .build()
357 .unwrap();
358
359 let response = middleware.process(request, handler).await.unwrap();
360
361 assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
362 let location = response.headers.get(hyper::header::LOCATION).unwrap();
363 let loc_str = location.to_str().unwrap();
364 assert!(loc_str.contains("/search/"));
365 assert!(loc_str.contains("?q=test"));
366 }
367
368 #[tokio::test]
369 async fn test_prepend_www() {
370 let config = CommonConfig {
371 append_slash: false,
372 prepend_www: true,
373 };
374 let middleware = CommonMiddleware::with_config(config);
375 let handler = Arc::new(TestHandler);
376
377 let mut headers = HeaderMap::new();
378 headers.insert(HOST, "example.com".parse().unwrap());
379
380 let request = Request::builder()
381 .method(Method::GET)
382 .uri("/page/")
383 .version(Version::HTTP_11)
384 .headers(headers)
385 .body(Bytes::new())
386 .build()
387 .unwrap();
388
389 let response = middleware.process(request, handler).await.unwrap();
390
391 assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
392 let location = response.headers.get(hyper::header::LOCATION).unwrap();
393 assert!(location.to_str().unwrap().contains("www.example.com"));
394 }
395
396 #[tokio::test]
397 async fn test_no_prepend_www_for_localhost() {
398 let config = CommonConfig {
399 append_slash: false,
400 prepend_www: true,
401 };
402 let middleware = CommonMiddleware::with_config(config);
403 let handler = Arc::new(TestHandler);
404
405 let mut headers = HeaderMap::new();
406 headers.insert(HOST, "localhost:8000".parse().unwrap());
407
408 let request = Request::builder()
409 .method(Method::GET)
410 .uri("/page/")
411 .version(Version::HTTP_11)
412 .headers(headers)
413 .body(Bytes::new())
414 .build()
415 .unwrap();
416
417 let response = middleware.process(request, handler).await.unwrap();
418
419 assert_eq!(response.status, StatusCode::OK);
420 }
421
422 #[tokio::test]
423 async fn test_no_prepend_www_when_already_present() {
424 let config = CommonConfig {
425 append_slash: false,
426 prepend_www: true,
427 };
428 let middleware = CommonMiddleware::with_config(config);
429 let handler = Arc::new(TestHandler);
430
431 let mut headers = HeaderMap::new();
432 headers.insert(HOST, "www.example.com".parse().unwrap());
433
434 let request = Request::builder()
435 .method(Method::GET)
436 .uri("/page/")
437 .version(Version::HTTP_11)
438 .headers(headers)
439 .body(Bytes::new())
440 .build()
441 .unwrap();
442
443 let response = middleware.process(request, handler).await.unwrap();
444
445 assert_eq!(response.status, StatusCode::OK);
446 }
447
448 #[tokio::test]
449 async fn test_both_transformations() {
450 let config = CommonConfig {
451 append_slash: true,
452 prepend_www: true,
453 };
454 let middleware = CommonMiddleware::with_config(config);
455 let handler = Arc::new(TestHandler);
456
457 let mut headers = HeaderMap::new();
458 headers.insert(HOST, "example.com".parse().unwrap());
459
460 let request = Request::builder()
461 .method(Method::GET)
462 .uri("/page")
463 .version(Version::HTTP_11)
464 .headers(headers)
465 .body(Bytes::new())
466 .build()
467 .unwrap();
468
469 let response = middleware.process(request, handler).await.unwrap();
470
471 assert_eq!(response.status, StatusCode::MOVED_PERMANENTLY);
472 let location = response.headers.get(hyper::header::LOCATION).unwrap();
473 let loc_str = location.to_str().unwrap();
474 assert!(loc_str.contains("www.example.com"));
475 assert!(loc_str.contains("/page/"));
476 }
477
478 #[tokio::test]
479 async fn test_both_disabled() {
480 let config = CommonConfig {
481 append_slash: false,
482 prepend_www: false,
483 };
484 let middleware = CommonMiddleware::with_config(config);
485 let handler = Arc::new(TestHandler);
486
487 let mut headers = HeaderMap::new();
488 headers.insert(HOST, "example.com".parse().unwrap());
489
490 let request = Request::builder()
491 .method(Method::GET)
492 .uri("/page")
493 .version(Version::HTTP_11)
494 .headers(headers)
495 .body(Bytes::new())
496 .build()
497 .unwrap();
498
499 let response = middleware.process(request, handler).await.unwrap();
500
501 assert_eq!(response.status, StatusCode::OK);
502 }
503
504 #[rstest]
505 #[case::get_returns_301(Method::GET, StatusCode::MOVED_PERMANENTLY)]
506 #[case::head_returns_301(Method::HEAD, StatusCode::MOVED_PERMANENTLY)]
507 #[case::post_returns_307(Method::POST, StatusCode::TEMPORARY_REDIRECT)]
508 #[case::put_returns_307(Method::PUT, StatusCode::TEMPORARY_REDIRECT)]
509 #[case::patch_returns_307(Method::PATCH, StatusCode::TEMPORARY_REDIRECT)]
510 #[case::delete_returns_307(Method::DELETE, StatusCode::TEMPORARY_REDIRECT)]
511 #[tokio::test]
512 async fn test_redirect_status_by_method(
513 #[case] method: Method,
514 #[case] expected_status: StatusCode,
515 ) {
516 let config = CommonConfig {
518 append_slash: true,
519 prepend_www: false,
520 };
521 let middleware = CommonMiddleware::with_config(config);
522 let handler = Arc::new(TestHandler);
523
524 let request = Request::builder()
525 .method(method)
526 .uri("/path/to/page")
527 .version(Version::HTTP_11)
528 .headers(HeaderMap::new())
529 .body(Bytes::new())
530 .build()
531 .unwrap();
532
533 let response = middleware.process(request, handler).await.unwrap();
535
536 assert_eq!(response.status, expected_status);
538 assert!(response.headers.contains_key(hyper::header::LOCATION));
539 }
540}