1use crate::middleware::v2::{Middleware, Next, NextFuture};
7use crate::request::{ElifMethod, ElifRequest};
8use crate::response::{ElifResponse, ElifStatusCode};
9use std::collections::HashSet;
10use std::path::Path;
11use std::sync::{Arc, RwLock};
12
13#[derive(Debug, Clone)]
15pub enum MaintenanceResponse {
16 Text(String),
18 Json(serde_json::Value),
20 Html(String),
22 Custom {
24 status_code: ElifStatusCode,
25 content_type: String,
26 body: Vec<u8>,
27 },
28 File(String),
30}
31
32impl Default for MaintenanceResponse {
33 fn default() -> Self {
34 Self::Json(serde_json::json!({
35 "error": {
36 "code": "maintenance_mode",
37 "message": "Service temporarily unavailable due to maintenance",
38 "hint": "Please try again later"
39 }
40 }))
41 }
42}
43
44impl MaintenanceResponse {
45 pub async fn to_elif_response(&self) -> ElifResponse {
47 match self {
48 Self::Text(text) => {
49 ElifResponse::with_status(ElifStatusCode::SERVICE_UNAVAILABLE).text(text.clone())
50 }
51 Self::Json(json) => ElifResponse::with_status(ElifStatusCode::SERVICE_UNAVAILABLE)
52 .json_value(json.clone()),
53 Self::Html(html) => ElifResponse::with_status(ElifStatusCode::SERVICE_UNAVAILABLE)
54 .content_type("text/html")
55 .unwrap_or_else(|_| ElifResponse::with_status(ElifStatusCode::SERVICE_UNAVAILABLE))
56 .text(html.clone()),
57 Self::Custom {
58 status_code,
59 content_type,
60 body,
61 } => ElifResponse::with_status(*status_code)
62 .content_type(content_type)
63 .unwrap_or_else(|_| ElifResponse::with_status(*status_code))
64 .bytes(axum::body::Bytes::copy_from_slice(body)),
65 Self::File(path) => {
66 match tokio::fs::read(path).await {
68 Ok(content) => {
69 let content_type =
71 match Path::new(path).extension().and_then(|ext| ext.to_str()) {
72 Some("html") => "text/html",
73 Some("json") => "application/json",
74 Some("txt") => "text/plain",
75 _ => "text/plain",
76 };
77
78 ElifResponse::with_status(ElifStatusCode::SERVICE_UNAVAILABLE)
79 .content_type(content_type)
80 .unwrap_or_else(|_| {
81 ElifResponse::with_status(ElifStatusCode::SERVICE_UNAVAILABLE)
82 })
83 .bytes(axum::body::Bytes::from(content))
84 }
85 Err(_) => {
86 ElifResponse::with_status(ElifStatusCode::SERVICE_UNAVAILABLE).json_value(
88 serde_json::json!({
89 "error": {
90 "code": "maintenance_mode",
91 "message": "Service temporarily unavailable"
92 }
93 }),
94 )
95 }
96 }
97 }
98 }
99 }
100}
101
102#[derive(Debug)]
104pub enum PathMatch {
105 Exact(String),
107 Prefix(String),
109 Regex(regex::Regex),
111 Custom(fn(&str) -> bool),
113}
114
115impl PathMatch {
116 pub fn regex(pattern: &str) -> Result<Self, regex::Error> {
118 Ok(Self::Regex(regex::Regex::new(pattern)?))
119 }
120
121 pub fn matches(&self, path: &str) -> bool {
123 match self {
124 Self::Exact(exact_path) => path == exact_path,
125 Self::Prefix(prefix) => path.starts_with(prefix),
126 Self::Regex(compiled_regex) => compiled_regex.is_match(path),
127 Self::Custom(matcher) => matcher(path),
128 }
129 }
130}
131
132impl Clone for PathMatch {
133 fn clone(&self) -> Self {
134 match self {
135 Self::Exact(s) => Self::Exact(s.clone()),
136 Self::Prefix(s) => Self::Prefix(s.clone()),
137 Self::Regex(regex) => Self::Regex(regex.clone()),
138 Self::Custom(f) => Self::Custom(*f),
139 }
140 }
141}
142
143impl PartialEq for PathMatch {
144 fn eq(&self, other: &Self) -> bool {
145 match (self, other) {
146 (Self::Exact(a), Self::Exact(b)) => a == b,
147 (Self::Prefix(a), Self::Prefix(b)) => a == b,
148 (Self::Regex(a), Self::Regex(b)) => a.as_str() == b.as_str(),
149 (Self::Custom(a), Self::Custom(b)) => std::ptr::eq(a as *const _, b as *const _),
150 _ => false,
151 }
152 }
153}
154
155#[derive(Debug)]
157pub struct MaintenanceModeConfig {
158 pub enabled: Arc<RwLock<bool>>,
160 pub response: MaintenanceResponse,
162 pub allowed_paths: Vec<PathMatch>,
164 pub allowed_methods: HashSet<ElifMethod>,
166 pub allowed_ips: HashSet<String>,
168 pub bypass_header: Option<(String, String)>,
170 pub add_retry_after: Option<u64>,
172}
173
174impl Default for MaintenanceModeConfig {
175 fn default() -> Self {
176 let mut allowed_methods = HashSet::new();
177 allowed_methods.insert(ElifMethod::GET); Self {
180 enabled: Arc::new(RwLock::new(false)),
181 response: MaintenanceResponse::default(),
182 allowed_paths: vec![
183 PathMatch::Exact("/health".to_string()),
184 PathMatch::Exact("/ping".to_string()),
185 PathMatch::Prefix("/status".to_string()),
186 ],
187 allowed_methods,
188 allowed_ips: HashSet::new(),
189 bypass_header: None,
190 add_retry_after: Some(3600), }
192 }
193}
194
195#[derive(Debug)]
197pub struct MaintenanceModeMiddleware {
198 config: MaintenanceModeConfig,
199}
200
201impl MaintenanceModeMiddleware {
202 pub fn new() -> Self {
204 Self {
205 config: MaintenanceModeConfig::default(),
206 }
207 }
208
209 pub fn with_config(config: MaintenanceModeConfig) -> Self {
211 Self { config }
212 }
213
214 pub fn enable(
216 &self,
217 ) -> Result<(), std::sync::PoisonError<std::sync::RwLockWriteGuard<'_, bool>>> {
218 let mut enabled = self.config.enabled.write()?;
219 *enabled = true;
220 Ok(())
221 }
222
223 pub fn disable(
225 &self,
226 ) -> Result<(), std::sync::PoisonError<std::sync::RwLockWriteGuard<'_, bool>>> {
227 let mut enabled = self.config.enabled.write()?;
228 *enabled = false;
229 Ok(())
230 }
231
232 pub fn is_enabled(&self) -> bool {
234 self.config
235 .enabled
236 .read()
237 .map(|enabled| *enabled)
238 .unwrap_or(false)
239 }
240
241 pub fn response(mut self, response: MaintenanceResponse) -> Self {
243 self.config.response = response;
244 self
245 }
246
247 pub fn allow_path(mut self, path: impl Into<String>) -> Self {
249 self.config
250 .allowed_paths
251 .push(PathMatch::Exact(path.into()));
252 self
253 }
254
255 pub fn allow_prefix(mut self, prefix: impl Into<String>) -> Self {
257 self.config
258 .allowed_paths
259 .push(PathMatch::Prefix(prefix.into()));
260 self
261 }
262
263 pub fn allow_regex(mut self, pattern: &str) -> Result<Self, regex::Error> {
265 self.config.allowed_paths.push(PathMatch::regex(pattern)?);
266 Ok(self)
267 }
268
269 pub fn allow_custom(mut self, matcher: fn(&str) -> bool) -> Self {
271 self.config.allowed_paths.push(PathMatch::Custom(matcher));
272 self
273 }
274
275 pub fn allow_method(mut self, method: ElifMethod) -> Self {
277 self.config.allowed_methods.insert(method);
278 self
279 }
280
281 pub fn allow_ip(mut self, ip: impl Into<String>) -> Self {
283 self.config.allowed_ips.insert(ip.into());
284 self
285 }
286
287 pub fn bypass_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
289 self.config.bypass_header = Some((name.into(), value.into()));
290 self
291 }
292
293 pub fn retry_after(mut self, seconds: u64) -> Self {
295 self.config.add_retry_after = Some(seconds);
296 self
297 }
298
299 pub fn no_retry_after(mut self) -> Self {
301 self.config.add_retry_after = None;
302 self
303 }
304
305 pub fn enabled(self) -> Self {
307 let _ = self.enable();
308 self
309 }
310
311 fn should_allow_request(&self, request: &ElifRequest) -> bool {
313 if !self.is_enabled() {
315 return true;
316 }
317
318 let path = request.path();
322 for path_match in &self.config.allowed_paths {
323 if path_match.matches(path) {
324 return true;
325 }
326 }
327
328 if let Some((header_name, expected_value)) = &self.config.bypass_header {
330 if let Some(header_value) = request.header(header_name) {
331 if let Ok(value_str) = header_value.to_str() {
332 if value_str == expected_value {
333 return true;
334 }
335 }
336 }
337 }
338
339 let client_ip = request
342 .header("x-forwarded-for")
343 .or_else(|| request.header("x-real-ip"))
344 .and_then(|h| h.to_str().ok())
345 .map(|s| s.split(',').next().unwrap_or(s).trim());
346
347 if let Some(ip) = client_ip {
348 if self.config.allowed_ips.contains(ip) {
349 return true;
350 }
351 }
352
353 false
354 }
355
356 async fn create_maintenance_response(&self) -> ElifResponse {
358 let mut response = self.config.response.to_elif_response().await;
359
360 if let Some(retry_after) = self.config.add_retry_after {
362 response = response
363 .header("retry-after", retry_after.to_string())
364 .unwrap_or_else(|_| ElifResponse::with_status(ElifStatusCode::SERVICE_UNAVAILABLE));
365 }
366
367 response
368 }
369}
370
371impl Default for MaintenanceModeMiddleware {
372 fn default() -> Self {
373 Self::new()
374 }
375}
376
377impl Middleware for MaintenanceModeMiddleware {
378 fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
379 let should_allow = self.should_allow_request(&request);
380 let config = MaintenanceModeConfig {
381 enabled: Arc::clone(&self.config.enabled),
382 response: self.config.response.clone(),
383 allowed_paths: self.config.allowed_paths.clone(),
384 allowed_methods: self.config.allowed_methods.clone(),
385 allowed_ips: self.config.allowed_ips.clone(),
386 bypass_header: self.config.bypass_header.clone(),
387 add_retry_after: self.config.add_retry_after,
388 };
389
390 Box::pin(async move {
391 if should_allow {
392 next.run(request).await
394 } else {
395 let middleware = MaintenanceModeMiddleware { config };
397 middleware.create_maintenance_response().await
398 }
399 })
400 }
401
402 fn name(&self) -> &'static str {
403 "MaintenanceModeMiddleware"
404 }
405}
406
407pub struct MaintenanceModeBuilder {
409 enabled: Arc<RwLock<bool>>,
410}
411
412impl MaintenanceModeBuilder {
413 pub fn new() -> Self {
415 Self {
416 enabled: Arc::new(RwLock::new(false)),
417 }
418 }
419
420 pub fn enable(&self) {
422 if let Ok(mut enabled) = self.enabled.write() {
423 *enabled = true;
424 }
425 }
426
427 pub fn disable(&self) {
429 if let Ok(mut enabled) = self.enabled.write() {
430 *enabled = false;
431 }
432 }
433
434 pub fn is_enabled(&self) -> bool {
436 self.enabled.read().map(|enabled| *enabled).unwrap_or(false)
437 }
438
439 pub fn build(&self) -> MaintenanceModeMiddleware {
441 let config = MaintenanceModeConfig {
442 enabled: Arc::clone(&self.enabled),
443 ..Default::default()
444 };
445 MaintenanceModeMiddleware::with_config(config)
446 }
447
448 pub fn build_with_config(
450 &self,
451 mut config: MaintenanceModeConfig,
452 ) -> MaintenanceModeMiddleware {
453 config.enabled = Arc::clone(&self.enabled);
454 MaintenanceModeMiddleware::with_config(config)
455 }
456}
457
458impl Default for MaintenanceModeBuilder {
459 fn default() -> Self {
460 Self::new()
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::request::ElifRequest;
468 use crate::response::headers::ElifHeaderMap;
469 use crate::response::ElifResponse;
470
471 #[test]
472 fn test_path_matching() {
473 let exact = PathMatch::Exact("/health".to_string());
474 assert!(exact.matches("/health"));
475 assert!(!exact.matches("/health-check"));
476
477 let prefix = PathMatch::Prefix("/api/".to_string());
478 assert!(prefix.matches("/api/users"));
479 assert!(prefix.matches("/api/"));
480 assert!(!prefix.matches("/v1/api/users"));
481
482 let regex = PathMatch::regex(r"^/api/v\d+/.*").unwrap();
483 assert!(regex.matches("/api/v1/users"));
484 assert!(regex.matches("/api/v2/posts"));
485 assert!(!regex.matches("/api/users"));
486
487 let custom = PathMatch::Custom(|path| path.ends_with(".json"));
488 assert!(custom.matches("/data.json"));
489 assert!(!custom.matches("/data.xml"));
490 }
491
492 #[tokio::test]
493 async fn test_maintenance_response_types() {
494 let text_response = MaintenanceResponse::Text("Under maintenance".to_string());
496 let response = text_response.to_elif_response().await;
497 assert_eq!(response.status_code(), ElifStatusCode::SERVICE_UNAVAILABLE);
498
499 let json_response = MaintenanceResponse::Json(serde_json::json!({
501 "error": "maintenance"
502 }));
503 let response = json_response.to_elif_response().await;
504 assert_eq!(response.status_code(), ElifStatusCode::SERVICE_UNAVAILABLE);
505
506 let html_response = MaintenanceResponse::Html("<h1>Maintenance</h1>".to_string());
508 let response = html_response.to_elif_response().await;
509 assert_eq!(response.status_code(), ElifStatusCode::SERVICE_UNAVAILABLE);
510
511 let custom_response = MaintenanceResponse::Custom {
513 status_code: ElifStatusCode::LOCKED,
514 content_type: "text/plain".to_string(),
515 body: b"Locked".to_vec(),
516 };
517 let response = custom_response.to_elif_response().await;
518 assert_eq!(response.status_code(), ElifStatusCode::LOCKED);
519 }
520
521 #[tokio::test]
522 async fn test_maintenance_mode_disabled() {
523 let middleware = MaintenanceModeMiddleware::new(); let request = ElifRequest::new(
526 ElifMethod::GET,
527 "/api/data".parse().unwrap(),
528 ElifHeaderMap::new(),
529 );
530
531 let next =
532 Next::new(|_req| Box::pin(async move { ElifResponse::ok().text("Normal response") }));
533
534 let response = middleware.handle(request, next).await;
535 assert_eq!(response.status_code(), ElifStatusCode::OK);
536 }
537
538 #[tokio::test]
539 async fn test_maintenance_mode_enabled() {
540 let middleware = MaintenanceModeMiddleware::new().enabled();
541
542 let request = ElifRequest::new(
543 ElifMethod::POST,
544 "/api/data".parse().unwrap(),
545 ElifHeaderMap::new(),
546 );
547
548 let next = Next::new(|_req| {
549 Box::pin(async move { ElifResponse::ok().text("Should not reach here") })
550 });
551
552 let response = middleware.handle(request, next).await;
553 assert_eq!(response.status_code(), ElifStatusCode::SERVICE_UNAVAILABLE);
554 }
555
556 #[tokio::test]
557 async fn test_maintenance_mode_allowed_paths() {
558 let middleware = MaintenanceModeMiddleware::new()
559 .enabled()
560 .allow_path("/health");
561
562 let request = ElifRequest::new(
564 ElifMethod::GET,
565 "/health".parse().unwrap(),
566 ElifHeaderMap::new(),
567 );
568
569 let next = Next::new(|_req| Box::pin(async move { ElifResponse::ok().text("Healthy") }));
570
571 let response = middleware.handle(request, next).await;
572 assert_eq!(response.status_code(), ElifStatusCode::OK);
573
574 let request = ElifRequest::new(
576 ElifMethod::GET,
577 "/api/data".parse().unwrap(),
578 ElifHeaderMap::new(),
579 );
580
581 let next =
582 Next::new(|_req| Box::pin(async move { ElifResponse::ok().text("Should be blocked") }));
583
584 let response = middleware.handle(request, next).await;
585 assert_eq!(response.status_code(), ElifStatusCode::SERVICE_UNAVAILABLE);
586 }
587
588 #[tokio::test]
589 async fn test_maintenance_mode_bypass_header() {
590 let middleware = MaintenanceModeMiddleware::new()
591 .enabled()
592 .bypass_header("x-admin-key", "secret123");
593
594 let mut headers = ElifHeaderMap::new();
596 headers.insert(
597 crate::response::headers::ElifHeaderName::from_str("x-admin-key").unwrap(),
598 "secret123".parse().unwrap(),
599 );
600 let request = ElifRequest::new(ElifMethod::GET, "/admin/panel".parse().unwrap(), headers);
601
602 let next =
603 Next::new(|_req| Box::pin(async move { ElifResponse::ok().text("Admin panel") }));
604
605 let response = middleware.handle(request, next).await;
606 assert_eq!(response.status_code(), ElifStatusCode::OK);
607
608 let mut headers = ElifHeaderMap::new();
610 headers.insert(
611 crate::response::headers::ElifHeaderName::from_str("x-admin-key").unwrap(),
612 "wrong-key".parse().unwrap(),
613 );
614 let request = ElifRequest::new(ElifMethod::GET, "/admin/panel".parse().unwrap(), headers);
615
616 let next =
617 Next::new(|_req| Box::pin(async move { ElifResponse::ok().text("Should be blocked") }));
618
619 let response = middleware.handle(request, next).await;
620 assert_eq!(response.status_code(), ElifStatusCode::SERVICE_UNAVAILABLE);
621 }
622
623 #[tokio::test]
624 async fn test_maintenance_mode_allowed_ips() {
625 let middleware = MaintenanceModeMiddleware::new()
626 .enabled()
627 .allow_ip("192.168.1.100");
628
629 let mut headers = ElifHeaderMap::new();
631 headers.insert(
632 crate::response::headers::ElifHeaderName::from_str("x-forwarded-for").unwrap(),
633 "192.168.1.100".parse().unwrap(),
634 );
635 let request = ElifRequest::new(ElifMethod::GET, "/api/data".parse().unwrap(), headers);
636
637 let next = Next::new(|_req| Box::pin(async move { ElifResponse::ok().text("Allowed IP") }));
638
639 let response = middleware.handle(request, next).await;
640 assert_eq!(response.status_code(), ElifStatusCode::OK);
641 }
642
643 #[tokio::test]
644 async fn test_maintenance_mode_builder() {
645 let builder = MaintenanceModeBuilder::new();
646 let middleware = builder.build();
647
648 assert!(!builder.is_enabled());
649
650 builder.enable();
652 assert!(builder.is_enabled());
653
654 let request = ElifRequest::new(
655 ElifMethod::GET,
656 "/api/data".parse().unwrap(),
657 ElifHeaderMap::new(),
658 );
659
660 let next =
661 Next::new(|_req| Box::pin(async move { ElifResponse::ok().text("Should be blocked") }));
662
663 let response = middleware.handle(request, next).await;
664 assert_eq!(response.status_code(), ElifStatusCode::SERVICE_UNAVAILABLE);
665
666 builder.disable();
668 assert!(!builder.is_enabled());
669 }
670
671 #[test]
672 fn test_middleware_builder_pattern() {
673 let middleware = MaintenanceModeMiddleware::new()
674 .allow_path("/health")
675 .allow_prefix("/status")
676 .allow_method(ElifMethod::OPTIONS)
677 .allow_ip("127.0.0.1")
678 .bypass_header("x-bypass", "secret")
679 .retry_after(7200)
680 .enabled();
681
682 assert!(middleware.is_enabled());
683 assert_eq!(middleware.config.allowed_paths.len(), 5); assert!(middleware
685 .config
686 .allowed_methods
687 .contains(&ElifMethod::OPTIONS));
688 assert!(middleware.config.allowed_ips.contains("127.0.0.1"));
689 assert_eq!(
690 middleware.config.bypass_header,
691 Some(("x-bypass".to_string(), "secret".to_string()))
692 );
693 assert_eq!(middleware.config.add_retry_after, Some(7200));
694 }
695
696 #[test]
697 fn test_regex_performance_improvement() {
698 let regex_matcher = PathMatch::regex(r"^/api/v\d+/.*").unwrap();
700
701 assert!(regex_matcher.matches("/api/v1/users"));
704 assert!(regex_matcher.matches("/api/v2/posts"));
705 assert!(regex_matcher.matches("/api/v3/comments"));
706 assert!(!regex_matcher.matches("/api/users"));
707 assert!(!regex_matcher.matches("/v1/api/users"));
708
709 let invalid_regex = PathMatch::regex(r"[invalid");
711 assert!(invalid_regex.is_err());
712 }
713
714 #[test]
715 fn test_path_match_clone_and_equality() {
716 let exact1 = PathMatch::Exact("/test".to_string());
717 let exact2 = PathMatch::Exact("/test".to_string());
718 let exact3 = PathMatch::Exact("/other".to_string());
719
720 assert_eq!(exact1, exact2);
721 assert_ne!(exact1, exact3);
722
723 let cloned = exact1.clone();
724 assert_eq!(exact1, cloned);
725
726 let regex1 = PathMatch::regex(r"^/api/.*").unwrap();
727 let regex2 = PathMatch::regex(r"^/api/.*").unwrap();
728 let regex3 = PathMatch::regex(r"^/other/.*").unwrap();
729
730 assert_eq!(regex1, regex2); assert_ne!(regex1, regex3); let cloned_regex = regex1.clone();
734 assert_eq!(regex1, cloned_regex);
735 }
736}