1use std::collections::HashSet;
54use std::future::Future;
55use std::ops::ControlFlow as StdControlFlow;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::Instant;
59
60use crate::context::RequestContext;
61use crate::dependency::DependencyOverrides;
62use crate::logging::{LogConfig, RequestLogger};
63use crate::request::{Body, Request};
64use crate::response::Response;
65
66pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
68
69#[derive(Debug)]
74pub enum ControlFlow {
75 Continue,
77 Break(Response),
82}
83
84impl ControlFlow {
85 #[must_use]
87 pub fn is_continue(&self) -> bool {
88 matches!(self, Self::Continue)
89 }
90
91 #[must_use]
93 pub fn is_break(&self) -> bool {
94 matches!(self, Self::Break(_))
95 }
96}
97
98impl From<ControlFlow> for StdControlFlow<Response, ()> {
99 fn from(cf: ControlFlow) -> Self {
100 match cf {
101 ControlFlow::Continue => StdControlFlow::Continue(()),
102 ControlFlow::Break(r) => StdControlFlow::Break(r),
103 }
104 }
105}
106
107pub trait Middleware: Send + Sync {
148 fn before<'a>(
164 &'a self,
165 _ctx: &'a RequestContext,
166 _req: &'a mut Request,
167 ) -> BoxFuture<'a, ControlFlow> {
168 Box::pin(async { ControlFlow::Continue })
169 }
170
171 fn after<'a>(
187 &'a self,
188 _ctx: &'a RequestContext,
189 _req: &'a Request,
190 response: Response,
191 ) -> BoxFuture<'a, Response> {
192 Box::pin(async move { response })
193 }
194
195 fn name(&self) -> &'static str {
199 std::any::type_name::<Self>()
200 }
201}
202
203pub trait Handler: Send + Sync {
208 fn call<'a>(&'a self, ctx: &'a RequestContext, req: &'a mut Request)
210 -> BoxFuture<'a, Response>;
211
212 fn dependency_overrides(&self) -> Option<Arc<DependencyOverrides>> {
216 None
217 }
218}
219
220impl<F, Fut> Handler for F
225where
226 F: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync,
227 Fut: Future<Output = Response> + Send + 'static,
228{
229 fn call<'a>(
230 &'a self,
231 ctx: &'a RequestContext,
232 req: &'a mut Request,
233 ) -> BoxFuture<'a, Response> {
234 let fut = self(ctx, req);
235 Box::pin(fut)
236 }
237}
238
239impl<H: Handler + ?Sized> Handler for Arc<H> {
244 fn call<'a>(
245 &'a self,
246 ctx: &'a RequestContext,
247 req: &'a mut Request,
248 ) -> BoxFuture<'a, Response> {
249 (**self).call(ctx, req)
250 }
251
252 fn dependency_overrides(&self) -> Option<Arc<DependencyOverrides>> {
253 (**self).dependency_overrides()
254 }
255}
256
257#[derive(Default)]
275pub struct MiddlewareStack {
276 middleware: Vec<Arc<dyn Middleware>>,
277}
278
279impl MiddlewareStack {
280 #[must_use]
282 pub fn new() -> Self {
283 Self {
284 middleware: Vec::new(),
285 }
286 }
287
288 #[must_use]
290 pub fn with_capacity(capacity: usize) -> Self {
291 Self {
292 middleware: Vec::with_capacity(capacity),
293 }
294 }
295
296 pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
300 self.middleware.push(Arc::new(middleware));
301 }
302
303 pub fn push_arc(&mut self, middleware: Arc<dyn Middleware>) {
307 self.middleware.push(middleware);
308 }
309
310 #[must_use]
312 pub fn len(&self) -> usize {
313 self.middleware.len()
314 }
315
316 #[must_use]
318 pub fn is_empty(&self) -> bool {
319 self.middleware.is_empty()
320 }
321
322 pub async fn execute<H: Handler>(
340 &self,
341 handler: &H,
342 ctx: &RequestContext,
343 req: &mut Request,
344 ) -> Response {
345 let mut ran_before_count = 0;
347
348 for mw in &self.middleware {
350 let _ = ctx.checkpoint();
351 match mw.before(ctx, req).await {
352 ControlFlow::Continue => {
353 ran_before_count += 1;
354 }
355 ControlFlow::Break(response) => {
356 return self
358 .run_after_hooks(ctx, req, response, ran_before_count)
359 .await;
360 }
361 }
362 }
363
364 let _ = ctx.checkpoint();
366 let response = handler.call(ctx, req).await;
367
368 self.run_after_hooks(ctx, req, response, ran_before_count)
370 .await
371 }
372
373 async fn run_after_hooks(
375 &self,
376 ctx: &RequestContext,
377 req: &Request,
378 mut response: Response,
379 count: usize,
380 ) -> Response {
381 for mw in self.middleware[..count].iter().rev() {
383 let _ = ctx.checkpoint();
384 response = mw.after(ctx, req, response).await;
385 }
386 response
387 }
388}
389
390pub struct Layer<M> {
401 middleware: M,
402}
403
404impl<M: Middleware + Clone> Layer<M> {
405 pub fn new(middleware: M) -> Self {
407 Self { middleware }
408 }
409
410 pub fn wrap<H: Handler>(&self, handler: H) -> Layered<M, H> {
412 Layered {
413 middleware: self.middleware.clone(),
414 inner: handler,
415 }
416 }
417}
418
419pub struct Layered<M, H> {
421 middleware: M,
422 inner: H,
423}
424
425impl<M: Middleware, H: Handler> Handler for Layered<M, H> {
426 fn call<'a>(
427 &'a self,
428 ctx: &'a RequestContext,
429 req: &'a mut Request,
430 ) -> BoxFuture<'a, Response> {
431 Box::pin(async move {
432 let _ = ctx.checkpoint();
434 match self.middleware.before(ctx, req).await {
435 ControlFlow::Continue => {
436 let _ = ctx.checkpoint();
438 let response = self.inner.call(ctx, req).await;
439 let _ = ctx.checkpoint();
441 self.middleware.after(ctx, req, response).await
442 }
443 ControlFlow::Break(response) => {
444 let _ = ctx.checkpoint();
446 self.middleware.after(ctx, req, response).await
447 }
448 }
449 })
450 }
451}
452
453#[derive(Debug, Clone, Copy, Default)]
461pub struct NoopMiddleware;
462
463impl Middleware for NoopMiddleware {
464 fn name(&self) -> &'static str {
465 "Noop"
466 }
467}
468
469#[derive(Debug, Clone)]
479pub struct AddResponseHeader {
480 name: String,
481 value: Vec<u8>,
482}
483
484impl AddResponseHeader {
485 pub fn new(name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
487 Self {
488 name: name.into(),
489 value: value.into(),
490 }
491 }
492}
493
494impl Middleware for AddResponseHeader {
495 fn after<'a>(
496 &'a self,
497 _ctx: &'a RequestContext,
498 _req: &'a Request,
499 response: Response,
500 ) -> BoxFuture<'a, Response> {
501 let name = self.name.clone();
502 let value = self.value.clone();
503 Box::pin(async move { response.header(name, value) })
504 }
505
506 fn name(&self) -> &'static str {
507 "AddResponseHeader"
508 }
509}
510
511#[derive(Debug, Clone)]
523pub struct RequireHeader {
524 name: String,
525}
526
527impl RequireHeader {
528 pub fn new(name: impl Into<String>) -> Self {
530 Self { name: name.into() }
531 }
532}
533
534impl Middleware for RequireHeader {
535 fn before<'a>(
536 &'a self,
537 _ctx: &'a RequestContext,
538 req: &'a mut Request,
539 ) -> BoxFuture<'a, ControlFlow> {
540 let has_header = req.headers().get(&self.name).is_some();
541 let name = self.name.clone();
542 Box::pin(async move {
543 if has_header {
544 ControlFlow::Continue
545 } else {
546 let body = format!("Missing required header: {name}");
547 ControlFlow::Break(
548 Response::with_status(crate::response::StatusCode::BAD_REQUEST)
549 .header("content-type", b"text/plain".to_vec())
550 .body(crate::response::ResponseBody::Bytes(body.into_bytes())),
551 )
552 }
553 })
554 }
555
556 fn name(&self) -> &'static str {
557 "RequireHeader"
558 }
559}
560
561#[derive(Debug, Clone)]
574pub struct PathPrefixFilter {
575 prefix: String,
576}
577
578impl PathPrefixFilter {
579 pub fn new(prefix: impl Into<String>) -> Self {
581 Self {
582 prefix: prefix.into(),
583 }
584 }
585}
586
587impl Middleware for PathPrefixFilter {
588 fn before<'a>(
589 &'a self,
590 _ctx: &'a RequestContext,
591 req: &'a mut Request,
592 ) -> BoxFuture<'a, ControlFlow> {
593 let path_matches = req.path().starts_with(&self.prefix);
594 Box::pin(async move {
595 if path_matches {
596 ControlFlow::Continue
597 } else {
598 ControlFlow::Break(Response::with_status(
599 crate::response::StatusCode::NOT_FOUND,
600 ))
601 }
602 })
603 }
604
605 fn name(&self) -> &'static str {
606 "PathPrefixFilter"
607 }
608}
609
610#[derive(Debug, Clone)]
614pub struct ConditionalStatus<F>
615where
616 F: Fn(&Request) -> bool + Send + Sync,
617{
618 condition: F,
619 status_if_true: crate::response::StatusCode,
620 status_if_false: crate::response::StatusCode,
621}
622
623impl<F> ConditionalStatus<F>
624where
625 F: Fn(&Request) -> bool + Send + Sync,
626{
627 pub fn new(
632 condition: F,
633 status_if_true: crate::response::StatusCode,
634 status_if_false: crate::response::StatusCode,
635 ) -> Self {
636 Self {
637 condition,
638 status_if_true,
639 status_if_false,
640 }
641 }
642}
643
644impl<F> Middleware for ConditionalStatus<F>
645where
646 F: Fn(&Request) -> bool + Send + Sync,
647{
648 fn after<'a>(
649 &'a self,
650 _ctx: &'a RequestContext,
651 req: &'a Request,
652 response: Response,
653 ) -> BoxFuture<'a, Response> {
654 let matches = (self.condition)(req);
655 let status = if matches {
656 self.status_if_true
657 } else {
658 self.status_if_false
659 };
660 Box::pin(async move { Response::with_status(status).body(response.body_ref().into()) })
661 }
662
663 fn name(&self) -> &'static str {
664 "ConditionalStatus"
665 }
666}
667
668#[derive(Debug, Clone)]
674pub enum OriginPattern {
675 Any,
677 Exact(String),
679 Wildcard(String),
681 Regex(String),
683}
684
685impl OriginPattern {
686 fn matches(&self, origin: &str) -> bool {
687 match self {
688 Self::Any => true,
689 Self::Exact(value) => value == origin,
690 Self::Wildcard(pattern) => wildcard_match(pattern, origin),
691 Self::Regex(pattern) => regex_match(pattern, origin),
692 }
693 }
694}
695
696#[derive(Debug, Clone)]
745pub struct CorsConfig {
746 allow_any_origin: bool,
747 allow_credentials: bool,
748 allowed_methods: Vec<crate::request::Method>,
749 allowed_headers: Vec<String>,
750 expose_headers: Vec<String>,
751 max_age: Option<u32>,
752 origins: Vec<OriginPattern>,
753}
754
755impl Default for CorsConfig {
756 fn default() -> Self {
757 Self {
758 allow_any_origin: false,
759 allow_credentials: false,
760 allowed_methods: vec![
761 crate::request::Method::Get,
762 crate::request::Method::Post,
763 crate::request::Method::Put,
764 crate::request::Method::Patch,
765 crate::request::Method::Delete,
766 crate::request::Method::Options,
767 crate::request::Method::Head,
768 ],
769 allowed_headers: Vec::new(),
770 expose_headers: Vec::new(),
771 max_age: None,
772 origins: Vec::new(),
773 }
774 }
775}
776
777#[derive(Debug, Clone)]
779pub struct Cors {
780 config: CorsConfig,
781}
782
783impl Cors {
784 #[must_use]
786 pub fn new() -> Self {
787 Self {
788 config: CorsConfig::default(),
789 }
790 }
791
792 #[must_use]
794 pub fn config(mut self, config: CorsConfig) -> Self {
795 self.config = config;
796 self
797 }
798
799 #[must_use]
801 pub fn allow_any_origin(mut self) -> Self {
802 self.config.allow_any_origin = true;
803 self
804 }
805
806 #[must_use]
808 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
809 self.config
810 .origins
811 .push(OriginPattern::Exact(origin.into()));
812 self
813 }
814
815 #[must_use]
817 pub fn allow_origin_wildcard(mut self, pattern: impl Into<String>) -> Self {
818 self.config
819 .origins
820 .push(OriginPattern::Wildcard(pattern.into()));
821 self
822 }
823
824 #[must_use]
826 pub fn allow_origin_regex(mut self, pattern: impl Into<String>) -> Self {
827 self.config
828 .origins
829 .push(OriginPattern::Regex(pattern.into()));
830 self
831 }
832
833 #[must_use]
835 pub fn allow_credentials(mut self, allow: bool) -> Self {
836 self.config.allow_credentials = allow;
837 self
838 }
839
840 #[must_use]
842 pub fn allow_methods<I>(mut self, methods: I) -> Self
843 where
844 I: IntoIterator<Item = crate::request::Method>,
845 {
846 self.config.allowed_methods = methods.into_iter().collect();
847 self
848 }
849
850 #[must_use]
852 pub fn allow_headers<I, S>(mut self, headers: I) -> Self
853 where
854 I: IntoIterator<Item = S>,
855 S: Into<String>,
856 {
857 self.config.allowed_headers = headers.into_iter().map(Into::into).collect();
858 self
859 }
860
861 #[must_use]
863 pub fn expose_headers<I, S>(mut self, headers: I) -> Self
864 where
865 I: IntoIterator<Item = S>,
866 S: Into<String>,
867 {
868 self.config.expose_headers = headers.into_iter().map(Into::into).collect();
869 self
870 }
871
872 #[must_use]
874 pub fn max_age(mut self, seconds: u32) -> Self {
875 self.config.max_age = Some(seconds);
876 self
877 }
878
879 fn is_origin_allowed(&self, origin: &str) -> bool {
880 if self.config.allow_any_origin {
881 return true;
882 }
883 self.config
884 .origins
885 .iter()
886 .any(|pattern| pattern.matches(origin))
887 }
888
889 fn allow_origin_value(&self, origin: &str) -> Option<String> {
890 if !self.is_origin_allowed(origin) {
891 return None;
892 }
893 if self.config.allow_any_origin && !self.config.allow_credentials {
894 Some("*".to_string())
895 } else {
896 Some(origin.to_string())
897 }
898 }
899
900 fn allow_methods_value(&self) -> String {
901 self.config
902 .allowed_methods
903 .iter()
904 .map(|method| method.as_str())
905 .collect::<Vec<_>>()
906 .join(", ")
907 }
908
909 fn allow_headers_value(&self, request: &Request) -> Option<String> {
910 if self.config.allowed_headers.is_empty() {
911 return None;
916 }
917
918 if self.config.allowed_headers.iter().any(|h| h == "*") {
923 if self.config.allow_credentials {
924 return request
927 .headers()
928 .get("access-control-request-headers")
929 .and_then(|value| std::str::from_utf8(value).ok())
930 .map(ToString::to_string);
931 }
932 return Some("*".to_string());
933 }
934
935 Some(self.config.allowed_headers.join(", "))
936 }
937
938 fn apply_common_headers(&self, mut response: Response, origin: &str) -> Response {
939 if let Some(allow_origin) = self.allow_origin_value(origin) {
940 let is_wildcard = allow_origin == "*";
941 response = response.header("access-control-allow-origin", allow_origin.into_bytes());
942 if !is_wildcard {
943 response = response.header("vary", b"Origin".to_vec());
944 }
945 if self.config.allow_credentials {
946 response = response.header("access-control-allow-credentials", b"true".to_vec());
947 }
948 if !self.config.expose_headers.is_empty() {
949 response = response.header(
950 "access-control-expose-headers",
951 self.config.expose_headers.join(", ").into_bytes(),
952 );
953 }
954 }
955 response
956 }
957}
958
959impl Default for Cors {
960 fn default() -> Self {
961 Self::new()
962 }
963}
964
965#[derive(Debug, Clone)]
966struct CorsOrigin(String);
967
968impl Middleware for Cors {
969 fn before<'a>(
970 &'a self,
971 _ctx: &'a RequestContext,
972 req: &'a mut Request,
973 ) -> BoxFuture<'a, ControlFlow> {
974 let origin = req
975 .headers()
976 .get("origin")
977 .and_then(|value| std::str::from_utf8(value).ok())
978 .map(ToString::to_string);
979
980 let Some(origin) = origin else {
981 return Box::pin(async { ControlFlow::Continue });
982 };
983
984 if !self.is_origin_allowed(&origin) {
985 let is_preflight = req.method() == crate::request::Method::Options
986 && req.headers().get("access-control-request-method").is_some();
987 if is_preflight {
988 return Box::pin(async {
989 ControlFlow::Break(Response::with_status(
990 crate::response::StatusCode::FORBIDDEN,
991 ))
992 });
993 }
994 return Box::pin(async { ControlFlow::Continue });
995 }
996
997 let is_preflight = req.method() == crate::request::Method::Options
998 && req.headers().get("access-control-request-method").is_some();
999
1000 if is_preflight {
1001 let mut response = Response::no_content();
1002 response = self.apply_common_headers(response, &origin);
1003 response = response.header(
1004 "access-control-allow-methods",
1005 self.allow_methods_value().into_bytes(),
1006 );
1007
1008 if let Some(value) = self.allow_headers_value(req) {
1009 response = response.header("access-control-allow-headers", value.into_bytes());
1010 }
1011
1012 if let Some(max_age) = self.config.max_age {
1013 response =
1014 response.header("access-control-max-age", max_age.to_string().into_bytes());
1015 }
1016
1017 return Box::pin(async move { ControlFlow::Break(response) });
1018 }
1019
1020 req.insert_extension(CorsOrigin(origin));
1021 Box::pin(async { ControlFlow::Continue })
1022 }
1023
1024 fn after<'a>(
1025 &'a self,
1026 _ctx: &'a RequestContext,
1027 req: &'a Request,
1028 response: Response,
1029 ) -> BoxFuture<'a, Response> {
1030 let origin = req.get_extension::<CorsOrigin>().map(|v| v.0.clone());
1031 Box::pin(async move {
1032 if let Some(origin) = origin {
1033 return self.apply_common_headers(response, &origin);
1034 }
1035 response
1036 })
1037 }
1038
1039 fn name(&self) -> &'static str {
1040 "Cors"
1041 }
1042}
1043
1044fn wildcard_match(pattern: &str, value: &str) -> bool {
1045 let mut pat_chars = pattern.chars().peekable();
1047 let mut val_chars = value.chars().peekable();
1048 let mut star = None;
1049 let mut match_after_star = None;
1050
1051 while let Some(p) = pat_chars.next() {
1052 match p {
1053 '*' => {
1054 star = Some(pat_chars.clone());
1055 match_after_star = Some(val_chars.clone());
1056 }
1057 _ => {
1058 if let Some(v) = val_chars.next() {
1059 if p != v {
1060 if let (Some(pat_backup), Some(val_backup)) =
1061 (star.clone(), match_after_star.clone())
1062 {
1063 pat_chars = pat_backup;
1064 val_chars = val_backup;
1065 val_chars.next();
1066 match_after_star = Some(val_chars.clone());
1067 continue;
1068 }
1069 return false;
1070 }
1071 } else {
1072 return false;
1073 }
1074 }
1075 }
1076 }
1077
1078 if pat_chars.peek().is_none() && val_chars.peek().is_none() {
1080 return true;
1081 }
1082
1083 if let Some(pat_backup) = star {
1084 if val_chars.peek().is_none() {
1085 let trailing = pat_backup;
1086 for ch in trailing {
1087 if ch != '*' {
1088 return false;
1089 }
1090 }
1091 return true;
1092 }
1093 }
1094
1095 val_chars.peek().is_none()
1096}
1097
1098fn regex_match(pattern: &str, value: &str) -> bool {
1099 let pat = pattern.as_bytes();
1101 let text = value.as_bytes();
1102
1103 if pat.first() == Some(&b'^') {
1104 return regex_match_here(&pat[1..], text);
1105 }
1106
1107 let mut i = 0;
1108 loop {
1109 if regex_match_here(pat, &text[i..]) {
1110 return true;
1111 }
1112 if i == text.len() {
1113 break;
1114 }
1115 i += 1;
1116 }
1117 false
1118}
1119
1120fn regex_match_here(pattern: &[u8], text: &[u8]) -> bool {
1121 if pattern.is_empty() {
1122 return true;
1123 }
1124 if pattern == b"$" {
1125 return text.is_empty();
1126 }
1127 if pattern.len() >= 2 && pattern[1] == b'*' {
1128 return regex_match_star(pattern[0], &pattern[2..], text);
1129 }
1130 if !text.is_empty() && (pattern[0] == b'.' || pattern[0] == text[0]) {
1131 return regex_match_here(&pattern[1..], &text[1..]);
1132 }
1133 false
1134}
1135
1136fn regex_match_star(ch: u8, pattern: &[u8], text: &[u8]) -> bool {
1137 let mut i = 0;
1138 loop {
1139 if regex_match_here(pattern, &text[i..]) {
1140 return true;
1141 }
1142 if i == text.len() {
1143 return false;
1144 }
1145 if ch != b'.' && text[i] != ch {
1146 return false;
1147 }
1148 i += 1;
1149 }
1150}
1151
1152#[derive(Debug, Clone)]
1158pub struct RequestResponseLogger {
1159 log_config: LogConfig,
1160 redact_headers: HashSet<String>,
1161 log_request_headers: bool,
1162 log_response_headers: bool,
1163 log_body: bool,
1164 max_body_bytes: usize,
1165}
1166
1167impl Default for RequestResponseLogger {
1168 fn default() -> Self {
1169 Self {
1170 log_config: LogConfig::production(),
1171 redact_headers: default_redacted_headers(),
1172 log_request_headers: true,
1173 log_response_headers: true,
1174 log_body: false,
1175 max_body_bytes: 1024,
1176 }
1177 }
1178}
1179
1180impl RequestResponseLogger {
1181 #[must_use]
1183 pub fn new() -> Self {
1184 Self::default()
1185 }
1186
1187 #[must_use]
1189 pub fn log_config(mut self, config: LogConfig) -> Self {
1190 self.log_config = config;
1191 self
1192 }
1193
1194 #[must_use]
1196 pub fn log_request_headers(mut self, enabled: bool) -> Self {
1197 self.log_request_headers = enabled;
1198 self
1199 }
1200
1201 #[must_use]
1203 pub fn log_response_headers(mut self, enabled: bool) -> Self {
1204 self.log_response_headers = enabled;
1205 self
1206 }
1207
1208 #[must_use]
1210 pub fn log_body(mut self, enabled: bool) -> Self {
1211 self.log_body = enabled;
1212 self
1213 }
1214
1215 #[must_use]
1217 pub fn max_body_bytes(mut self, max: usize) -> Self {
1218 self.max_body_bytes = max;
1219 self
1220 }
1221
1222 #[must_use]
1224 pub fn redact_header(mut self, name: impl Into<String>) -> Self {
1225 self.redact_headers.insert(name.into().to_ascii_lowercase());
1226 self
1227 }
1228}
1229
1230#[derive(Debug, Clone)]
1231struct RequestStart(Instant);
1232
1233impl Middleware for RequestResponseLogger {
1234 fn before<'a>(
1235 &'a self,
1236 ctx: &'a RequestContext,
1237 req: &'a mut Request,
1238 ) -> BoxFuture<'a, ControlFlow> {
1239 let logger = RequestLogger::new(ctx, self.log_config.clone());
1240 req.insert_extension(RequestStart(Instant::now()));
1241
1242 let method = req.method();
1243 let path = req.path();
1244 let query = req.query();
1245 let body_bytes = body_len(req.body());
1246
1247 logger.info_with_fields("request", |entry| {
1248 let mut entry = entry
1249 .field("method", method)
1250 .field("path", path)
1251 .field("body_bytes", body_bytes);
1252
1253 if let Some(q) = query {
1254 entry = entry.field("query", q);
1255 }
1256
1257 if self.log_request_headers {
1258 let headers = format_headers(req.headers().iter(), &self.redact_headers);
1259 entry = entry.field("headers", headers);
1260 }
1261
1262 if self.log_body {
1263 if let Some(body) = preview_body(req.body(), self.max_body_bytes) {
1264 entry = entry.field("body", body);
1265 }
1266 }
1267
1268 entry
1269 });
1270
1271 Box::pin(async { ControlFlow::Continue })
1272 }
1273
1274 fn after<'a>(
1275 &'a self,
1276 ctx: &'a RequestContext,
1277 req: &'a Request,
1278 response: Response,
1279 ) -> BoxFuture<'a, Response> {
1280 let logger = RequestLogger::new(ctx, self.log_config.clone());
1281 let duration = req
1282 .get_extension::<RequestStart>()
1283 .map(|start| start.0.elapsed())
1284 .unwrap_or_default();
1285
1286 let status = response.status();
1287 let body_bytes = response.body_ref().len();
1288
1289 logger.info_with_fields("response", |entry| {
1290 let mut entry = entry
1291 .field("status", status.as_u16())
1292 .field("duration_us", duration.as_micros())
1293 .field("body_bytes", body_bytes);
1294
1295 if self.log_response_headers {
1296 let headers = format_response_headers(response.headers(), &self.redact_headers);
1297 entry = entry.field("headers", headers);
1298 }
1299
1300 if self.log_body {
1301 if let Some(body) = preview_response_body(response.body_ref(), self.max_body_bytes)
1302 {
1303 entry = entry.field("body", body);
1304 }
1305 }
1306
1307 entry
1308 });
1309
1310 Box::pin(async move { response })
1311 }
1312
1313 fn name(&self) -> &'static str {
1314 "RequestResponseLogger"
1315 }
1316}
1317
1318fn default_redacted_headers() -> HashSet<String> {
1319 [
1320 "authorization",
1321 "proxy-authorization",
1322 "cookie",
1323 "set-cookie",
1324 ]
1325 .iter()
1326 .map(ToString::to_string)
1327 .collect()
1328}
1329
1330fn body_len(body: &Body) -> usize {
1331 match body {
1332 Body::Empty => 0,
1333 Body::Bytes(bytes) => bytes.len(),
1334 Body::Stream { content_length, .. } => content_length.unwrap_or(0),
1335 }
1336}
1337
1338fn preview_body(body: &Body, max_bytes: usize) -> Option<String> {
1339 if max_bytes == 0 {
1340 return None;
1341 }
1342 match body {
1343 Body::Empty => None,
1344 Body::Bytes(bytes) => {
1345 if bytes.is_empty() {
1346 None
1347 } else {
1348 Some(format_bytes(bytes, max_bytes))
1349 }
1350 }
1351 Body::Stream { .. } => None,
1352 }
1353}
1354
1355fn preview_response_body(body: &crate::response::ResponseBody, max_bytes: usize) -> Option<String> {
1356 if max_bytes == 0 {
1357 return None;
1358 }
1359 match body {
1360 crate::response::ResponseBody::Empty => None,
1361 crate::response::ResponseBody::Bytes(bytes) => {
1362 if bytes.is_empty() {
1363 None
1364 } else {
1365 Some(format_bytes(bytes, max_bytes))
1366 }
1367 }
1368 crate::response::ResponseBody::Stream(_) => None,
1369 }
1370}
1371
1372fn format_headers<'a>(
1373 headers: impl Iterator<Item = (&'a str, &'a [u8])>,
1374 redacted: &HashSet<String>,
1375) -> String {
1376 let mut out = String::new();
1377 for (idx, (name, value)) in headers.enumerate() {
1378 if idx > 0 {
1379 out.push_str(", ");
1380 }
1381 out.push_str(name);
1382 out.push('=');
1383
1384 let lowered = name.to_ascii_lowercase();
1385 if redacted.contains(&lowered) {
1386 out.push_str("<redacted>");
1387 continue;
1388 }
1389
1390 match std::str::from_utf8(value) {
1391 Ok(text) => out.push_str(text),
1392 Err(_) => out.push_str("<binary>"),
1393 }
1394 }
1395 out
1396}
1397
1398fn format_response_headers(headers: &[(String, Vec<u8>)], redacted: &HashSet<String>) -> String {
1399 format_headers(
1400 headers
1401 .iter()
1402 .map(|(name, value)| (name.as_str(), value.as_slice())),
1403 redacted,
1404 )
1405}
1406
1407fn format_bytes(bytes: &[u8], max_bytes: usize) -> String {
1408 let limit = max_bytes.min(bytes.len());
1409 match std::str::from_utf8(&bytes[..limit]) {
1410 Ok(text) => {
1411 let mut output = text.to_string();
1412 if bytes.len() > max_bytes {
1413 output.push_str("...");
1414 }
1415 output
1416 }
1417 Err(_) => format!("<{} bytes binary>", bytes.len()),
1418 }
1419}
1420
1421impl From<&crate::response::ResponseBody> for crate::response::ResponseBody {
1423 fn from(body: &crate::response::ResponseBody) -> Self {
1424 match body {
1425 crate::response::ResponseBody::Empty => crate::response::ResponseBody::Empty,
1426 crate::response::ResponseBody::Bytes(b) => {
1427 crate::response::ResponseBody::Bytes(b.clone())
1428 }
1429 crate::response::ResponseBody::Stream(_) => crate::response::ResponseBody::Empty,
1430 }
1431 }
1432}
1433
1434#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1443pub struct RequestId(pub String);
1444
1445impl RequestId {
1446 #[must_use]
1448 pub fn new(id: impl Into<String>) -> Self {
1449 Self(id.into())
1450 }
1451
1452 #[must_use]
1454 pub fn as_str(&self) -> &str {
1455 &self.0
1456 }
1457
1458 #[must_use]
1463 pub fn generate() -> Self {
1464 use std::sync::atomic::{AtomicU64, Ordering};
1465 use std::time::{SystemTime, UNIX_EPOCH};
1466
1467 static COUNTER: AtomicU64 = AtomicU64::new(0);
1468
1469 let timestamp = SystemTime::now()
1470 .duration_since(UNIX_EPOCH)
1471 .map(|d| d.as_micros() as u64)
1472 .unwrap_or(0);
1473 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
1474
1475 Self(format!("{:x}-{:x}", timestamp, counter))
1477 }
1478}
1479
1480impl std::fmt::Display for RequestId {
1481 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1482 write!(f, "{}", self.0)
1483 }
1484}
1485
1486impl From<String> for RequestId {
1487 fn from(s: String) -> Self {
1488 Self(s)
1489 }
1490}
1491
1492impl From<&str> for RequestId {
1493 fn from(s: &str) -> Self {
1494 Self(s.to_string())
1495 }
1496}
1497
1498#[derive(Debug, Clone)]
1500pub struct RequestIdConfig {
1501 pub header_name: String,
1503 pub accept_from_client: bool,
1505 pub add_to_response: bool,
1507 pub max_client_id_length: usize,
1509}
1510
1511impl Default for RequestIdConfig {
1512 fn default() -> Self {
1513 Self {
1514 header_name: "x-request-id".to_string(),
1515 accept_from_client: true,
1516 add_to_response: true,
1517 max_client_id_length: 128,
1518 }
1519 }
1520}
1521
1522impl RequestIdConfig {
1523 #[must_use]
1525 pub fn new() -> Self {
1526 Self::default()
1527 }
1528
1529 #[must_use]
1531 pub fn header_name(mut self, name: impl Into<String>) -> Self {
1532 self.header_name = name.into();
1533 self
1534 }
1535
1536 #[must_use]
1538 pub fn accept_from_client(mut self, accept: bool) -> Self {
1539 self.accept_from_client = accept;
1540 self
1541 }
1542
1543 #[must_use]
1545 pub fn add_to_response(mut self, add: bool) -> Self {
1546 self.add_to_response = add;
1547 self
1548 }
1549
1550 #[must_use]
1552 pub fn max_client_id_length(mut self, max: usize) -> Self {
1553 self.max_client_id_length = max;
1554 self
1555 }
1556}
1557
1558#[derive(Debug, Clone)]
1583pub struct RequestIdMiddleware {
1584 config: RequestIdConfig,
1585}
1586
1587impl Default for RequestIdMiddleware {
1588 fn default() -> Self {
1589 Self::new()
1590 }
1591}
1592
1593impl RequestIdMiddleware {
1594 #[must_use]
1596 pub fn new() -> Self {
1597 Self {
1598 config: RequestIdConfig::default(),
1599 }
1600 }
1601
1602 #[must_use]
1604 pub fn with_config(config: RequestIdConfig) -> Self {
1605 Self { config }
1606 }
1607
1608 fn get_or_generate_id(&self, req: &Request) -> RequestId {
1610 if self.config.accept_from_client {
1611 if let Some(header_value) = req.headers().get(&self.config.header_name) {
1612 if let Ok(client_id) = std::str::from_utf8(header_value) {
1613 if !client_id.is_empty()
1615 && client_id.len() <= self.config.max_client_id_length
1616 && is_valid_request_id(client_id)
1617 {
1618 return RequestId::new(client_id);
1619 }
1620 }
1621 }
1622 }
1623 RequestId::generate()
1624 }
1625}
1626
1627fn is_valid_request_id(id: &str) -> bool {
1629 !id.is_empty()
1630 && id
1631 .chars()
1632 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
1633}
1634
1635impl Middleware for RequestIdMiddleware {
1636 fn before<'a>(
1637 &'a self,
1638 _ctx: &'a RequestContext,
1639 req: &'a mut Request,
1640 ) -> BoxFuture<'a, ControlFlow> {
1641 let request_id = self.get_or_generate_id(req);
1642 req.insert_extension(request_id);
1643 Box::pin(async { ControlFlow::Continue })
1644 }
1645
1646 fn after<'a>(
1647 &'a self,
1648 _ctx: &'a RequestContext,
1649 req: &'a Request,
1650 response: Response,
1651 ) -> BoxFuture<'a, Response> {
1652 if !self.config.add_to_response {
1653 return Box::pin(async move { response });
1654 }
1655
1656 let request_id = req.get_extension::<RequestId>().cloned();
1657 let header_name = self.config.header_name.clone();
1658
1659 Box::pin(async move {
1660 if let Some(id) = request_id {
1661 response.header(header_name, id.0.into_bytes())
1662 } else {
1663 response
1664 }
1665 })
1666 }
1667
1668 fn name(&self) -> &'static str {
1669 "RequestId"
1670 }
1671}
1672
1673#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1681pub enum XFrameOptions {
1682 Deny,
1684 SameOrigin,
1686}
1687
1688impl XFrameOptions {
1689 fn as_bytes(self) -> &'static [u8] {
1690 match self {
1691 Self::Deny => b"DENY",
1692 Self::SameOrigin => b"SAMEORIGIN",
1693 }
1694 }
1695}
1696
1697#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1701pub enum ReferrerPolicy {
1702 NoReferrer,
1704 NoReferrerWhenDowngrade,
1706 Origin,
1708 OriginWhenCrossOrigin,
1710 SameOrigin,
1712 StrictOrigin,
1714 StrictOriginWhenCrossOrigin,
1716 UnsafeUrl,
1718}
1719
1720impl ReferrerPolicy {
1721 fn as_bytes(self) -> &'static [u8] {
1722 match self {
1723 Self::NoReferrer => b"no-referrer",
1724 Self::NoReferrerWhenDowngrade => b"no-referrer-when-downgrade",
1725 Self::Origin => b"origin",
1726 Self::OriginWhenCrossOrigin => b"origin-when-cross-origin",
1727 Self::SameOrigin => b"same-origin",
1728 Self::StrictOrigin => b"strict-origin",
1729 Self::StrictOriginWhenCrossOrigin => b"strict-origin-when-cross-origin",
1730 Self::UnsafeUrl => b"unsafe-url",
1731 }
1732 }
1733}
1734
1735#[derive(Debug, Clone)]
1759pub struct SecurityHeadersConfig {
1760 pub x_content_type_options: Option<&'static str>,
1763 pub x_frame_options: Option<XFrameOptions>,
1766 pub x_xss_protection: Option<&'static str>,
1772 pub content_security_policy: Option<String>,
1775 pub hsts: Option<(u64, bool, bool)>,
1779 pub referrer_policy: Option<ReferrerPolicy>,
1782 pub permissions_policy: Option<String>,
1785}
1786
1787impl Default for SecurityHeadersConfig {
1788 fn default() -> Self {
1789 Self {
1790 x_content_type_options: Some("nosniff"),
1791 x_frame_options: Some(XFrameOptions::Deny),
1792 x_xss_protection: Some("0"),
1793 content_security_policy: None,
1794 hsts: None,
1795 referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
1796 permissions_policy: None,
1797 }
1798 }
1799}
1800
1801impl SecurityHeadersConfig {
1802 #[must_use]
1804 pub fn new() -> Self {
1805 Self::default()
1806 }
1807
1808 #[must_use]
1810 pub fn none() -> Self {
1811 Self {
1812 x_content_type_options: None,
1813 x_frame_options: None,
1814 x_xss_protection: None,
1815 content_security_policy: None,
1816 hsts: None,
1817 referrer_policy: None,
1818 permissions_policy: None,
1819 }
1820 }
1821
1822 #[must_use]
1829 pub fn strict() -> Self {
1830 Self {
1831 x_content_type_options: Some("nosniff"),
1832 x_frame_options: Some(XFrameOptions::Deny),
1833 x_xss_protection: Some("0"),
1834 content_security_policy: Some("default-src 'self'".to_string()),
1835 hsts: Some((31536000, true, false)), referrer_policy: Some(ReferrerPolicy::NoReferrer),
1837 permissions_policy: Some("geolocation=(), camera=(), microphone=()".to_string()),
1838 }
1839 }
1840
1841 #[must_use]
1843 pub fn x_content_type_options(mut self, value: Option<&'static str>) -> Self {
1844 self.x_content_type_options = value;
1845 self
1846 }
1847
1848 #[must_use]
1850 pub fn x_frame_options(mut self, value: Option<XFrameOptions>) -> Self {
1851 self.x_frame_options = value;
1852 self
1853 }
1854
1855 #[must_use]
1857 pub fn x_xss_protection(mut self, value: Option<&'static str>) -> Self {
1858 self.x_xss_protection = value;
1859 self
1860 }
1861
1862 #[must_use]
1864 pub fn content_security_policy(mut self, value: impl Into<String>) -> Self {
1865 self.content_security_policy = Some(value.into());
1866 self
1867 }
1868
1869 #[must_use]
1871 pub fn no_content_security_policy(mut self) -> Self {
1872 self.content_security_policy = None;
1873 self
1874 }
1875
1876 #[must_use]
1889 pub fn hsts(mut self, max_age: u64, include_sub_domains: bool, preload: bool) -> Self {
1890 self.hsts = Some((max_age, include_sub_domains, preload));
1891 self
1892 }
1893
1894 #[must_use]
1896 pub fn no_hsts(mut self) -> Self {
1897 self.hsts = None;
1898 self
1899 }
1900
1901 #[must_use]
1903 pub fn referrer_policy(mut self, value: Option<ReferrerPolicy>) -> Self {
1904 self.referrer_policy = value;
1905 self
1906 }
1907
1908 #[must_use]
1910 pub fn permissions_policy(mut self, value: impl Into<String>) -> Self {
1911 self.permissions_policy = Some(value.into());
1912 self
1913 }
1914
1915 #[must_use]
1917 pub fn no_permissions_policy(mut self) -> Self {
1918 self.permissions_policy = None;
1919 self
1920 }
1921
1922 fn build_hsts_value(&self) -> Option<String> {
1924 self.hsts.map(|(max_age, include_sub, preload)| {
1925 let mut value = format!("max-age={}", max_age);
1926 if include_sub {
1927 value.push_str("; includeSubDomains");
1928 }
1929 if preload {
1930 value.push_str("; preload");
1931 }
1932 value
1933 })
1934 }
1935}
1936
1937#[derive(Debug, Clone)]
1968pub struct SecurityHeaders {
1969 config: SecurityHeadersConfig,
1970}
1971
1972impl Default for SecurityHeaders {
1973 fn default() -> Self {
1974 Self::new()
1975 }
1976}
1977
1978impl SecurityHeaders {
1979 #[must_use]
1981 pub fn new() -> Self {
1982 Self {
1983 config: SecurityHeadersConfig::default(),
1984 }
1985 }
1986
1987 #[must_use]
1989 pub fn with_config(config: SecurityHeadersConfig) -> Self {
1990 Self { config }
1991 }
1992
1993 #[must_use]
1995 pub fn strict() -> Self {
1996 Self {
1997 config: SecurityHeadersConfig::strict(),
1998 }
1999 }
2000}
2001
2002impl Middleware for SecurityHeaders {
2003 fn after<'a>(
2004 &'a self,
2005 _ctx: &'a RequestContext,
2006 _req: &'a Request,
2007 response: Response,
2008 ) -> BoxFuture<'a, Response> {
2009 let config = self.config.clone();
2010 Box::pin(async move {
2011 let mut resp = response;
2012
2013 if let Some(value) = config.x_content_type_options {
2015 resp = resp.header("X-Content-Type-Options", value.as_bytes().to_vec());
2016 }
2017
2018 if let Some(value) = config.x_frame_options {
2020 resp = resp.header("X-Frame-Options", value.as_bytes().to_vec());
2021 }
2022
2023 if let Some(value) = config.x_xss_protection {
2025 resp = resp.header("X-XSS-Protection", value.as_bytes().to_vec());
2026 }
2027
2028 if let Some(ref value) = config.content_security_policy {
2030 resp = resp.header("Content-Security-Policy", value.as_bytes().to_vec());
2031 }
2032
2033 if let Some(ref hsts_value) = config.build_hsts_value() {
2035 resp = resp.header("Strict-Transport-Security", hsts_value.as_bytes().to_vec());
2036 }
2037
2038 if let Some(value) = config.referrer_policy {
2040 resp = resp.header("Referrer-Policy", value.as_bytes().to_vec());
2041 }
2042
2043 if let Some(ref value) = config.permissions_policy {
2045 resp = resp.header("Permissions-Policy", value.as_bytes().to_vec());
2046 }
2047
2048 resp
2049 })
2050 }
2051
2052 fn name(&self) -> &'static str {
2053 "SecurityHeaders"
2054 }
2055}
2056
2057#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2066pub struct CsrfToken(pub String);
2067
2068impl CsrfToken {
2069 #[must_use]
2071 pub fn new(token: impl Into<String>) -> Self {
2072 Self(token.into())
2073 }
2074
2075 #[must_use]
2077 pub fn as_str(&self) -> &str {
2078 &self.0
2079 }
2080
2081 #[must_use]
2090 pub fn generate() -> Self {
2091 let bytes = Self::read_urandom(32).unwrap_or_else(|_| {
2093 panic!(
2094 "FATAL: Cryptographically secure random source (/dev/urandom) is unavailable. \
2095 CSRF token generation requires a CSPRNG. Cannot safely generate CSRF tokens \
2096 without cryptographic entropy."
2097 );
2098 });
2099 Self(Self::bytes_to_hex(&bytes))
2100 }
2101
2102 fn read_urandom(len: usize) -> std::io::Result<Vec<u8>> {
2103 use std::io::Read;
2104 let mut f = std::fs::File::open("/dev/urandom")?;
2105 let mut buf = vec![0u8; len];
2106 f.read_exact(&mut buf)?;
2107 Ok(buf)
2108 }
2109
2110 fn bytes_to_hex(bytes: &[u8]) -> String {
2111 use std::fmt::Write;
2112 let mut s = String::with_capacity(bytes.len() * 2);
2113 for b in bytes {
2114 let _ = write!(s, "{b:02x}");
2115 }
2116 s
2117 }
2118}
2119
2120impl std::fmt::Display for CsrfToken {
2121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2122 f.write_str(&self.0)
2123 }
2124}
2125
2126impl From<&str> for CsrfToken {
2127 fn from(s: &str) -> Self {
2128 Self(s.to_string())
2129 }
2130}
2131
2132#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
2134pub enum CsrfMode {
2135 #[default]
2138 DoubleSubmit,
2139 HeaderOnly,
2141}
2142
2143#[derive(Debug, Clone)]
2145pub struct CsrfConfig {
2146 pub cookie_name: String,
2148 pub header_name: String,
2150 pub mode: CsrfMode,
2152 pub rotate_token: bool,
2154 pub production: bool,
2156 pub error_message: Option<String>,
2158}
2159
2160impl Default for CsrfConfig {
2161 fn default() -> Self {
2162 Self {
2163 cookie_name: "csrf_token".to_string(),
2164 header_name: "x-csrf-token".to_string(),
2165 mode: CsrfMode::DoubleSubmit,
2166 rotate_token: false,
2167 production: true,
2168 error_message: None,
2169 }
2170 }
2171}
2172
2173impl CsrfConfig {
2174 #[must_use]
2176 pub fn new() -> Self {
2177 Self::default()
2178 }
2179
2180 #[must_use]
2182 pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
2183 self.cookie_name = name.into();
2184 self
2185 }
2186
2187 #[must_use]
2189 pub fn header_name(mut self, name: impl Into<String>) -> Self {
2190 self.header_name = name.into();
2191 self
2192 }
2193
2194 #[must_use]
2196 pub fn mode(mut self, mode: CsrfMode) -> Self {
2197 self.mode = mode;
2198 self
2199 }
2200
2201 #[must_use]
2203 pub fn rotate_token(mut self, rotate: bool) -> Self {
2204 self.rotate_token = rotate;
2205 self
2206 }
2207
2208 #[must_use]
2210 pub fn production(mut self, production: bool) -> Self {
2211 self.production = production;
2212 self
2213 }
2214
2215 #[must_use]
2217 pub fn error_message(mut self, message: impl Into<String>) -> Self {
2218 self.error_message = Some(message.into());
2219 self
2220 }
2221}
2222
2223#[derive(Debug, Clone)]
2253pub struct CsrfMiddleware {
2254 config: CsrfConfig,
2255}
2256
2257impl Default for CsrfMiddleware {
2258 fn default() -> Self {
2259 Self::new()
2260 }
2261}
2262
2263impl CsrfMiddleware {
2264 #[must_use]
2266 pub fn new() -> Self {
2267 Self {
2268 config: CsrfConfig::default(),
2269 }
2270 }
2271
2272 #[must_use]
2274 pub fn with_config(config: CsrfConfig) -> Self {
2275 Self { config }
2276 }
2277
2278 fn is_safe_method(method: crate::request::Method) -> bool {
2280 matches!(
2281 method,
2282 crate::request::Method::Get
2283 | crate::request::Method::Head
2284 | crate::request::Method::Options
2285 | crate::request::Method::Trace
2286 )
2287 }
2288
2289 fn get_cookie_token(&self, req: &Request) -> Option<String> {
2291 let cookie_header = req.headers().get("cookie")?;
2292 let cookie_str = std::str::from_utf8(cookie_header).ok()?;
2293
2294 for part in cookie_str.split(';') {
2296 let part = part.trim();
2297 if let Some((name, value)) = part.split_once('=') {
2298 if name.trim() == self.config.cookie_name {
2299 return Some(value.trim().to_string());
2300 }
2301 }
2302 }
2303 None
2304 }
2305
2306 fn get_header_token(&self, req: &Request) -> Option<String> {
2308 let header_value = req.headers().get(&self.config.header_name)?;
2309 std::str::from_utf8(header_value)
2310 .ok()
2311 .map(|s| s.trim().to_string())
2312 }
2313
2314 fn validate_token(&self, req: &Request) -> Result<Option<CsrfToken>, Response> {
2316 let header_token = self.get_header_token(req);
2317
2318 match self.config.mode {
2319 CsrfMode::DoubleSubmit => {
2320 let cookie_token = self.get_cookie_token(req);
2321
2322 match (header_token, cookie_token) {
2323 (Some(header), Some(cookie))
2324 if !header.is_empty()
2325 && crate::password::constant_time_eq(
2326 header.as_bytes(),
2327 cookie.as_bytes(),
2328 ) =>
2329 {
2330 Ok(Some(CsrfToken::new(header)))
2331 }
2332 (None, _) | (_, None) => Err(self.csrf_error_response("CSRF token missing")),
2333 _ => Err(self.csrf_error_response("CSRF token mismatch")),
2334 }
2335 }
2336 CsrfMode::HeaderOnly => match header_token {
2337 Some(token) if !token.is_empty() => Ok(Some(CsrfToken::new(token))),
2338 _ => Err(self.csrf_error_response("CSRF token missing in header")),
2339 },
2340 }
2341 }
2342
2343 fn csrf_error_response(&self, default_message: &str) -> Response {
2345 let message = self
2346 .config
2347 .error_message
2348 .as_deref()
2349 .unwrap_or(default_message);
2350
2351 let detail = serde_json::json!({
2354 "detail": [{
2355 "type": "csrf_error",
2356 "loc": ["header", self.config.header_name],
2357 "msg": message,
2358 }]
2359 });
2360 let body = detail.to_string();
2361
2362 Response::with_status(crate::response::StatusCode::FORBIDDEN)
2363 .header("content-type", b"application/json".to_vec())
2364 .body(crate::response::ResponseBody::Bytes(body.into_bytes()))
2365 }
2366
2367 fn make_set_cookie_header_value(cookie_name: &str, token: &str, production: bool) -> Vec<u8> {
2369 let mut cookie = format!("{}={}; Path=/; SameSite=Strict", cookie_name, token);
2370
2371 if production {
2372 cookie.push_str("; Secure");
2373 }
2374
2375 cookie.into_bytes()
2378 }
2379}
2380
2381impl Middleware for CsrfMiddleware {
2382 fn before<'a>(
2383 &'a self,
2384 _ctx: &'a RequestContext,
2385 req: &'a mut Request,
2386 ) -> BoxFuture<'a, ControlFlow> {
2387 Box::pin(async move {
2388 if Self::is_safe_method(req.method()) {
2389 let existing_token = self.get_cookie_token(req);
2391 let token = existing_token
2392 .map(CsrfToken::new)
2393 .unwrap_or_else(CsrfToken::generate);
2394 req.insert_extension(token);
2395 ControlFlow::Continue
2396 } else {
2397 match self.validate_token(req) {
2399 Ok(Some(token)) => {
2400 req.insert_extension(token);
2401 ControlFlow::Continue
2402 }
2403 Ok(None) => ControlFlow::Continue,
2404 Err(response) => ControlFlow::Break(response),
2405 }
2406 }
2407 })
2408 }
2409
2410 fn after<'a>(
2411 &'a self,
2412 _ctx: &'a RequestContext,
2413 req: &'a Request,
2414 response: Response,
2415 ) -> BoxFuture<'a, Response> {
2416 let config = self.config.clone();
2417 let is_safe = Self::is_safe_method(req.method());
2418 let existing_cookie_token = self.get_cookie_token(req);
2419 let token = req.get_extension::<CsrfToken>().cloned();
2420
2421 Box::pin(async move {
2422 if is_safe {
2426 let should_set_cookie = existing_cookie_token.is_none() || config.rotate_token;
2427
2428 if should_set_cookie {
2429 if let Some(token) = token {
2430 let cookie_value = Self::make_set_cookie_header_value(
2431 &config.cookie_name,
2432 token.as_str(),
2433 config.production,
2434 );
2435 return response.header("set-cookie", cookie_value);
2436 }
2437 }
2438 }
2439 response
2440 })
2441 }
2442
2443 fn name(&self) -> &'static str {
2444 "CSRF"
2445 }
2446}
2447
2448#[cfg(feature = "compression")]
2471#[derive(Debug, Clone)]
2472pub struct CompressionConfig {
2473 pub min_size: usize,
2477 pub level: u32,
2480 pub skip_content_types: Vec<&'static str>,
2483}
2484
2485#[cfg(feature = "compression")]
2486impl Default for CompressionConfig {
2487 fn default() -> Self {
2488 Self {
2489 min_size: 1024,
2490 level: 6,
2491 skip_content_types: vec![
2492 "image/jpeg",
2494 "image/png",
2495 "image/gif",
2496 "image/webp",
2497 "image/avif",
2498 "video/",
2500 "audio/",
2501 "application/zip",
2503 "application/gzip",
2504 "application/x-gzip",
2505 "application/x-bzip2",
2506 "application/x-xz",
2507 "application/x-7z-compressed",
2508 "application/x-rar-compressed",
2509 "application/pdf",
2511 "application/woff",
2512 "application/woff2",
2513 "font/woff",
2514 "font/woff2",
2515 ],
2516 }
2517 }
2518}
2519
2520#[cfg(feature = "compression")]
2521impl CompressionConfig {
2522 #[must_use]
2524 pub fn new() -> Self {
2525 Self::default()
2526 }
2527
2528 #[must_use]
2533 pub fn min_size(mut self, size: usize) -> Self {
2534 self.min_size = size;
2535 self
2536 }
2537
2538 #[must_use]
2546 pub fn level(mut self, level: u32) -> Self {
2547 self.level = level.clamp(1, 9);
2548 self
2549 }
2550
2551 #[must_use]
2555 pub fn skip_content_type(mut self, content_type: &'static str) -> Self {
2556 self.skip_content_types.push(content_type);
2557 self
2558 }
2559
2560 fn should_skip_content_type(&self, content_type: &str) -> bool {
2562 let ct_lower = content_type.to_ascii_lowercase();
2563 for skip in &self.skip_content_types {
2564 if skip.ends_with('/') {
2565 if ct_lower.starts_with(*skip) {
2567 return true;
2568 }
2569 } else {
2570 if ct_lower == *skip || ct_lower.starts_with(&format!("{skip};")) {
2572 return true;
2573 }
2574 }
2575 }
2576 false
2577 }
2578}
2579
2580#[cfg(feature = "compression")]
2613#[derive(Debug, Clone)]
2614pub struct CompressionMiddleware {
2615 config: CompressionConfig,
2616}
2617
2618#[cfg(feature = "compression")]
2619impl Default for CompressionMiddleware {
2620 fn default() -> Self {
2621 Self::new()
2622 }
2623}
2624
2625#[cfg(feature = "compression")]
2626impl CompressionMiddleware {
2627 #[must_use]
2629 pub fn new() -> Self {
2630 Self {
2631 config: CompressionConfig::default(),
2632 }
2633 }
2634
2635 #[must_use]
2637 pub fn with_config(config: CompressionConfig) -> Self {
2638 Self { config }
2639 }
2640
2641 fn accepts_gzip(req: &Request) -> bool {
2643 if let Some(accept_encoding) = req.headers().get("accept-encoding") {
2644 if let Ok(value) = std::str::from_utf8(accept_encoding) {
2645 for part in value.split(',') {
2648 let encoding = part.trim().split(';').next().unwrap_or("").trim();
2649 if encoding.eq_ignore_ascii_case("gzip") {
2650 return true;
2651 }
2652 if encoding == "*" {
2654 return true;
2655 }
2656 }
2657 }
2658 }
2659 false
2660 }
2661
2662 fn get_content_type(headers: &[(String, Vec<u8>)]) -> Option<String> {
2664 for (name, value) in headers {
2665 if name.eq_ignore_ascii_case("content-type") {
2666 return std::str::from_utf8(value).ok().map(String::from);
2667 }
2668 }
2669 None
2670 }
2671
2672 fn has_content_encoding(headers: &[(String, Vec<u8>)]) -> bool {
2674 headers
2675 .iter()
2676 .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
2677 }
2678
2679 fn compress_gzip(data: &[u8], level: u32) -> Result<Vec<u8>, std::io::Error> {
2681 use flate2::Compression;
2682 use flate2::write::GzEncoder;
2683 use std::io::Write;
2684
2685 let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
2686 encoder.write_all(data)?;
2687 encoder.finish()
2688 }
2689}
2690
2691#[cfg(feature = "compression")]
2692impl Middleware for CompressionMiddleware {
2693 fn after<'a>(
2694 &'a self,
2695 _ctx: &'a RequestContext,
2696 req: &'a Request,
2697 response: Response,
2698 ) -> BoxFuture<'a, Response> {
2699 let config = self.config.clone();
2700
2701 Box::pin(async move {
2702 if !Self::accepts_gzip(req) {
2704 return response;
2705 }
2706
2707 let (status, headers, body) = response.into_parts();
2709
2710 if Self::has_content_encoding(&headers) {
2712 return Response::with_status(status)
2713 .body(body)
2714 .rebuild_with_headers(headers);
2715 }
2716
2717 let body_bytes = match body {
2719 crate::response::ResponseBody::Bytes(bytes) => bytes,
2720 other => {
2721 return Response::with_status(status)
2723 .body(other)
2724 .rebuild_with_headers(headers);
2725 }
2726 };
2727
2728 if body_bytes.len() < config.min_size {
2730 return Response::with_status(status)
2731 .body(crate::response::ResponseBody::Bytes(body_bytes))
2732 .rebuild_with_headers(headers);
2733 }
2734
2735 if let Some(content_type) = Self::get_content_type(&headers) {
2737 if config.should_skip_content_type(&content_type) {
2738 return Response::with_status(status)
2739 .body(crate::response::ResponseBody::Bytes(body_bytes))
2740 .rebuild_with_headers(headers);
2741 }
2742 }
2743
2744 match Self::compress_gzip(&body_bytes, config.level) {
2746 Ok(compressed) => {
2747 if compressed.len() >= body_bytes.len() {
2749 return Response::with_status(status)
2750 .body(crate::response::ResponseBody::Bytes(body_bytes))
2751 .rebuild_with_headers(headers);
2752 }
2753
2754 let mut resp = Response::with_status(status)
2756 .body(crate::response::ResponseBody::Bytes(compressed));
2757
2758 for (name, value) in headers {
2760 if !name.eq_ignore_ascii_case("content-length") {
2761 resp = resp.header(name, value);
2762 }
2763 }
2764
2765 resp = resp.header("Content-Encoding", b"gzip".to_vec());
2767 resp = resp.header("Vary", b"Accept-Encoding".to_vec());
2768
2769 resp
2770 }
2771 Err(_) => {
2772 Response::with_status(status)
2774 .body(crate::response::ResponseBody::Bytes(body_bytes))
2775 .rebuild_with_headers(headers)
2776 }
2777 }
2778 })
2779 }
2780
2781 fn name(&self) -> &'static str {
2782 "Compression"
2783 }
2784}
2785
2786use parking_lot::Mutex;
2791use std::collections::HashMap as StdHashMap;
2792use std::time::Duration;
2793
2794#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2796pub enum RateLimitAlgorithm {
2797 TokenBucket,
2799 FixedWindow,
2801 SlidingWindow,
2803}
2804
2805#[derive(Debug, Clone)]
2807pub struct RateLimitResult {
2808 pub allowed: bool,
2810 pub limit: u64,
2812 pub remaining: u64,
2814 pub reset_after_secs: u64,
2816}
2817
2818pub trait KeyExtractor: Send + Sync {
2823 fn extract_key(&self, req: &Request) -> Option<String>;
2827}
2828
2829#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
2847pub struct RemoteAddr(pub std::net::IpAddr);
2848
2849impl std::fmt::Display for RemoteAddr {
2850 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2851 write!(f, "{}", self.0)
2852 }
2853}
2854
2855#[derive(Debug, Clone)]
2876pub struct ConnectedIpKeyExtractor;
2877
2878impl KeyExtractor for ConnectedIpKeyExtractor {
2879 fn extract_key(&self, req: &Request) -> Option<String> {
2880 req.get_extension::<RemoteAddr>().map(ToString::to_string)
2881 }
2882}
2883
2884#[derive(Debug, Clone)]
2913pub struct IpKeyExtractor;
2914
2915impl KeyExtractor for IpKeyExtractor {
2916 fn extract_key(&self, req: &Request) -> Option<String> {
2917 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
2919 if let Ok(s) = std::str::from_utf8(forwarded) {
2920 if let Some(ip) = s.split(',').next() {
2922 return Some(ip.trim().to_string());
2923 }
2924 }
2925 }
2926 if let Some(real_ip) = req.headers().get("x-real-ip") {
2927 if let Ok(s) = std::str::from_utf8(real_ip) {
2928 return Some(s.trim().to_string());
2929 }
2930 }
2931 Some("unknown".to_string())
2932 }
2933}
2934
2935#[derive(Debug, Clone)]
2966pub struct TrustedProxyIpKeyExtractor {
2967 trusted_cidrs: Vec<(std::net::IpAddr, u8)>,
2969}
2970
2971impl TrustedProxyIpKeyExtractor {
2972 #[must_use]
2974 pub fn new() -> Self {
2975 Self {
2976 trusted_cidrs: Vec::new(),
2977 }
2978 }
2979
2980 #[must_use]
2986 pub fn trust_cidr(mut self, cidr: &str) -> Self {
2987 let (ip, prefix) = parse_cidr(cidr).expect("invalid CIDR notation");
2988 self.trusted_cidrs.push((ip, prefix));
2989 self
2990 }
2991
2992 #[must_use]
2994 pub fn trust_loopback(mut self) -> Self {
2995 self.trusted_cidrs.push((
2996 std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 0)),
2997 8,
2998 ));
2999 self.trusted_cidrs
3000 .push((std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST), 128));
3001 self
3002 }
3003
3004 fn is_trusted(&self, ip: std::net::IpAddr) -> bool {
3006 self.trusted_cidrs
3007 .iter()
3008 .any(|(cidr_ip, prefix)| ip_in_cidr(ip, *cidr_ip, *prefix))
3009 }
3010
3011 fn extract_from_header(&self, req: &Request) -> Option<String> {
3013 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
3014 if let Ok(s) = std::str::from_utf8(forwarded) {
3015 if let Some(ip) = s.split(',').next() {
3016 return Some(ip.trim().to_string());
3017 }
3018 }
3019 }
3020 if let Some(real_ip) = req.headers().get("x-real-ip") {
3021 if let Ok(s) = std::str::from_utf8(real_ip) {
3022 return Some(s.trim().to_string());
3023 }
3024 }
3025 None
3026 }
3027}
3028
3029impl Default for TrustedProxyIpKeyExtractor {
3030 fn default() -> Self {
3031 Self::new()
3032 }
3033}
3034
3035impl KeyExtractor for TrustedProxyIpKeyExtractor {
3036 fn extract_key(&self, req: &Request) -> Option<String> {
3037 let remote = req.get_extension::<RemoteAddr>()?;
3038
3039 if self.is_trusted(remote.0) {
3040 self.extract_from_header(req)
3042 .or_else(|| Some(remote.to_string()))
3043 } else {
3044 Some(remote.to_string())
3046 }
3047 }
3048}
3049
3050fn parse_cidr(cidr: &str) -> Option<(std::net::IpAddr, u8)> {
3052 let (ip_str, prefix_str) = cidr.split_once('/')?;
3053 let ip: std::net::IpAddr = ip_str.parse().ok()?;
3054 let prefix: u8 = prefix_str.parse().ok()?;
3055
3056 let max_prefix = match ip {
3058 std::net::IpAddr::V4(_) => 32,
3059 std::net::IpAddr::V6(_) => 128,
3060 };
3061 if prefix > max_prefix {
3062 return None;
3063 }
3064
3065 Some((ip, prefix))
3066}
3067
3068fn ip_in_cidr(ip: std::net::IpAddr, cidr_ip: std::net::IpAddr, prefix: u8) -> bool {
3070 match (ip, cidr_ip) {
3071 (std::net::IpAddr::V4(ip), std::net::IpAddr::V4(cidr)) => {
3072 if prefix == 0 {
3073 return true;
3074 }
3075 let ip_bits = u32::from(ip);
3076 let cidr_bits = u32::from(cidr);
3077 let mask = !0u32 << (32 - prefix);
3078 (ip_bits & mask) == (cidr_bits & mask)
3079 }
3080 (std::net::IpAddr::V6(ip), std::net::IpAddr::V6(cidr)) => {
3081 if prefix == 0 {
3082 return true;
3083 }
3084 let ip_bits = u128::from(ip);
3085 let cidr_bits = u128::from(cidr);
3086 let mask = !0u128 << (128 - prefix);
3087 (ip_bits & mask) == (cidr_bits & mask)
3088 }
3089 _ => false, }
3091}
3092
3093#[derive(Debug, Clone)]
3095pub struct HeaderKeyExtractor {
3096 header_name: String,
3097}
3098
3099impl HeaderKeyExtractor {
3100 #[must_use]
3102 pub fn new(header_name: impl Into<String>) -> Self {
3103 Self {
3104 header_name: header_name.into(),
3105 }
3106 }
3107}
3108
3109impl KeyExtractor for HeaderKeyExtractor {
3110 fn extract_key(&self, req: &Request) -> Option<String> {
3111 req.headers()
3112 .get(&self.header_name)
3113 .and_then(|v| std::str::from_utf8(v).ok())
3114 .map(str::to_string)
3115 }
3116}
3117
3118#[derive(Debug, Clone)]
3120pub struct PathKeyExtractor;
3121
3122impl KeyExtractor for PathKeyExtractor {
3123 fn extract_key(&self, req: &Request) -> Option<String> {
3124 Some(req.path().to_string())
3125 }
3126}
3127
3128pub struct CompositeKeyExtractor {
3133 extractors: Vec<Box<dyn KeyExtractor>>,
3134}
3135
3136impl CompositeKeyExtractor {
3137 #[must_use]
3139 pub fn new(extractors: Vec<Box<dyn KeyExtractor>>) -> Self {
3140 Self { extractors }
3141 }
3142}
3143
3144impl KeyExtractor for CompositeKeyExtractor {
3145 fn extract_key(&self, req: &Request) -> Option<String> {
3146 let parts: Vec<String> = self
3147 .extractors
3148 .iter()
3149 .filter_map(|e| e.extract_key(req))
3150 .collect();
3151 if parts.is_empty() {
3152 None
3153 } else {
3154 Some(parts.join(":"))
3155 }
3156 }
3157}
3158
3159#[derive(Debug, Clone)]
3161struct TokenBucketState {
3162 tokens: f64,
3163 last_refill: Instant,
3164}
3165
3166#[derive(Debug, Clone)]
3168struct FixedWindowState {
3169 count: u64,
3170 window_start: Instant,
3171}
3172
3173#[derive(Debug, Clone)]
3175struct SlidingWindowState {
3176 current_count: u64,
3177 previous_count: u64,
3178 current_window_start: Instant,
3179}
3180
3181pub struct InMemoryRateLimitStore {
3187 token_buckets: Mutex<StdHashMap<String, TokenBucketState>>,
3188 fixed_windows: Mutex<StdHashMap<String, FixedWindowState>>,
3189 sliding_windows: Mutex<StdHashMap<String, SlidingWindowState>>,
3190}
3191
3192impl InMemoryRateLimitStore {
3193 #[must_use]
3195 pub fn new() -> Self {
3196 Self {
3197 token_buckets: Mutex::new(StdHashMap::new()),
3198 fixed_windows: Mutex::new(StdHashMap::new()),
3199 sliding_windows: Mutex::new(StdHashMap::new()),
3200 }
3201 }
3202
3203 #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
3204 fn check_token_bucket(
3205 &self,
3206 key: &str,
3207 max_tokens: u64,
3208 refill_rate: f64,
3209 window: Duration,
3210 ) -> RateLimitResult {
3211 let mut buckets = self.token_buckets.lock();
3212 let now = Instant::now();
3213
3214 let state = buckets
3215 .entry(key.to_string())
3216 .or_insert_with(|| TokenBucketState {
3217 tokens: max_tokens as f64,
3218 last_refill: now,
3219 });
3220
3221 let elapsed = now.duration_since(state.last_refill);
3223 let refill = elapsed.as_secs_f64() * refill_rate;
3224 state.tokens = (state.tokens + refill).min(max_tokens as f64);
3225 state.last_refill = now;
3226
3227 if state.tokens >= 1.0 {
3228 state.tokens -= 1.0;
3229 RateLimitResult {
3230 allowed: true,
3231 limit: max_tokens,
3232 remaining: state.tokens as u64,
3233 reset_after_secs: if state.tokens < max_tokens as f64 {
3234 ((max_tokens as f64 - state.tokens) / refill_rate).ceil() as u64
3235 } else {
3236 window.as_secs()
3237 },
3238 }
3239 } else {
3240 let wait_secs = ((1.0 - state.tokens) / refill_rate).ceil() as u64;
3241 RateLimitResult {
3242 allowed: false,
3243 limit: max_tokens,
3244 remaining: 0,
3245 reset_after_secs: wait_secs,
3246 }
3247 }
3248 }
3249
3250 fn check_fixed_window(
3251 &self,
3252 key: &str,
3253 max_requests: u64,
3254 window: Duration,
3255 ) -> RateLimitResult {
3256 let mut windows = self.fixed_windows.lock();
3257 let now = Instant::now();
3258
3259 let state = windows
3260 .entry(key.to_string())
3261 .or_insert_with(|| FixedWindowState {
3262 count: 0,
3263 window_start: now,
3264 });
3265
3266 let elapsed = now.duration_since(state.window_start);
3268 if elapsed >= window {
3269 state.count = 0;
3270 state.window_start = now;
3271 }
3272
3273 let remaining_time = window
3274 .checked_sub(now.duration_since(state.window_start))
3275 .unwrap_or(Duration::ZERO);
3276
3277 if state.count < max_requests {
3278 state.count += 1;
3279 RateLimitResult {
3280 allowed: true,
3281 limit: max_requests,
3282 remaining: max_requests - state.count,
3283 reset_after_secs: remaining_time.as_secs(),
3284 }
3285 } else {
3286 RateLimitResult {
3287 allowed: false,
3288 limit: max_requests,
3289 remaining: 0,
3290 reset_after_secs: remaining_time.as_secs(),
3291 }
3292 }
3293 }
3294
3295 #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
3296 fn check_sliding_window(
3297 &self,
3298 key: &str,
3299 max_requests: u64,
3300 window: Duration,
3301 ) -> RateLimitResult {
3302 let mut windows = self.sliding_windows.lock();
3303 let now = Instant::now();
3304
3305 let state = windows
3306 .entry(key.to_string())
3307 .or_insert_with(|| SlidingWindowState {
3308 current_count: 0,
3309 previous_count: 0,
3310 current_window_start: now,
3311 });
3312
3313 let elapsed = now.duration_since(state.current_window_start);
3315 if elapsed >= window {
3316 state.previous_count = state.current_count;
3318 state.current_count = 0;
3319 state.current_window_start = now;
3320 }
3321
3322 let window_elapsed = now.duration_since(state.current_window_start);
3325 let window_fraction = window_elapsed.as_secs_f64() / window.as_secs_f64();
3326 let previous_weight = 1.0 - window_fraction;
3327 let weighted_count =
3328 (state.previous_count as f64 * previous_weight) + state.current_count as f64;
3329
3330 let remaining_time = window.checked_sub(window_elapsed).unwrap_or(Duration::ZERO);
3331
3332 if weighted_count < max_requests as f64 {
3333 state.current_count += 1;
3334 let new_weighted =
3335 (state.previous_count as f64 * previous_weight) + state.current_count as f64;
3336 let remaining = (max_requests as f64 - new_weighted).max(0.0) as u64;
3337 RateLimitResult {
3338 allowed: true,
3339 limit: max_requests,
3340 remaining,
3341 reset_after_secs: remaining_time.as_secs(),
3342 }
3343 } else {
3344 RateLimitResult {
3345 allowed: false,
3346 limit: max_requests,
3347 remaining: 0,
3348 reset_after_secs: remaining_time.as_secs(),
3349 }
3350 }
3351 }
3352
3353 #[allow(clippy::cast_precision_loss)]
3355 pub fn check(
3356 &self,
3357 key: &str,
3358 algorithm: RateLimitAlgorithm,
3359 max_requests: u64,
3360 window: Duration,
3361 ) -> RateLimitResult {
3362 match algorithm {
3363 RateLimitAlgorithm::TokenBucket => {
3364 let refill_rate = max_requests as f64 / window.as_secs_f64();
3365 self.check_token_bucket(key, max_requests, refill_rate, window)
3366 }
3367 RateLimitAlgorithm::FixedWindow => self.check_fixed_window(key, max_requests, window),
3368 RateLimitAlgorithm::SlidingWindow => {
3369 self.check_sliding_window(key, max_requests, window)
3370 }
3371 }
3372 }
3373}
3374
3375impl Default for InMemoryRateLimitStore {
3376 fn default() -> Self {
3377 Self::new()
3378 }
3379}
3380
3381#[derive(Clone)]
3415pub struct RateLimitConfig {
3416 pub max_requests: u64,
3418 pub window: Duration,
3420 pub algorithm: RateLimitAlgorithm,
3422 pub include_headers: bool,
3424 pub retry_message: String,
3426}
3427
3428impl Default for RateLimitConfig {
3429 fn default() -> Self {
3430 Self {
3431 max_requests: 100,
3432 window: Duration::from_secs(60),
3433 algorithm: RateLimitAlgorithm::TokenBucket,
3434 include_headers: true,
3435 retry_message: "Rate limit exceeded. Please retry later.".to_string(),
3436 }
3437 }
3438}
3439
3440pub struct RateLimitBuilder {
3442 config: RateLimitConfig,
3443 key_extractor: Option<Box<dyn KeyExtractor>>,
3444}
3445
3446impl RateLimitBuilder {
3447 #[must_use]
3449 pub fn new() -> Self {
3450 Self {
3451 config: RateLimitConfig::default(),
3452 key_extractor: None,
3453 }
3454 }
3455
3456 #[must_use]
3458 pub fn requests(mut self, max: u64) -> Self {
3459 self.config.max_requests = max;
3460 self
3461 }
3462
3463 #[must_use]
3465 pub fn per(mut self, window: Duration) -> Self {
3466 self.config.window = window;
3467 self
3468 }
3469
3470 #[must_use]
3472 pub fn per_second(self, secs: u64) -> Self {
3473 self.per(Duration::from_secs(secs))
3474 }
3475
3476 #[must_use]
3478 pub fn per_minute(self, minutes: u64) -> Self {
3479 self.per(Duration::from_secs(minutes * 60))
3480 }
3481
3482 #[must_use]
3484 pub fn per_hour(self, hours: u64) -> Self {
3485 self.per(Duration::from_secs(hours * 3600))
3486 }
3487
3488 #[must_use]
3490 pub fn algorithm(mut self, algo: RateLimitAlgorithm) -> Self {
3491 self.config.algorithm = algo;
3492 self
3493 }
3494
3495 #[must_use]
3497 pub fn key_extractor(mut self, extractor: impl KeyExtractor + 'static) -> Self {
3498 self.key_extractor = Some(Box::new(extractor));
3499 self
3500 }
3501
3502 #[must_use]
3504 pub fn include_headers(mut self, include: bool) -> Self {
3505 self.config.include_headers = include;
3506 self
3507 }
3508
3509 #[must_use]
3511 pub fn retry_message(mut self, msg: impl Into<String>) -> Self {
3512 self.config.retry_message = msg.into();
3513 self
3514 }
3515
3516 #[must_use]
3518 pub fn build(self) -> RateLimitMiddleware {
3519 let key_extractor = self
3520 .key_extractor
3521 .unwrap_or_else(|| Box::new(IpKeyExtractor));
3522 RateLimitMiddleware {
3523 config: self.config,
3524 store: Arc::new(InMemoryRateLimitStore::new()),
3525 key_extractor: Arc::from(key_extractor),
3526 }
3527 }
3528}
3529
3530impl Default for RateLimitBuilder {
3531 fn default() -> Self {
3532 Self::new()
3533 }
3534}
3535
3536#[derive(Debug, Clone)]
3538struct RateLimitInfo {
3539 result: RateLimitResult,
3540}
3541
3542pub struct RateLimitMiddleware {
3565 config: RateLimitConfig,
3566 store: Arc<InMemoryRateLimitStore>,
3567 key_extractor: Arc<dyn KeyExtractor>,
3568}
3569
3570impl RateLimitMiddleware {
3571 #[must_use]
3573 pub fn new() -> Self {
3574 Self::builder().build()
3575 }
3576
3577 #[must_use]
3579 pub fn builder() -> RateLimitBuilder {
3580 RateLimitBuilder::new()
3581 }
3582
3583 fn too_many_requests_body(&self, result: &RateLimitResult) -> Vec<u8> {
3585 format!(
3586 r#"{{"detail":"{}","retry_after_secs":{}}}"#,
3587 self.config.retry_message, result.reset_after_secs
3588 )
3589 .into_bytes()
3590 }
3591
3592 fn add_headers(&self, response: Response, result: &RateLimitResult) -> Response {
3594 response
3595 .header("X-RateLimit-Limit", result.limit.to_string().into_bytes())
3596 .header(
3597 "X-RateLimit-Remaining",
3598 result.remaining.to_string().into_bytes(),
3599 )
3600 .header(
3601 "X-RateLimit-Reset",
3602 result.reset_after_secs.to_string().into_bytes(),
3603 )
3604 }
3605}
3606
3607impl Default for RateLimitMiddleware {
3608 fn default() -> Self {
3609 Self::new()
3610 }
3611}
3612
3613impl Middleware for RateLimitMiddleware {
3614 fn before<'a>(
3615 &'a self,
3616 _ctx: &'a RequestContext,
3617 req: &'a mut Request,
3618 ) -> BoxFuture<'a, ControlFlow> {
3619 Box::pin(async move {
3620 let Some(key) = self.key_extractor.extract_key(req) else {
3622 return ControlFlow::Continue;
3624 };
3625
3626 let result = self.store.check(
3628 &key,
3629 self.config.algorithm,
3630 self.config.max_requests,
3631 self.config.window,
3632 );
3633
3634 if result.allowed {
3635 req.insert_extension(RateLimitInfo { result });
3637 ControlFlow::Continue
3638 } else {
3639 let body = self.too_many_requests_body(&result);
3641 let mut response =
3642 Response::with_status(crate::response::StatusCode::TOO_MANY_REQUESTS)
3643 .header("Content-Type", b"application/json".to_vec())
3644 .header(
3645 "Retry-After",
3646 result.reset_after_secs.to_string().into_bytes(),
3647 )
3648 .body(crate::response::ResponseBody::Bytes(body));
3649
3650 if self.config.include_headers {
3651 response = self.add_headers(response, &result);
3652 }
3653
3654 ControlFlow::Break(response)
3655 }
3656 })
3657 }
3658
3659 fn after<'a>(
3660 &'a self,
3661 _ctx: &'a RequestContext,
3662 req: &'a Request,
3663 response: Response,
3664 ) -> BoxFuture<'a, Response> {
3665 Box::pin(async move {
3666 if !self.config.include_headers {
3667 return response;
3668 }
3669
3670 if let Some(info) = req.get_extension::<RateLimitInfo>() {
3672 self.add_headers(response, &info.result)
3673 } else {
3674 response
3675 }
3676 })
3677 }
3678
3679 fn name(&self) -> &'static str {
3680 "RateLimit"
3681 }
3682}
3683
3684#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3696pub enum InspectionVerbosity {
3697 Minimal,
3701
3702 Normal,
3706
3707 Verbose,
3712}
3713
3714pub struct RequestInspectionMiddleware {
3753 log_config: LogConfig,
3754 verbosity: InspectionVerbosity,
3755 redact_headers: HashSet<String>,
3756 slow_threshold_ms: u64,
3757 max_body_preview: usize,
3758}
3759
3760impl Default for RequestInspectionMiddleware {
3761 fn default() -> Self {
3762 Self {
3763 log_config: LogConfig::development(),
3764 verbosity: InspectionVerbosity::Normal,
3765 redact_headers: default_redacted_headers(),
3766 slow_threshold_ms: 1000,
3767 max_body_preview: 2048,
3768 }
3769 }
3770}
3771
3772impl RequestInspectionMiddleware {
3773 #[must_use]
3775 pub fn new() -> Self {
3776 Self::default()
3777 }
3778
3779 #[must_use]
3781 pub fn log_config(mut self, config: LogConfig) -> Self {
3782 self.log_config = config;
3783 self
3784 }
3785
3786 #[must_use]
3788 pub fn verbosity(mut self, level: InspectionVerbosity) -> Self {
3789 self.verbosity = level;
3790 self
3791 }
3792
3793 #[must_use]
3795 pub fn slow_threshold_ms(mut self, ms: u64) -> Self {
3796 self.slow_threshold_ms = ms;
3797 self
3798 }
3799
3800 #[must_use]
3802 pub fn max_body_preview(mut self, max: usize) -> Self {
3803 self.max_body_preview = max;
3804 self
3805 }
3806
3807 #[must_use]
3809 pub fn redact_header(mut self, name: impl Into<String>) -> Self {
3810 self.redact_headers.insert(name.into().to_ascii_lowercase());
3811 self
3812 }
3813
3814 fn format_body_preview(&self, bytes: &[u8], content_type: Option<&[u8]>) -> Option<String> {
3816 if bytes.is_empty() || self.max_body_preview == 0 {
3817 return None;
3818 }
3819
3820 let is_json = content_type
3821 .and_then(|ct| std::str::from_utf8(ct).ok())
3822 .is_some_and(|ct| ct.contains("application/json"));
3823
3824 let limit = self.max_body_preview.min(bytes.len());
3825 let truncated = bytes.len() > self.max_body_preview;
3826
3827 match std::str::from_utf8(&bytes[..limit]) {
3828 Ok(text) => {
3829 if is_json {
3830 if let Some(pretty) = try_pretty_json(text) {
3832 let mut output = pretty;
3833 if truncated {
3834 output.push_str("\n ... (truncated)");
3835 }
3836 return Some(output);
3837 }
3838 }
3839 let mut output = text.to_string();
3840 if truncated {
3841 output.push_str("...");
3842 }
3843 Some(output)
3844 }
3845 Err(_) => Some(format!("<{} bytes binary>", bytes.len())),
3846 }
3847 }
3848
3849 fn format_response_preview(
3851 &self,
3852 body: &crate::response::ResponseBody,
3853 content_type: Option<&[u8]>,
3854 ) -> Option<String> {
3855 match body {
3856 crate::response::ResponseBody::Empty => None,
3857 crate::response::ResponseBody::Bytes(bytes) => {
3858 self.format_body_preview(bytes, content_type)
3859 }
3860 crate::response::ResponseBody::Stream(_) => Some("<streaming body>".to_string()),
3861 }
3862 }
3863
3864 fn format_inspection_headers<'a>(
3866 &self,
3867 headers: impl Iterator<Item = (&'a str, &'a [u8])>,
3868 ) -> String {
3869 let mut out = String::new();
3870 for (name, value) in headers {
3871 out.push_str("\n ");
3872 out.push_str(name);
3873 out.push_str(": ");
3874
3875 let lowered = name.to_ascii_lowercase();
3876 if self.redact_headers.contains(&lowered) {
3877 out.push_str("[REDACTED]");
3878 } else {
3879 match std::str::from_utf8(value) {
3880 Ok(text) => out.push_str(text),
3881 Err(_) => out.push_str("<binary>"),
3882 }
3883 }
3884 }
3885 out
3886 }
3887
3888 fn format_response_inspection_headers(&self, headers: &[(String, Vec<u8>)]) -> String {
3890 self.format_inspection_headers(
3891 headers
3892 .iter()
3893 .map(|(name, value)| (name.as_str(), value.as_slice())),
3894 )
3895 }
3896}
3897
3898#[derive(Debug, Clone)]
3900struct InspectionStart(Instant);
3901
3902impl Middleware for RequestInspectionMiddleware {
3903 fn before<'a>(
3904 &'a self,
3905 ctx: &'a RequestContext,
3906 req: &'a mut Request,
3907 ) -> BoxFuture<'a, ControlFlow> {
3908 let logger = RequestLogger::new(ctx, self.log_config.clone());
3909 req.insert_extension(InspectionStart(Instant::now()));
3910
3911 let method = req.method();
3912 let path = req.path();
3913 let query = req.query();
3914
3915 let mut request_line = format!("--> {method} {path}");
3917 if let Some(q) = query {
3918 request_line.push('?');
3919 request_line.push_str(q);
3920 }
3921
3922 let body_size = body_len(req.body());
3923 if body_size > 0 {
3924 request_line.push_str(&format!(" ({body_size} bytes)"));
3925 }
3926
3927 match self.verbosity {
3928 InspectionVerbosity::Minimal => {
3929 logger.info(request_line);
3930 }
3931 InspectionVerbosity::Normal => {
3932 let headers = self.format_inspection_headers(req.headers().iter());
3933 logger.info(format!("{request_line}{headers}"));
3934 }
3935 InspectionVerbosity::Verbose => {
3936 let headers = self.format_inspection_headers(req.headers().iter());
3937 let content_type = req.headers().get("content-type");
3938 let body_preview = match req.body() {
3939 Body::Empty => None,
3940 Body::Bytes(bytes) => self.format_body_preview(bytes, content_type),
3941 Body::Stream { .. } => None,
3942 };
3943
3944 let mut output = format!("{request_line}{headers}");
3945 if let Some(body) = body_preview {
3946 output.push_str("\n ");
3947 output.push_str(&body.replace('\n', "\n "));
3949 }
3950 logger.info(output);
3951 }
3952 }
3953
3954 Box::pin(async { ControlFlow::Continue })
3955 }
3956
3957 fn after<'a>(
3958 &'a self,
3959 ctx: &'a RequestContext,
3960 req: &'a Request,
3961 response: Response,
3962 ) -> BoxFuture<'a, Response> {
3963 let logger = RequestLogger::new(ctx, self.log_config.clone());
3964 let duration = req
3965 .get_extension::<InspectionStart>()
3966 .map(|start| start.0.elapsed())
3967 .unwrap_or_default();
3968
3969 let status = response.status();
3970 let duration_ms = duration.as_millis();
3971
3972 let mut response_line = format!(
3974 "<-- {} {} ({duration_ms}ms)",
3975 status.as_u16(),
3976 status.canonical_reason(),
3977 );
3978
3979 if duration_ms >= u128::from(self.slow_threshold_ms) {
3981 response_line.push_str(" [SLOW]");
3982 }
3983
3984 match self.verbosity {
3985 InspectionVerbosity::Minimal => {
3986 if duration_ms >= u128::from(self.slow_threshold_ms) {
3987 logger.warn(response_line);
3988 } else {
3989 logger.info(response_line);
3990 }
3991 }
3992 InspectionVerbosity::Normal => {
3993 let headers = self.format_response_inspection_headers(response.headers());
3994 let output = format!("{response_line}{headers}");
3995 if duration_ms >= u128::from(self.slow_threshold_ms) {
3996 logger.warn(output);
3997 } else {
3998 logger.info(output);
3999 }
4000 }
4001 InspectionVerbosity::Verbose => {
4002 let headers = self.format_response_inspection_headers(response.headers());
4003
4004 let resp_content_type: Option<&[u8]> = response
4006 .headers()
4007 .iter()
4008 .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
4009 .map(|(_, value)| value.as_slice());
4010
4011 let body_preview =
4012 self.format_response_preview(response.body_ref(), resp_content_type);
4013
4014 let mut output = format!("{response_line}{headers}");
4015 if let Some(body) = body_preview {
4016 output.push_str("\n ");
4017 output.push_str(&body.replace('\n', "\n "));
4018 }
4019
4020 if duration_ms >= u128::from(self.slow_threshold_ms) {
4021 logger.warn(output);
4022 } else {
4023 logger.info(output);
4024 }
4025 }
4026 }
4027
4028 Box::pin(async move { response })
4029 }
4030
4031 fn name(&self) -> &'static str {
4032 "RequestInspection"
4033 }
4034}
4035
4036fn try_pretty_json(input: &str) -> Option<String> {
4041 let trimmed = input.trim();
4042 if !trimmed.starts_with('{') && !trimmed.starts_with('[') {
4043 return None;
4044 }
4045
4046 let mut output = String::with_capacity(trimmed.len() * 2);
4048 if json_pretty_format(trimmed, &mut output).is_ok() {
4049 Some(output)
4050 } else {
4051 None
4052 }
4053}
4054
4055fn json_pretty_format(input: &str, output: &mut String) -> Result<(), ()> {
4060 let bytes = input.as_bytes();
4061 let mut pos = 0;
4062 let mut indent: usize = 0;
4063 let mut in_string = false;
4064 let mut escape_next = false;
4065
4066 while pos < bytes.len() {
4067 let ch = bytes[pos] as char;
4068
4069 if escape_next {
4070 output.push(ch);
4071 escape_next = false;
4072 pos += 1;
4073 continue;
4074 }
4075
4076 if in_string {
4077 output.push(ch);
4078 if ch == '\\' {
4079 escape_next = true;
4080 } else if ch == '"' {
4081 in_string = false;
4082 }
4083 pos += 1;
4084 continue;
4085 }
4086
4087 match ch {
4088 '"' => {
4089 in_string = true;
4090 output.push('"');
4091 }
4092 '{' | '[' => {
4093 output.push(ch);
4094 let peek = skip_whitespace(bytes, pos + 1);
4096 let closing = if ch == '{' { '}' } else { ']' };
4097 if peek < bytes.len() && bytes[peek] as char == closing {
4098 output.push(closing);
4099 pos = peek + 1;
4100 continue;
4101 }
4102 indent += 1;
4103 output.push('\n');
4104 push_indent(output, indent);
4105 }
4106 '}' | ']' => {
4107 indent = indent.saturating_sub(1);
4108 output.push('\n');
4109 push_indent(output, indent);
4110 output.push(ch);
4111 }
4112 ':' => {
4113 output.push_str(": ");
4114 }
4115 ',' => {
4116 output.push(',');
4117 output.push('\n');
4118 push_indent(output, indent);
4119 }
4120 c if c.is_ascii_whitespace() => {
4121 }
4123 _ => {
4124 output.push(ch);
4125 }
4126 }
4127
4128 pos += 1;
4129 }
4130
4131 if in_string || indent != 0 {
4132 return Err(());
4133 }
4134
4135 Ok(())
4136}
4137
4138fn skip_whitespace(bytes: &[u8], start: usize) -> usize {
4139 let mut i = start;
4140 while i < bytes.len() && (bytes[i] as char).is_ascii_whitespace() {
4141 i += 1;
4142 }
4143 i
4144}
4145
4146fn push_indent(output: &mut String, level: usize) {
4147 for _ in 0..level {
4148 output.push_str(" ");
4149 }
4150}
4151
4152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4162pub enum ETagMode {
4163 Auto,
4166 Manual,
4169 Disabled,
4171}
4172
4173impl Default for ETagMode {
4174 fn default() -> Self {
4175 Self::Auto
4176 }
4177}
4178
4179#[derive(Debug, Clone)]
4181pub struct ETagConfig {
4182 pub mode: ETagMode,
4184 pub weak: bool,
4187 pub min_size: usize,
4190}
4191
4192impl Default for ETagConfig {
4193 fn default() -> Self {
4194 Self {
4195 mode: ETagMode::Auto,
4196 weak: false,
4197 min_size: 0,
4198 }
4199 }
4200}
4201
4202impl ETagConfig {
4203 #[must_use]
4205 pub fn new() -> Self {
4206 Self::default()
4207 }
4208
4209 #[must_use]
4211 pub fn mode(mut self, mode: ETagMode) -> Self {
4212 self.mode = mode;
4213 self
4214 }
4215
4216 #[must_use]
4218 pub fn weak(mut self, weak: bool) -> Self {
4219 self.weak = weak;
4220 self
4221 }
4222
4223 #[must_use]
4225 pub fn min_size(mut self, size: usize) -> Self {
4226 self.min_size = size;
4227 self
4228 }
4229}
4230
4231pub struct ETagMiddleware {
4272 config: ETagConfig,
4273}
4274
4275impl Default for ETagMiddleware {
4276 fn default() -> Self {
4277 Self::new()
4278 }
4279}
4280
4281impl ETagMiddleware {
4282 #[must_use]
4284 pub fn new() -> Self {
4285 Self {
4286 config: ETagConfig::default(),
4287 }
4288 }
4289
4290 #[must_use]
4292 pub fn with_config(config: ETagConfig) -> Self {
4293 Self { config }
4294 }
4295
4296 fn generate_etag(data: &[u8], weak: bool) -> String {
4303 const FNV_OFFSET_BASIS: u64 = 0xcbf29ce484222325;
4305 const FNV_PRIME: u64 = 0x100000001b3;
4306
4307 let mut hash = FNV_OFFSET_BASIS;
4308 for &byte in data {
4309 hash ^= u64::from(byte);
4310 hash = hash.wrapping_mul(FNV_PRIME);
4311 }
4312
4313 if weak {
4315 format!("W/\"{:016x}\"", hash)
4316 } else {
4317 format!("\"{:016x}\"", hash)
4318 }
4319 }
4320
4321 fn parse_if_none_match(value: &str) -> Vec<String> {
4329 let trimmed = value.trim();
4330
4331 if trimmed == "*" {
4333 return vec!["*".to_string()];
4334 }
4335
4336 let mut etags = Vec::new();
4337 let mut current = String::new();
4338 let mut in_quote = false;
4339 let mut prev_char = '\0';
4340
4341 for ch in trimmed.chars() {
4342 match ch {
4343 '"' if prev_char != '\\' => {
4344 current.push(ch);
4345 if in_quote {
4346 let etag = current.trim().to_string();
4348 if !etag.is_empty() {
4349 etags.push(etag);
4350 }
4351 current.clear();
4352 }
4353 in_quote = !in_quote;
4354 }
4355 ',' if !in_quote => {
4356 current.clear();
4358 }
4359 _ => {
4360 current.push(ch);
4361 }
4362 }
4363 prev_char = ch;
4364 }
4365
4366 etags
4367 }
4368
4369 fn etags_match_weak(etag1: &str, etag2: &str) -> bool {
4377 let e1 = Self::strip_weak_prefix(etag1);
4379 let e2 = Self::strip_weak_prefix(etag2);
4380 e1 == e2
4381 }
4382
4383 fn strip_weak_prefix(s: &str) -> &str {
4385 if s.starts_with("W/") || s.starts_with("w/") {
4386 &s[2..]
4387 } else {
4388 s
4389 }
4390 }
4391
4392 fn is_cacheable_method(method: crate::request::Method) -> bool {
4394 matches!(
4395 method,
4396 crate::request::Method::Get | crate::request::Method::Head
4397 )
4398 }
4399
4400 fn get_existing_etag(headers: &[(String, Vec<u8>)]) -> Option<String> {
4402 for (name, value) in headers {
4403 if name.eq_ignore_ascii_case("etag") {
4404 return std::str::from_utf8(value).ok().map(String::from);
4405 }
4406 }
4407 None
4408 }
4409}
4410
4411impl Middleware for ETagMiddleware {
4412 fn after<'a>(
4413 &'a self,
4414 _ctx: &'a RequestContext,
4415 req: &'a Request,
4416 response: Response,
4417 ) -> BoxFuture<'a, Response> {
4418 let config = self.config.clone();
4419
4420 Box::pin(async move {
4421 if config.mode == ETagMode::Disabled {
4423 return response;
4424 }
4425
4426 if !Self::is_cacheable_method(req.method()) {
4428 return response;
4429 }
4430
4431 let (status, headers, body) = response.into_parts();
4433
4434 let existing_etag = Self::get_existing_etag(&headers);
4436
4437 let body_bytes = match &body {
4439 crate::response::ResponseBody::Bytes(bytes) => Some(bytes.clone()),
4440 crate::response::ResponseBody::Empty => Some(Vec::new()),
4441 crate::response::ResponseBody::Stream(_) => None,
4442 };
4443
4444 let etag = if let Some(existing) = existing_etag {
4446 Some(existing)
4447 } else if config.mode == ETagMode::Auto {
4448 if let Some(ref bytes) = body_bytes {
4449 if bytes.len() >= config.min_size {
4450 Some(Self::generate_etag(bytes, config.weak))
4451 } else {
4452 None
4453 }
4454 } else {
4455 None
4456 }
4457 } else {
4458 None
4459 };
4460
4461 if let Some(ref etag_value) = etag {
4463 if let Some(if_none_match) = req.headers().get("if-none-match") {
4464 if let Ok(value) = std::str::from_utf8(if_none_match) {
4465 let client_etags = Self::parse_if_none_match(value);
4466
4467 let matches = client_etags.iter().any(|client_etag| {
4469 client_etag == "*" || Self::etags_match_weak(client_etag, etag_value)
4470 });
4471
4472 if matches {
4473 return Response::with_status(
4475 crate::response::StatusCode::NOT_MODIFIED,
4476 )
4477 .header("etag", etag_value.as_bytes().to_vec());
4478 }
4479 }
4480 }
4481 }
4482
4483 let mut new_response = Response::with_status(status)
4485 .body(body)
4486 .rebuild_with_headers(headers);
4487
4488 if let Some(etag_value) = etag {
4489 new_response = new_response.header("etag", etag_value.into_bytes());
4490 }
4491
4492 new_response
4493 })
4494 }
4495
4496 fn name(&self) -> &'static str {
4497 "ETagMiddleware"
4498 }
4499}
4500
4501#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4510pub enum CacheDirective {
4511 Public,
4513 Private,
4515 NoStore,
4517 NoCache,
4519 NoTransform,
4521 MustRevalidate,
4523 ProxyRevalidate,
4525 StaleIfError,
4527 StaleWhileRevalidate,
4529 SMaxAge,
4531 OnlyIfCached,
4533 Immutable,
4535}
4536
4537impl CacheDirective {
4538 fn as_str(self) -> &'static str {
4540 match self {
4541 Self::Public => "public",
4542 Self::Private => "private",
4543 Self::NoStore => "no-store",
4544 Self::NoCache => "no-cache",
4545 Self::NoTransform => "no-transform",
4546 Self::MustRevalidate => "must-revalidate",
4547 Self::ProxyRevalidate => "proxy-revalidate",
4548 Self::StaleIfError => "stale-if-error",
4549 Self::StaleWhileRevalidate => "stale-while-revalidate",
4550 Self::SMaxAge => "s-maxage",
4551 Self::OnlyIfCached => "only-if-cached",
4552 Self::Immutable => "immutable",
4553 }
4554 }
4555}
4556
4557#[derive(Debug, Clone, Default)]
4587pub struct CacheControlBuilder {
4588 directives: Vec<CacheDirective>,
4589 max_age: Option<u32>,
4590 s_maxage: Option<u32>,
4591 stale_while_revalidate: Option<u32>,
4592 stale_if_error: Option<u32>,
4593}
4594
4595impl CacheControlBuilder {
4596 #[must_use]
4598 pub fn new() -> Self {
4599 Self::default()
4600 }
4601
4602 #[must_use]
4604 pub fn public(mut self) -> Self {
4605 self.directives.push(CacheDirective::Public);
4606 self
4607 }
4608
4609 #[must_use]
4611 pub fn private(mut self) -> Self {
4612 self.directives.push(CacheDirective::Private);
4613 self
4614 }
4615
4616 #[must_use]
4618 pub fn no_store(mut self) -> Self {
4619 self.directives.push(CacheDirective::NoStore);
4620 self
4621 }
4622
4623 #[must_use]
4625 pub fn no_cache(mut self) -> Self {
4626 self.directives.push(CacheDirective::NoCache);
4627 self
4628 }
4629
4630 #[must_use]
4632 pub fn no_transform(mut self) -> Self {
4633 self.directives.push(CacheDirective::NoTransform);
4634 self
4635 }
4636
4637 #[must_use]
4639 pub fn must_revalidate(mut self) -> Self {
4640 self.directives.push(CacheDirective::MustRevalidate);
4641 self
4642 }
4643
4644 #[must_use]
4646 pub fn proxy_revalidate(mut self) -> Self {
4647 self.directives.push(CacheDirective::ProxyRevalidate);
4648 self
4649 }
4650
4651 #[must_use]
4653 pub fn immutable(mut self) -> Self {
4654 self.directives.push(CacheDirective::Immutable);
4655 self
4656 }
4657
4658 #[must_use]
4660 pub fn max_age_secs(mut self, seconds: u32) -> Self {
4661 self.max_age = Some(seconds);
4662 self
4663 }
4664
4665 #[must_use]
4667 pub fn max_age(self, duration: std::time::Duration) -> Self {
4668 self.max_age_secs(duration.as_secs() as u32)
4669 }
4670
4671 #[must_use]
4673 pub fn s_maxage_secs(mut self, seconds: u32) -> Self {
4674 self.s_maxage = Some(seconds);
4675 self
4676 }
4677
4678 #[must_use]
4680 pub fn s_maxage(self, duration: std::time::Duration) -> Self {
4681 self.s_maxage_secs(duration.as_secs() as u32)
4682 }
4683
4684 #[must_use]
4686 pub fn stale_while_revalidate_secs(mut self, seconds: u32) -> Self {
4687 self.stale_while_revalidate = Some(seconds);
4688 self
4689 }
4690
4691 #[must_use]
4693 pub fn stale_if_error_secs(mut self, seconds: u32) -> Self {
4694 self.stale_if_error = Some(seconds);
4695 self
4696 }
4697
4698 #[must_use]
4700 pub fn build(&self) -> String {
4701 let mut parts = Vec::new();
4702
4703 for directive in &self.directives {
4705 parts.push(directive.as_str().to_string());
4706 }
4707
4708 if let Some(age) = self.max_age {
4710 parts.push(format!("max-age={age}"));
4711 }
4712
4713 if let Some(age) = self.s_maxage {
4715 parts.push(format!("s-maxage={age}"));
4716 }
4717
4718 if let Some(seconds) = self.stale_while_revalidate {
4720 parts.push(format!("stale-while-revalidate={seconds}"));
4721 }
4722
4723 if let Some(seconds) = self.stale_if_error {
4725 parts.push(format!("stale-if-error={seconds}"));
4726 }
4727
4728 parts.join(", ")
4729 }
4730
4731 #[must_use]
4733 pub fn is_no_cache(&self) -> bool {
4734 self.directives.contains(&CacheDirective::NoStore)
4735 || self.directives.contains(&CacheDirective::NoCache)
4736 }
4737}
4738
4739#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4741pub enum CachePreset {
4742 NoCache,
4744 PrivateNoCache,
4746 PublicOneHour,
4748 Immutable,
4750 CdnFriendly,
4752 StaticAssets,
4754}
4755
4756impl CachePreset {
4757 #[must_use]
4759 pub fn to_header_value(&self) -> String {
4760 match self {
4761 Self::NoCache => "no-store, no-cache, must-revalidate".to_string(),
4762 Self::PrivateNoCache => "private, max-age=0, must-revalidate".to_string(),
4763 Self::PublicOneHour => "public, max-age=3600".to_string(),
4764 Self::Immutable => "public, max-age=31536000, immutable".to_string(),
4765 Self::CdnFriendly => "public, max-age=60, s-maxage=3600".to_string(),
4766 Self::StaticAssets => "public, max-age=86400".to_string(),
4767 }
4768 }
4769
4770 #[must_use]
4772 pub fn to_builder(&self) -> CacheControlBuilder {
4773 match self {
4774 Self::NoCache => CacheControlBuilder::new()
4775 .no_store()
4776 .no_cache()
4777 .must_revalidate(),
4778 Self::PrivateNoCache => CacheControlBuilder::new()
4779 .private()
4780 .max_age_secs(0)
4781 .must_revalidate(),
4782 Self::PublicOneHour => CacheControlBuilder::new().public().max_age_secs(3600),
4783 Self::Immutable => CacheControlBuilder::new()
4784 .public()
4785 .max_age_secs(31536000)
4786 .immutable(),
4787 Self::CdnFriendly => CacheControlBuilder::new()
4788 .public()
4789 .max_age_secs(60)
4790 .s_maxage_secs(3600),
4791 Self::StaticAssets => CacheControlBuilder::new().public().max_age_secs(86400),
4792 }
4793 }
4794}
4795
4796#[derive(Debug, Clone)]
4798pub struct CacheControlConfig {
4799 pub cache_control: String,
4801 pub vary: Vec<String>,
4803 pub set_expires: bool,
4805 pub preserve_existing: bool,
4807 pub methods: Vec<crate::request::Method>,
4809 pub path_patterns: Vec<String>,
4811 pub cacheable_statuses: Vec<u16>,
4813}
4814
4815impl Default for CacheControlConfig {
4816 fn default() -> Self {
4817 Self {
4818 cache_control: CachePreset::NoCache.to_header_value(),
4819 vary: Vec::new(),
4820 set_expires: false,
4821 preserve_existing: true,
4822 methods: vec![crate::request::Method::Get, crate::request::Method::Head],
4823 path_patterns: Vec::new(),
4824 cacheable_statuses: (200..300).collect(),
4825 }
4826 }
4827}
4828
4829impl CacheControlConfig {
4830 #[must_use]
4832 pub fn new() -> Self {
4833 Self::default()
4834 }
4835
4836 #[must_use]
4838 pub fn from_preset(preset: CachePreset) -> Self {
4839 Self {
4840 cache_control: preset.to_header_value(),
4841 ..Self::default()
4842 }
4843 }
4844
4845 #[must_use]
4847 pub fn from_builder(builder: CacheControlBuilder) -> Self {
4848 Self {
4849 cache_control: builder.build(),
4850 ..Self::default()
4851 }
4852 }
4853
4854 #[must_use]
4856 pub fn cache_control(mut self, value: impl Into<String>) -> Self {
4857 self.cache_control = value.into();
4858 self
4859 }
4860
4861 #[must_use]
4863 pub fn vary(mut self, header: impl Into<String>) -> Self {
4864 self.vary.push(header.into());
4865 self
4866 }
4867
4868 #[must_use]
4870 pub fn vary_headers(mut self, headers: Vec<String>) -> Self {
4871 self.vary.extend(headers);
4872 self
4873 }
4874
4875 #[must_use]
4877 pub fn with_expires(mut self, enable: bool) -> Self {
4878 self.set_expires = enable;
4879 self
4880 }
4881
4882 #[must_use]
4884 pub fn preserve_existing(mut self, preserve: bool) -> Self {
4885 self.preserve_existing = preserve;
4886 self
4887 }
4888
4889 #[must_use]
4891 pub fn methods(mut self, methods: Vec<crate::request::Method>) -> Self {
4892 self.methods = methods;
4893 self
4894 }
4895
4896 #[must_use]
4898 pub fn path_patterns(mut self, patterns: Vec<String>) -> Self {
4899 self.path_patterns = patterns;
4900 self
4901 }
4902
4903 #[must_use]
4905 pub fn cacheable_statuses(mut self, statuses: Vec<u16>) -> Self {
4906 self.cacheable_statuses = statuses;
4907 self
4908 }
4909}
4910
4911pub struct CacheControlMiddleware {
4958 config: CacheControlConfig,
4959}
4960
4961impl Default for CacheControlMiddleware {
4962 fn default() -> Self {
4963 Self::new()
4964 }
4965}
4966
4967impl CacheControlMiddleware {
4968 #[must_use]
4972 pub fn new() -> Self {
4973 Self {
4974 config: CacheControlConfig::default(),
4975 }
4976 }
4977
4978 #[must_use]
4980 pub fn with_preset(preset: CachePreset) -> Self {
4981 Self {
4982 config: CacheControlConfig::from_preset(preset),
4983 }
4984 }
4985
4986 #[must_use]
4988 pub fn with_config(config: CacheControlConfig) -> Self {
4989 Self { config }
4990 }
4991
4992 fn is_cacheable_method(&self, method: crate::request::Method) -> bool {
4994 self.config.methods.contains(&method)
4995 }
4996
4997 fn is_cacheable_status(&self, status: u16) -> bool {
4999 self.config.cacheable_statuses.contains(&status)
5000 }
5001
5002 fn matches_path(&self, path: &str) -> bool {
5004 if self.config.path_patterns.is_empty() {
5005 return true; }
5007
5008 for pattern in &self.config.path_patterns {
5009 if path_matches_pattern(path, pattern) {
5010 return true;
5011 }
5012 }
5013 false
5014 }
5015
5016 fn has_cache_control(headers: &[(String, Vec<u8>)]) -> bool {
5018 headers
5019 .iter()
5020 .any(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
5021 }
5022
5023 fn calculate_expires(cache_control: &str) -> Option<String> {
5025 for directive in cache_control.split(',') {
5027 let directive = directive.trim();
5028 if directive.starts_with("max-age=") {
5029 if let Ok(seconds) = directive[8..].parse::<u64>() {
5030 let now = std::time::SystemTime::now();
5032 if let Some(expires) = now.checked_add(std::time::Duration::from_secs(seconds))
5033 {
5034 return Some(format_http_date(expires));
5035 }
5036 }
5037 }
5038 }
5039 None
5040 }
5041}
5042
5043fn path_matches_pattern(path: &str, pattern: &str) -> bool {
5045 if pattern == "*" {
5046 return true;
5047 }
5048
5049 if pattern.contains('*') {
5050 let parts: Vec<&str> = pattern.split('*').collect();
5052 if parts.len() == 2 {
5053 let (prefix, suffix) = (parts[0], parts[1]);
5054 return path.starts_with(prefix) && path.ends_with(suffix);
5055 }
5056 let fixed_parts: Vec<&str> = pattern.split('*').filter(|s| !s.is_empty()).collect();
5058 let mut remaining = path;
5059 for part in fixed_parts {
5060 if let Some(pos) = remaining.find(part) {
5061 remaining = &remaining[pos + part.len()..];
5062 } else {
5063 return false;
5064 }
5065 }
5066 true
5067 } else {
5068 path == pattern
5069 }
5070}
5071
5072fn format_http_date(time: std::time::SystemTime) -> String {
5074 match time.duration_since(std::time::UNIX_EPOCH) {
5076 Ok(duration) => {
5077 let secs = duration.as_secs();
5079 let days = secs / 86400;
5081 let remaining_secs = secs % 86400;
5082 let hours = remaining_secs / 3600;
5083 let minutes = (remaining_secs % 3600) / 60;
5084 let seconds = remaining_secs % 60;
5085
5086 let day_of_week = ((days + 4) % 7) as usize;
5088 let day_names = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
5089
5090 let (year, month, day) = days_to_date(days);
5092 let month_names = [
5093 "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
5094 ];
5095
5096 format!(
5097 "{}, {:02} {} {} {:02}:{:02}:{:02} GMT",
5098 day_names[day_of_week],
5099 day,
5100 month_names[(month - 1) as usize],
5101 year,
5102 hours,
5103 minutes,
5104 seconds
5105 )
5106 }
5107 Err(_) => "Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
5108 }
5109}
5110
5111fn days_to_date(days: u64) -> (u64, u64, u64) {
5113 let mut remaining_days = days;
5115 let mut year = 1970u64;
5116
5117 loop {
5118 let days_in_year = if is_leap_year(year) { 366 } else { 365 };
5119 if remaining_days < days_in_year {
5120 break;
5121 }
5122 remaining_days -= days_in_year;
5123 year += 1;
5124 }
5125
5126 let leap = is_leap_year(year);
5127 let month_days: [u64; 12] = if leap {
5128 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
5129 } else {
5130 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
5131 };
5132
5133 let mut month = 1u64;
5134 for &days_in_month in &month_days {
5135 if remaining_days < days_in_month {
5136 break;
5137 }
5138 remaining_days -= days_in_month;
5139 month += 1;
5140 }
5141
5142 (year, month, remaining_days + 1)
5143}
5144
5145fn is_leap_year(year: u64) -> bool {
5147 (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
5148}
5149
5150impl Middleware for CacheControlMiddleware {
5151 fn after<'a>(
5152 &'a self,
5153 _ctx: &'a RequestContext,
5154 req: &'a Request,
5155 response: Response,
5156 ) -> BoxFuture<'a, Response> {
5157 let config = self.config.clone();
5158
5159 Box::pin(async move {
5160 if !self.is_cacheable_method(req.method()) {
5162 return response;
5163 }
5164
5165 if !self.is_cacheable_status(response.status().as_u16()) {
5166 return response;
5167 }
5168
5169 if !self.matches_path(req.path()) {
5170 return response;
5171 }
5172
5173 let (status, mut headers, body) = response.into_parts();
5175
5176 if config.preserve_existing && Self::has_cache_control(&headers) {
5178 let mut resp = Response::with_status(status);
5180 for (name, value) in headers {
5181 resp = resp.header(name, value);
5182 }
5183 return resp.body(body);
5184 }
5185
5186 headers.push((
5188 "Cache-Control".to_string(),
5189 config.cache_control.as_bytes().to_vec(),
5190 ));
5191
5192 if !config.vary.is_empty() {
5194 let vary_value = config.vary.join(", ");
5195 headers.push(("Vary".to_string(), vary_value.into_bytes()));
5196 }
5197
5198 if config.set_expires {
5200 if let Some(expires) = Self::calculate_expires(&config.cache_control) {
5201 headers.push(("Expires".to_string(), expires.into_bytes()));
5202 }
5203 }
5204
5205 let mut resp = Response::with_status(status);
5207 for (name, value) in headers {
5208 resp = resp.header(name, value);
5209 }
5210 resp.body(body)
5211 })
5212 }
5213
5214 fn name(&self) -> &'static str {
5215 "CacheControlMiddleware"
5216 }
5217}
5218
5219#[derive(Debug, Clone)]
5254pub struct TraceRejectionMiddleware {
5255 log_attempts: bool,
5257}
5258
5259impl Default for TraceRejectionMiddleware {
5260 fn default() -> Self {
5261 Self::new()
5262 }
5263}
5264
5265impl TraceRejectionMiddleware {
5266 #[must_use]
5270 pub fn new() -> Self {
5271 Self { log_attempts: true }
5272 }
5273
5274 #[must_use]
5279 pub fn log_attempts(mut self, log: bool) -> Self {
5280 self.log_attempts = log;
5281 self
5282 }
5283
5284 fn rejection_response(path: &str) -> Response {
5286 let body = format!(
5287 r#"{{"detail":"HTTP TRACE method is not allowed","path":"{}"}}"#,
5288 path.replace('"', "\\\"")
5289 );
5290 Response::with_status(crate::response::StatusCode::METHOD_NOT_ALLOWED)
5291 .header("Content-Type", b"application/json".to_vec())
5292 .header(
5293 "Allow",
5294 b"GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD".to_vec(),
5295 )
5296 .body(crate::response::ResponseBody::Bytes(body.into_bytes()))
5297 }
5298}
5299
5300impl Middleware for TraceRejectionMiddleware {
5301 fn before<'a>(
5302 &'a self,
5303 _ctx: &'a RequestContext,
5304 req: &'a mut Request,
5305 ) -> BoxFuture<'a, ControlFlow> {
5306 Box::pin(async move {
5307 if req.method() == crate::request::Method::Trace {
5308 if self.log_attempts {
5309 let path = req.path();
5311 let remote_ip = req
5312 .headers()
5313 .get("X-Forwarded-For")
5314 .or_else(|| req.headers().get("X-Real-IP"))
5315 .map(|v| String::from_utf8_lossy(v).to_string())
5316 .unwrap_or_else(|| "unknown".to_string());
5317
5318 eprintln!(
5319 "[SECURITY] TRACE request blocked: path={}, remote_ip={}",
5320 path, remote_ip
5321 );
5322 }
5323
5324 return ControlFlow::Break(Self::rejection_response(req.path()));
5325 }
5326
5327 ControlFlow::Continue
5328 })
5329 }
5330
5331 fn name(&self) -> &'static str {
5332 "TraceRejection"
5333 }
5334}
5335
5336#[derive(Debug, Clone)]
5346#[allow(clippy::struct_excessive_bools)]
5347pub struct HttpsRedirectConfig {
5348 pub redirect_enabled: bool,
5350 pub permanent_redirect: bool,
5352 pub hsts_max_age_secs: u64,
5354 pub hsts_include_subdomains: bool,
5356 pub hsts_preload: bool,
5358 pub exclude_paths: Vec<String>,
5360 pub https_port: u16,
5362}
5363
5364impl Default for HttpsRedirectConfig {
5365 fn default() -> Self {
5366 Self {
5367 redirect_enabled: true,
5368 permanent_redirect: true, hsts_max_age_secs: 31_536_000, hsts_include_subdomains: false,
5371 hsts_preload: false,
5372 exclude_paths: Vec::new(),
5373 https_port: 443,
5374 }
5375 }
5376}
5377
5378#[derive(Debug, Clone)]
5418pub struct HttpsRedirectMiddleware {
5419 config: HttpsRedirectConfig,
5420}
5421
5422impl Default for HttpsRedirectMiddleware {
5423 fn default() -> Self {
5424 Self::new()
5425 }
5426}
5427
5428impl HttpsRedirectMiddleware {
5429 #[must_use]
5431 pub fn new() -> Self {
5432 Self {
5433 config: HttpsRedirectConfig::default(),
5434 }
5435 }
5436
5437 #[must_use]
5439 pub fn redirect_enabled(mut self, enabled: bool) -> Self {
5440 self.config.redirect_enabled = enabled;
5441 self
5442 }
5443
5444 #[must_use]
5448 pub fn permanent_redirect(mut self, permanent: bool) -> Self {
5449 self.config.permanent_redirect = permanent;
5450 self
5451 }
5452
5453 #[must_use]
5458 pub fn hsts_max_age_secs(mut self, secs: u64) -> Self {
5459 self.config.hsts_max_age_secs = secs;
5460 self
5461 }
5462
5463 #[must_use]
5465 pub fn include_subdomains(mut self, include: bool) -> Self {
5466 self.config.hsts_include_subdomains = include;
5467 self
5468 }
5469
5470 #[must_use]
5475 pub fn preload(mut self, preload: bool) -> Self {
5476 self.config.hsts_preload = preload;
5477 self
5478 }
5479
5480 #[must_use]
5485 pub fn exclude_path(mut self, path: impl Into<String>) -> Self {
5486 self.config.exclude_paths.push(path.into());
5487 self
5488 }
5489
5490 #[must_use]
5492 pub fn exclude_paths(mut self, paths: Vec<String>) -> Self {
5493 self.config.exclude_paths = paths;
5494 self
5495 }
5496
5497 #[must_use]
5499 pub fn https_port(mut self, port: u16) -> Self {
5500 self.config.https_port = port;
5501 self
5502 }
5503
5504 fn is_secure(&self, req: &Request) -> bool {
5509 fn trim_ascii(mut bytes: &[u8]) -> &[u8] {
5510 while matches!(bytes.first(), Some(b' ' | b'\t')) {
5511 bytes = &bytes[1..];
5512 }
5513 while matches!(bytes.last(), Some(b' ' | b'\t')) {
5514 bytes = &bytes[..bytes.len() - 1];
5515 }
5516 bytes
5517 }
5518
5519 if let Some(info) = req.get_extension::<crate::request::ConnectionInfo>() {
5520 if info.is_tls {
5521 return true;
5522 }
5523 }
5524
5525 if let Some(forwarded) = req.headers().get("Forwarded") {
5527 if let Ok(s) = std::str::from_utf8(forwarded) {
5528 for entry in s.split(',') {
5529 for param in entry.split(';') {
5530 let param = param.trim();
5531 if let Some((k, v)) = param.split_once('=') {
5532 if k.trim().eq_ignore_ascii_case("proto") {
5533 let proto = v.trim().trim_matches('"');
5534 if proto.eq_ignore_ascii_case("https") {
5535 return true;
5536 }
5537 }
5538 }
5539 }
5540 }
5541 }
5542 }
5543
5544 if let Some(proto) = req.headers().get("X-Forwarded-Proto") {
5546 let first = proto.split(|&b| b == b',').next().unwrap_or(proto);
5547 return trim_ascii(first).eq_ignore_ascii_case(b"https");
5548 }
5549
5550 if let Some(ssl) = req.headers().get("X-Forwarded-Ssl") {
5552 return ssl.eq_ignore_ascii_case(b"on");
5553 }
5554
5555 if let Some(https) = req.headers().get("Front-End-Https") {
5557 return https.eq_ignore_ascii_case(b"on");
5558 }
5559
5560 false
5561 }
5562
5563 fn is_excluded(&self, path: &str) -> bool {
5565 self.config
5566 .exclude_paths
5567 .iter()
5568 .any(|p| path.starts_with(p))
5569 }
5570
5571 fn build_hsts_header(&self) -> Option<Vec<u8>> {
5573 if self.config.hsts_max_age_secs == 0 {
5574 return None;
5575 }
5576
5577 let mut value = format!("max-age={}", self.config.hsts_max_age_secs);
5578
5579 if self.config.hsts_include_subdomains {
5580 value.push_str("; includeSubDomains");
5581 }
5582
5583 if self.config.hsts_preload {
5584 value.push_str("; preload");
5585 }
5586
5587 Some(value.into_bytes())
5588 }
5589
5590 fn build_redirect_url(&self, req: &Request) -> String {
5592 let host = req
5593 .headers()
5594 .get("Host")
5595 .map(|h| String::from_utf8_lossy(h).to_string())
5596 .unwrap_or_else(|| "localhost".to_string());
5597
5598 let host_without_port = host.split(':').next().unwrap_or(&host);
5600
5601 let path = req.path();
5602 let query = req.query();
5603
5604 if self.config.https_port == 443 {
5605 match query {
5606 Some(q) => format!("https://{}{}?{}", host_without_port, path, q),
5607 None => format!("https://{}{}", host_without_port, path),
5608 }
5609 } else {
5610 match query {
5611 Some(q) => format!(
5612 "https://{}:{}{}?{}",
5613 host_without_port, self.config.https_port, path, q
5614 ),
5615 None => format!(
5616 "https://{}:{}{}",
5617 host_without_port, self.config.https_port, path
5618 ),
5619 }
5620 }
5621 }
5622}
5623
5624impl Middleware for HttpsRedirectMiddleware {
5625 fn before<'a>(
5626 &'a self,
5627 _ctx: &'a RequestContext,
5628 req: &'a mut Request,
5629 ) -> BoxFuture<'a, ControlFlow> {
5630 Box::pin(async move {
5631 if !self.config.redirect_enabled {
5633 return ControlFlow::Continue;
5634 }
5635
5636 if self.is_secure(req) {
5638 return ControlFlow::Continue;
5639 }
5640
5641 if self.is_excluded(req.path()) {
5643 return ControlFlow::Continue;
5644 }
5645
5646 let redirect_url = self.build_redirect_url(req);
5648
5649 let status = if self.config.permanent_redirect {
5651 crate::response::StatusCode::MOVED_PERMANENTLY
5652 } else {
5653 crate::response::StatusCode::TEMPORARY_REDIRECT
5654 };
5655
5656 let response = Response::with_status(status)
5658 .header("Location", redirect_url.into_bytes())
5659 .header("Content-Type", b"text/plain".to_vec())
5660 .body(crate::response::ResponseBody::Bytes(
5661 b"Redirecting to HTTPS...".to_vec(),
5662 ));
5663
5664 ControlFlow::Break(response)
5665 })
5666 }
5667
5668 fn after<'a>(
5669 &'a self,
5670 _ctx: &'a RequestContext,
5671 req: &'a Request,
5672 response: Response,
5673 ) -> BoxFuture<'a, Response> {
5674 Box::pin(async move {
5675 if !self.is_secure(req) {
5677 return response;
5678 }
5679
5680 if let Some(hsts_value) = self.build_hsts_header() {
5682 response.header("Strict-Transport-Security", hsts_value)
5683 } else {
5684 response
5685 }
5686 })
5687 }
5688
5689 fn name(&self) -> &'static str {
5690 "HttpsRedirect"
5691 }
5692}
5693
5694pub trait ResponseInterceptor: Send + Sync {
5733 fn intercept<'a>(
5744 &'a self,
5745 ctx: &'a ResponseInterceptorContext<'a>,
5746 response: Response,
5747 ) -> BoxFuture<'a, Response>;
5748
5749 fn name(&self) -> &'static str {
5751 std::any::type_name::<Self>()
5752 }
5753}
5754
5755#[derive(Debug)]
5760pub struct ResponseInterceptorContext<'a> {
5761 pub request: &'a Request,
5763 pub start_time: Instant,
5765 pub request_ctx: &'a RequestContext,
5767}
5768
5769impl<'a> ResponseInterceptorContext<'a> {
5770 pub fn new(request: &'a Request, request_ctx: &'a RequestContext, start_time: Instant) -> Self {
5772 Self {
5773 request,
5774 start_time,
5775 request_ctx,
5776 }
5777 }
5778
5779 pub fn elapsed(&self) -> std::time::Duration {
5781 self.start_time.elapsed()
5782 }
5783
5784 pub fn elapsed_ms(&self) -> u128 {
5786 self.start_time.elapsed().as_millis()
5787 }
5788}
5789
5790#[derive(Default)]
5805pub struct ResponseInterceptorStack {
5806 interceptors: Vec<Arc<dyn ResponseInterceptor>>,
5807}
5808
5809impl ResponseInterceptorStack {
5810 #[must_use]
5812 pub fn new() -> Self {
5813 Self {
5814 interceptors: Vec::new(),
5815 }
5816 }
5817
5818 #[must_use]
5820 pub fn with_capacity(capacity: usize) -> Self {
5821 Self {
5822 interceptors: Vec::with_capacity(capacity),
5823 }
5824 }
5825
5826 pub fn push<I: ResponseInterceptor + 'static>(&mut self, interceptor: I) {
5828 self.interceptors.push(Arc::new(interceptor));
5829 }
5830
5831 pub fn push_arc(&mut self, interceptor: Arc<dyn ResponseInterceptor>) {
5833 self.interceptors.push(interceptor);
5834 }
5835
5836 #[must_use]
5838 pub fn len(&self) -> usize {
5839 self.interceptors.len()
5840 }
5841
5842 #[must_use]
5844 pub fn is_empty(&self) -> bool {
5845 self.interceptors.is_empty()
5846 }
5847
5848 pub async fn process(
5850 &self,
5851 ctx: &ResponseInterceptorContext<'_>,
5852 mut response: Response,
5853 ) -> Response {
5854 for interceptor in &self.interceptors {
5855 let _ = ctx.request_ctx.checkpoint();
5856 response = interceptor.intercept(ctx, response).await;
5857 }
5858 response
5859 }
5860}
5861
5862#[derive(Debug, Clone)]
5879pub struct TimingInterceptor {
5880 header_name: String,
5882 include_server_timing: bool,
5884 server_timing_name: String,
5886}
5887
5888impl Default for TimingInterceptor {
5889 fn default() -> Self {
5890 Self::new()
5891 }
5892}
5893
5894impl TimingInterceptor {
5895 #[must_use]
5897 pub fn new() -> Self {
5898 Self {
5899 header_name: "X-Response-Time".to_string(),
5900 include_server_timing: false,
5901 server_timing_name: "total".to_string(),
5902 }
5903 }
5904
5905 #[must_use]
5907 pub fn with_server_timing(mut self, metric_name: impl Into<String>) -> Self {
5908 self.include_server_timing = true;
5909 self.server_timing_name = metric_name.into();
5910 self
5911 }
5912
5913 #[must_use]
5915 pub fn header_name(mut self, name: impl Into<String>) -> Self {
5916 self.header_name = name.into();
5917 self
5918 }
5919}
5920
5921impl ResponseInterceptor for TimingInterceptor {
5922 fn intercept<'a>(
5923 &'a self,
5924 ctx: &'a ResponseInterceptorContext<'a>,
5925 response: Response,
5926 ) -> BoxFuture<'a, Response> {
5927 Box::pin(async move {
5928 let elapsed_ms = ctx.elapsed_ms();
5929 let timing_value = format!("{}ms", elapsed_ms);
5930
5931 let response = response.header(&self.header_name, timing_value.clone().into_bytes());
5932
5933 if self.include_server_timing {
5934 let server_timing = format!("{};dur={}", self.server_timing_name, elapsed_ms);
5936 response.header("Server-Timing", server_timing.into_bytes())
5937 } else {
5938 response
5939 }
5940 })
5941 }
5942
5943 fn name(&self) -> &'static str {
5944 "TimingInterceptor"
5945 }
5946}
5947
5948#[derive(Debug, Clone)]
5972#[allow(clippy::struct_excessive_bools)]
5973pub struct DebugInfoInterceptor {
5974 include_path: bool,
5976 include_method: bool,
5978 include_request_id: bool,
5980 include_timing: bool,
5982 header_prefix: String,
5984}
5985
5986impl Default for DebugInfoInterceptor {
5987 fn default() -> Self {
5988 Self::new()
5989 }
5990}
5991
5992impl DebugInfoInterceptor {
5993 #[must_use]
5995 pub fn new() -> Self {
5996 Self {
5997 include_path: true,
5998 include_method: true,
5999 include_request_id: true,
6000 include_timing: true,
6001 header_prefix: "X-Debug-".to_string(),
6002 }
6003 }
6004
6005 #[must_use]
6007 pub fn include_path(mut self, include: bool) -> Self {
6008 self.include_path = include;
6009 self
6010 }
6011
6012 #[must_use]
6014 pub fn include_method(mut self, include: bool) -> Self {
6015 self.include_method = include;
6016 self
6017 }
6018
6019 #[must_use]
6021 pub fn include_request_id(mut self, include: bool) -> Self {
6022 self.include_request_id = include;
6023 self
6024 }
6025
6026 #[must_use]
6028 pub fn include_timing(mut self, include: bool) -> Self {
6029 self.include_timing = include;
6030 self
6031 }
6032
6033 #[must_use]
6035 pub fn header_prefix(mut self, prefix: impl Into<String>) -> Self {
6036 self.header_prefix = prefix.into();
6037 self
6038 }
6039}
6040
6041impl ResponseInterceptor for DebugInfoInterceptor {
6042 fn intercept<'a>(
6043 &'a self,
6044 ctx: &'a ResponseInterceptorContext<'a>,
6045 response: Response,
6046 ) -> BoxFuture<'a, Response> {
6047 Box::pin(async move {
6048 let mut resp = response;
6049
6050 if self.include_path {
6051 let header_name = format!("{}Path", self.header_prefix);
6052 resp = resp.header(header_name, ctx.request.path().as_bytes().to_vec());
6053 }
6054
6055 if self.include_method {
6056 let header_name = format!("{}Method", self.header_prefix);
6057 resp = resp.header(
6058 header_name,
6059 ctx.request.method().as_str().as_bytes().to_vec(),
6060 );
6061 }
6062
6063 if self.include_request_id {
6064 if let Some(request_id) = ctx.request.get_extension::<RequestId>() {
6065 let header_name = format!("{}Request-Id", self.header_prefix);
6066 resp = resp.header(header_name, request_id.0.as_bytes().to_vec());
6067 }
6068 }
6069
6070 if self.include_timing {
6071 let header_name = format!("{}Handler-Time", self.header_prefix);
6072 let timing = format!("{}ms", ctx.elapsed_ms());
6073 resp = resp.header(header_name, timing.into_bytes());
6074 }
6075
6076 resp
6077 })
6078 }
6079
6080 fn name(&self) -> &'static str {
6081 "DebugInfoInterceptor"
6082 }
6083}
6084
6085pub struct ResponseBodyTransform<F>
6106where
6107 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6108{
6109 transform_fn: F,
6110 content_type_filter: Option<String>,
6112}
6113
6114impl<F> ResponseBodyTransform<F>
6115where
6116 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6117{
6118 pub fn new(transform_fn: F) -> Self {
6120 Self {
6121 transform_fn,
6122 content_type_filter: None,
6123 }
6124 }
6125
6126 #[must_use]
6128 pub fn for_content_type(mut self, content_type: impl Into<String>) -> Self {
6129 self.content_type_filter = Some(content_type.into());
6130 self
6131 }
6132
6133 fn should_transform(&self, response: &Response) -> bool {
6134 match &self.content_type_filter {
6135 Some(filter) => response
6136 .headers()
6137 .iter()
6138 .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
6139 .and_then(|(_, ct)| std::str::from_utf8(ct).ok())
6140 .map(|ct| ct.starts_with(filter))
6141 .unwrap_or(false),
6142 None => true,
6143 }
6144 }
6145}
6146
6147impl<F> ResponseInterceptor for ResponseBodyTransform<F>
6148where
6149 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6150{
6151 fn intercept<'a>(
6152 &'a self,
6153 _ctx: &'a ResponseInterceptorContext<'a>,
6154 response: Response,
6155 ) -> BoxFuture<'a, Response> {
6156 Box::pin(async move {
6157 if !self.should_transform(&response) {
6158 return response;
6159 }
6160
6161 let body_bytes = match response.body_ref() {
6163 crate::response::ResponseBody::Empty => Vec::new(),
6164 crate::response::ResponseBody::Bytes(b) => b.clone(),
6165 crate::response::ResponseBody::Stream(_) => {
6166 return response;
6168 }
6169 };
6170
6171 let transformed = (self.transform_fn)(body_bytes);
6173
6174 response.body(crate::response::ResponseBody::Bytes(transformed))
6176 })
6177 }
6178
6179 fn name(&self) -> &'static str {
6180 "ResponseBodyTransform"
6181 }
6182}
6183
6184#[derive(Debug, Clone, Default)]
6201pub struct HeaderTransformInterceptor {
6202 add_headers: Vec<(String, Vec<u8>)>,
6204 remove_headers: Vec<String>,
6206 rename_headers: Vec<(String, String)>,
6208}
6209
6210impl HeaderTransformInterceptor {
6211 #[must_use]
6213 pub fn new() -> Self {
6214 Self::default()
6215 }
6216
6217 #[must_use]
6219 pub fn add(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
6220 self.add_headers.push((name.into(), value.into()));
6221 self
6222 }
6223
6224 #[must_use]
6226 pub fn remove(mut self, name: impl Into<String>) -> Self {
6227 self.remove_headers.push(name.into());
6228 self
6229 }
6230
6231 #[must_use]
6233 pub fn rename(mut self, old_name: impl Into<String>, new_name: impl Into<String>) -> Self {
6234 self.rename_headers.push((old_name.into(), new_name.into()));
6235 self
6236 }
6237}
6238
6239impl ResponseInterceptor for HeaderTransformInterceptor {
6240 fn intercept<'a>(
6241 &'a self,
6242 _ctx: &'a ResponseInterceptorContext<'a>,
6243 response: Response,
6244 ) -> BoxFuture<'a, Response> {
6245 let add_headers = self.add_headers.clone();
6246 let remove_headers = self.remove_headers.clone();
6247 let rename_headers = self.rename_headers.clone();
6248
6249 Box::pin(async move {
6250 let mut resp = response;
6251
6252 for (old_name, new_name) in &rename_headers {
6254 let values: Vec<Vec<u8>> = resp
6255 .headers()
6256 .iter()
6257 .filter(|(name, _)| name.eq_ignore_ascii_case(old_name))
6258 .map(|(_, v)| v.clone())
6259 .collect();
6260
6261 if !values.is_empty() {
6262 resp = resp.remove_header(old_name);
6263 for v in values {
6264 resp = resp.header(new_name, v);
6265 }
6266 }
6267 }
6268
6269 for (name, value) in add_headers {
6271 resp = resp.header(name, value);
6272 }
6273
6274 for name in &remove_headers {
6276 resp = resp.remove_header(name);
6277 }
6278
6279 resp
6280 })
6281 }
6282
6283 fn name(&self) -> &'static str {
6284 "HeaderTransformInterceptor"
6285 }
6286}
6287
6288pub struct ConditionalInterceptor<I, F>
6304where
6305 I: ResponseInterceptor,
6306 F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6307{
6308 inner: I,
6309 condition: F,
6310}
6311
6312impl<I, F> ConditionalInterceptor<I, F>
6313where
6314 I: ResponseInterceptor,
6315 F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6316{
6317 pub fn new(inner: I, condition: F) -> Self {
6319 Self { inner, condition }
6320 }
6321}
6322
6323impl<I, F> ResponseInterceptor for ConditionalInterceptor<I, F>
6324where
6325 I: ResponseInterceptor,
6326 F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6327{
6328 fn intercept<'a>(
6329 &'a self,
6330 ctx: &'a ResponseInterceptorContext<'a>,
6331 response: Response,
6332 ) -> BoxFuture<'a, Response> {
6333 Box::pin(async move {
6334 if (self.condition)(ctx, &response) {
6335 self.inner.intercept(ctx, response).await
6336 } else {
6337 response
6338 }
6339 })
6340 }
6341
6342 fn name(&self) -> &'static str {
6343 "ConditionalInterceptor"
6344 }
6345}
6346
6347#[derive(Debug, Clone)]
6366pub struct ErrorResponseTransformer {
6367 status_codes: HashSet<u16>,
6369 replacement_body: Option<Vec<u8>>,
6371 add_error_id: bool,
6373}
6374
6375impl Default for ErrorResponseTransformer {
6376 fn default() -> Self {
6377 Self::new()
6378 }
6379}
6380
6381impl ErrorResponseTransformer {
6382 #[must_use]
6384 pub fn new() -> Self {
6385 Self {
6386 status_codes: HashSet::new(),
6387 replacement_body: None,
6388 add_error_id: false,
6389 }
6390 }
6391
6392 #[must_use]
6394 pub fn hide_details_for_status(mut self, status: crate::response::StatusCode) -> Self {
6395 self.status_codes.insert(status.as_u16());
6396 self
6397 }
6398
6399 #[must_use]
6401 pub fn with_replacement_body(mut self, body: impl Into<Vec<u8>>) -> Self {
6402 self.replacement_body = Some(body.into());
6403 self
6404 }
6405
6406 #[must_use]
6408 pub fn add_error_id(mut self, enable: bool) -> Self {
6409 self.add_error_id = enable;
6410 self
6411 }
6412}
6413
6414impl ResponseInterceptor for ErrorResponseTransformer {
6415 fn intercept<'a>(
6416 &'a self,
6417 ctx: &'a ResponseInterceptorContext<'a>,
6418 response: Response,
6419 ) -> BoxFuture<'a, Response> {
6420 Box::pin(async move {
6421 let status_code = response.status().as_u16();
6422
6423 if !self.status_codes.contains(&status_code) {
6424 return response;
6425 }
6426
6427 let mut resp = response;
6428
6429 if let Some(ref replacement) = self.replacement_body {
6431 resp = resp.body(crate::response::ResponseBody::Bytes(replacement.clone()));
6432 }
6433
6434 if self.add_error_id {
6436 let error_id = ctx
6438 .request
6439 .get_extension::<RequestId>()
6440 .map(|r| r.0.clone())
6441 .unwrap_or_else(|| format!("err-{}", ctx.elapsed_ms()));
6442 resp = resp.header("X-Error-Id", error_id.into_bytes());
6443 }
6444
6445 resp
6446 })
6447 }
6448
6449 fn name(&self) -> &'static str {
6450 "ErrorResponseTransformer"
6451 }
6452}
6453
6454pub struct ResponseInterceptorMiddleware<I>
6470where
6471 I: ResponseInterceptor,
6472{
6473 interceptor: I,
6474}
6475
6476impl<I> ResponseInterceptorMiddleware<I>
6477where
6478 I: ResponseInterceptor,
6479{
6480 pub fn new(interceptor: I) -> Self {
6482 Self { interceptor }
6483 }
6484}
6485
6486impl<I> Middleware for ResponseInterceptorMiddleware<I>
6487where
6488 I: ResponseInterceptor,
6489{
6490 fn before<'a>(
6491 &'a self,
6492 _ctx: &'a RequestContext,
6493 req: &'a mut Request,
6494 ) -> BoxFuture<'a, ControlFlow> {
6495 req.insert_extension(InterceptorStartTime(Instant::now()));
6497 Box::pin(async { ControlFlow::Continue })
6498 }
6499
6500 fn after<'a>(
6501 &'a self,
6502 ctx: &'a RequestContext,
6503 req: &'a Request,
6504 response: Response,
6505 ) -> BoxFuture<'a, Response> {
6506 Box::pin(async move {
6507 let start_time = req
6509 .get_extension::<InterceptorStartTime>()
6510 .map(|t| t.0)
6511 .unwrap_or_else(Instant::now);
6512
6513 let interceptor_ctx = ResponseInterceptorContext::new(req, ctx, start_time);
6514 self.interceptor.intercept(&interceptor_ctx, response).await
6515 })
6516 }
6517
6518 fn name(&self) -> &'static str {
6519 self.interceptor.name()
6520 }
6521}
6522
6523#[derive(Debug, Clone, Copy)]
6525struct InterceptorStartTime(Instant);
6526
6527#[derive(Debug, Clone)]
6559pub struct ServerTimingEntry {
6560 name: String,
6562 duration_ms: f64,
6564 description: Option<String>,
6566}
6567
6568impl ServerTimingEntry {
6569 #[must_use]
6571 pub fn new(name: impl Into<String>, duration_ms: f64) -> Self {
6572 Self {
6573 name: name.into(),
6574 duration_ms,
6575 description: None,
6576 }
6577 }
6578
6579 #[must_use]
6581 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
6582 self.description = Some(desc.into());
6583 self
6584 }
6585
6586 #[must_use]
6588 pub fn to_header_value(&self) -> String {
6589 match &self.description {
6590 Some(desc) => format!(
6591 "{};dur={:.3};desc=\"{}\"",
6592 self.name, self.duration_ms, desc
6593 ),
6594 None => format!("{};dur={:.3}", self.name, self.duration_ms),
6595 }
6596 }
6597}
6598
6599#[derive(Debug, Clone, Default)]
6615pub struct ServerTimingBuilder {
6616 entries: Vec<ServerTimingEntry>,
6617}
6618
6619impl ServerTimingBuilder {
6620 #[must_use]
6622 pub fn new() -> Self {
6623 Self::default()
6624 }
6625
6626 #[must_use]
6628 pub fn add(mut self, name: impl Into<String>, duration_ms: f64) -> Self {
6629 self.entries.push(ServerTimingEntry::new(name, duration_ms));
6630 self
6631 }
6632
6633 #[must_use]
6635 pub fn add_with_desc(
6636 mut self,
6637 name: impl Into<String>,
6638 duration_ms: f64,
6639 description: impl Into<String>,
6640 ) -> Self {
6641 self.entries
6642 .push(ServerTimingEntry::new(name, duration_ms).with_description(description));
6643 self
6644 }
6645
6646 #[must_use]
6648 pub fn add_entry(mut self, entry: ServerTimingEntry) -> Self {
6649 self.entries.push(entry);
6650 self
6651 }
6652
6653 #[must_use]
6655 pub fn build(&self) -> String {
6656 self.entries
6657 .iter()
6658 .map(ServerTimingEntry::to_header_value)
6659 .collect::<Vec<_>>()
6660 .join(", ")
6661 }
6662
6663 #[must_use]
6665 pub fn is_empty(&self) -> bool {
6666 self.entries.is_empty()
6667 }
6668
6669 #[must_use]
6671 pub fn len(&self) -> usize {
6672 self.entries.len()
6673 }
6674}
6675
6676#[derive(Debug, Clone)]
6692pub struct TimingMetrics {
6693 pub start_time: Instant,
6695 pub first_byte_time: Option<Instant>,
6697 pub custom_metrics: Vec<(String, f64, Option<String>)>,
6699}
6700
6701impl TimingMetrics {
6702 #[must_use]
6704 pub fn new() -> Self {
6705 Self {
6706 start_time: Instant::now(),
6707 first_byte_time: None,
6708 custom_metrics: Vec::new(),
6709 }
6710 }
6711
6712 #[must_use]
6714 pub fn with_start_time(start_time: Instant) -> Self {
6715 Self {
6716 start_time,
6717 first_byte_time: None,
6718 custom_metrics: Vec::new(),
6719 }
6720 }
6721
6722 pub fn mark_first_byte(&mut self) {
6724 self.first_byte_time = Some(Instant::now());
6725 }
6726
6727 pub fn add_metric(&mut self, name: impl Into<String>, duration_ms: f64) {
6729 self.custom_metrics.push((name.into(), duration_ms, None));
6730 }
6731
6732 pub fn add_metric_with_desc(
6734 &mut self,
6735 name: impl Into<String>,
6736 duration_ms: f64,
6737 desc: impl Into<String>,
6738 ) {
6739 self.custom_metrics
6740 .push((name.into(), duration_ms, Some(desc.into())));
6741 }
6742
6743 #[must_use]
6745 pub fn total_ms(&self) -> f64 {
6746 self.start_time.elapsed().as_secs_f64() * 1000.0
6747 }
6748
6749 #[must_use]
6751 pub fn ttfb_ms(&self) -> Option<f64> {
6752 self.first_byte_time
6753 .map(|t| t.duration_since(self.start_time).as_secs_f64() * 1000.0)
6754 }
6755
6756 #[must_use]
6758 pub fn to_server_timing(&self) -> ServerTimingBuilder {
6759 let mut builder = ServerTimingBuilder::new().add_with_desc(
6760 "total",
6761 self.total_ms(),
6762 "Total request time",
6763 );
6764
6765 if let Some(ttfb) = self.ttfb_ms() {
6766 builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
6767 }
6768
6769 for (name, duration, desc) in &self.custom_metrics {
6770 match desc {
6771 Some(d) => builder = builder.add_with_desc(name, *duration, d),
6772 None => builder = builder.add(name, *duration),
6773 }
6774 }
6775
6776 builder
6777 }
6778}
6779
6780impl Default for TimingMetrics {
6781 fn default() -> Self {
6782 Self::new()
6783 }
6784}
6785
6786#[derive(Debug, Clone)]
6788#[allow(clippy::struct_excessive_bools)]
6789pub struct TimingMetricsConfig {
6790 pub add_server_timing_header: bool,
6792 pub add_response_time_header: bool,
6794 pub response_time_header_name: String,
6796 pub include_custom_metrics: bool,
6798 pub include_ttfb: bool,
6800}
6801
6802impl Default for TimingMetricsConfig {
6803 fn default() -> Self {
6804 Self {
6805 add_server_timing_header: true,
6806 add_response_time_header: true,
6807 response_time_header_name: "X-Response-Time".to_string(),
6808 include_custom_metrics: true,
6809 include_ttfb: true,
6810 }
6811 }
6812}
6813
6814impl TimingMetricsConfig {
6815 #[must_use]
6817 pub fn new() -> Self {
6818 Self::default()
6819 }
6820
6821 #[must_use]
6823 pub fn server_timing(mut self, enabled: bool) -> Self {
6824 self.add_server_timing_header = enabled;
6825 self
6826 }
6827
6828 #[must_use]
6830 pub fn response_time(mut self, enabled: bool) -> Self {
6831 self.add_response_time_header = enabled;
6832 self
6833 }
6834
6835 #[must_use]
6837 pub fn response_time_header(mut self, name: impl Into<String>) -> Self {
6838 self.response_time_header_name = name.into();
6839 self
6840 }
6841
6842 #[must_use]
6844 pub fn custom_metrics(mut self, enabled: bool) -> Self {
6845 self.include_custom_metrics = enabled;
6846 self
6847 }
6848
6849 #[must_use]
6851 pub fn ttfb(mut self, enabled: bool) -> Self {
6852 self.include_ttfb = enabled;
6853 self
6854 }
6855
6856 #[must_use]
6858 pub fn production() -> Self {
6859 Self {
6860 add_server_timing_header: false,
6861 add_response_time_header: true,
6862 response_time_header_name: "X-Response-Time".to_string(),
6863 include_custom_metrics: false,
6864 include_ttfb: false,
6865 }
6866 }
6867
6868 #[must_use]
6870 pub fn development() -> Self {
6871 Self::default()
6872 }
6873}
6874
6875#[derive(Debug, Clone)]
6894pub struct TimingMetricsMiddleware {
6895 config: TimingMetricsConfig,
6896}
6897
6898impl TimingMetricsMiddleware {
6899 #[must_use]
6901 pub fn new() -> Self {
6902 Self {
6903 config: TimingMetricsConfig::default(),
6904 }
6905 }
6906
6907 #[must_use]
6909 pub fn with_config(config: TimingMetricsConfig) -> Self {
6910 Self { config }
6911 }
6912
6913 #[must_use]
6915 pub fn production() -> Self {
6916 Self {
6917 config: TimingMetricsConfig::production(),
6918 }
6919 }
6920
6921 #[must_use]
6923 pub fn development() -> Self {
6924 Self {
6925 config: TimingMetricsConfig::development(),
6926 }
6927 }
6928}
6929
6930impl Default for TimingMetricsMiddleware {
6931 fn default() -> Self {
6932 Self::new()
6933 }
6934}
6935
6936impl Middleware for TimingMetricsMiddleware {
6937 fn before<'a>(
6938 &'a self,
6939 _ctx: &'a RequestContext,
6940 req: &'a mut Request,
6941 ) -> BoxFuture<'a, ControlFlow> {
6942 req.insert_extension(TimingMetrics::new());
6944 Box::pin(async { ControlFlow::Continue })
6945 }
6946
6947 fn after<'a>(
6948 &'a self,
6949 _ctx: &'a RequestContext,
6950 req: &'a Request,
6951 response: Response,
6952 ) -> BoxFuture<'a, Response> {
6953 let config = self.config.clone();
6954
6955 Box::pin(async move {
6956 let mut resp = response;
6957
6958 let metrics = req.get_extension::<TimingMetrics>();
6960
6961 match metrics {
6962 Some(metrics) => {
6963 if config.add_response_time_header {
6965 let timing = format!("{:.3}ms", metrics.total_ms());
6966 resp = resp.header(&config.response_time_header_name, timing.into_bytes());
6967 }
6968
6969 if config.add_server_timing_header {
6971 let mut builder = ServerTimingBuilder::new().add_with_desc(
6972 "total",
6973 metrics.total_ms(),
6974 "Total request time",
6975 );
6976
6977 if config.include_ttfb {
6979 if let Some(ttfb) = metrics.ttfb_ms() {
6980 builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
6981 }
6982 }
6983
6984 if config.include_custom_metrics {
6986 for (name, duration, desc) in &metrics.custom_metrics {
6987 match desc {
6988 Some(d) => builder = builder.add_with_desc(name, *duration, d),
6989 None => builder = builder.add(name, *duration),
6990 }
6991 }
6992 }
6993
6994 let header_value = builder.build();
6995 resp = resp.header("Server-Timing", header_value.into_bytes());
6996 }
6997 }
6998 None => {
6999 if config.add_response_time_header {
7002 resp = resp.header(&config.response_time_header_name, b"0.000ms".to_vec());
7003 }
7004 }
7005 }
7006
7007 resp
7008 })
7009 }
7010
7011 fn name(&self) -> &'static str {
7012 "TimingMetrics"
7013 }
7014}
7015
7016#[derive(Debug, Clone)]
7020pub struct TimingHistogramBucket {
7021 pub le: f64,
7023 pub count: u64,
7025}
7026
7027#[derive(Debug, Clone)]
7046pub struct TimingHistogram {
7047 bucket_bounds: Vec<f64>,
7049 bucket_counts: Vec<u64>,
7051 sum: f64,
7053 count: u64,
7055}
7056
7057impl TimingHistogram {
7058 #[must_use]
7062 pub fn with_buckets(bucket_bounds: Vec<f64>) -> Self {
7063 let bucket_counts = vec![0; bucket_bounds.len()];
7064 Self {
7065 bucket_bounds,
7066 bucket_counts,
7067 sum: 0.0,
7068 count: 0,
7069 }
7070 }
7071
7072 #[must_use]
7076 pub fn http_latency() -> Self {
7077 Self::with_buckets(vec![
7078 1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0, 2500.0, 5000.0, 10000.0,
7079 ])
7080 }
7081
7082 pub fn observe(&mut self, value_ms: f64) {
7084 self.sum += value_ms;
7085 self.count += 1;
7086
7087 for (i, bound) in self.bucket_bounds.iter().enumerate() {
7089 if value_ms <= *bound {
7090 self.bucket_counts[i] += 1;
7091 }
7092 }
7093 }
7094
7095 #[must_use]
7097 pub fn count(&self) -> u64 {
7098 self.count
7099 }
7100
7101 #[must_use]
7103 pub fn sum(&self) -> f64 {
7104 self.sum
7105 }
7106
7107 #[must_use]
7109 pub fn mean(&self) -> f64 {
7110 if self.count == 0 {
7111 0.0
7112 } else {
7113 #[allow(clippy::cast_precision_loss)]
7114 {
7115 self.sum / self.count as f64
7116 }
7117 }
7118 }
7119
7120 #[must_use]
7122 pub fn buckets(&self) -> Vec<TimingHistogramBucket> {
7123 self.bucket_bounds
7124 .iter()
7125 .zip(&self.bucket_counts)
7126 .map(|(&le, &count)| TimingHistogramBucket { le, count })
7127 .collect()
7128 }
7129
7130 pub fn reset(&mut self) {
7132 self.sum = 0.0;
7133 self.count = 0;
7134 for count in &mut self.bucket_counts {
7135 *count = 0;
7136 }
7137 }
7138}
7139
7140impl Default for TimingHistogram {
7141 fn default() -> Self {
7142 Self::http_latency()
7143 }
7144}
7145
7146#[cfg(test)]
7151mod timing_metrics_tests {
7152 use super::*;
7153 use crate::request::Method;
7154 use crate::response::StatusCode;
7155
7156 fn test_context() -> RequestContext {
7157 RequestContext::new(asupersync::Cx::for_testing(), 1)
7158 }
7159
7160 fn test_request() -> Request {
7161 Request::new(Method::Get, "/test")
7162 }
7163
7164 fn run_middleware_before(mw: &impl Middleware, req: &mut Request) -> ControlFlow {
7165 let ctx = test_context();
7166 futures_executor::block_on(mw.before(&ctx, req))
7167 }
7168
7169 fn run_middleware_after(mw: &impl Middleware, req: &Request, resp: Response) -> Response {
7170 let ctx = test_context();
7171 futures_executor::block_on(mw.after(&ctx, req, resp))
7172 }
7173
7174 #[test]
7175 fn server_timing_entry_basic() {
7176 let entry = ServerTimingEntry::new("db", 42.5);
7177 assert_eq!(entry.to_header_value(), "db;dur=42.500");
7178 }
7179
7180 #[test]
7181 fn server_timing_entry_with_description() {
7182 let entry = ServerTimingEntry::new("db", 42.5).with_description("Database query");
7183 assert_eq!(
7184 entry.to_header_value(),
7185 "db;dur=42.500;desc=\"Database query\""
7186 );
7187 }
7188
7189 #[test]
7190 fn server_timing_builder_single_entry() {
7191 let timing = ServerTimingBuilder::new().add("total", 150.0).build();
7192 assert_eq!(timing, "total;dur=150.000");
7193 }
7194
7195 #[test]
7196 fn server_timing_builder_multiple_entries() {
7197 let timing = ServerTimingBuilder::new()
7198 .add("total", 150.0)
7199 .add_with_desc("db", 42.0, "Database")
7200 .add("cache", 5.0)
7201 .build();
7202
7203 assert!(timing.contains("total;dur=150.000"));
7204 assert!(timing.contains("db;dur=42.000;desc=\"Database\""));
7205 assert!(timing.contains("cache;dur=5.000"));
7206 assert!(timing.contains(", ")); }
7208
7209 #[test]
7210 fn server_timing_builder_empty() {
7211 let builder = ServerTimingBuilder::new();
7212 assert!(builder.is_empty());
7213 assert_eq!(builder.len(), 0);
7214 assert_eq!(builder.build(), "");
7215 }
7216
7217 #[test]
7218 fn timing_metrics_basic() {
7219 let metrics = TimingMetrics::new();
7220 std::thread::sleep(std::time::Duration::from_millis(5));
7221
7222 let total = metrics.total_ms();
7223 assert!(total >= 5.0, "Total should be at least 5ms");
7224 assert!(metrics.ttfb_ms().is_none(), "TTFB should not be set");
7225 }
7226
7227 #[test]
7228 fn timing_metrics_custom_metrics() {
7229 let mut metrics = TimingMetrics::new();
7230 metrics.add_metric("db", 42.5);
7231 metrics.add_metric_with_desc("cache", 5.0, "Cache lookup");
7232
7233 let timing = metrics.to_server_timing();
7234 assert_eq!(timing.len(), 3); let header = timing.build();
7237 assert!(header.contains("total"));
7238 assert!(header.contains("db;dur=42.500"));
7239 assert!(header.contains("cache;dur=5.000;desc=\"Cache lookup\""));
7240 }
7241
7242 #[test]
7243 fn timing_metrics_ttfb() {
7244 let mut metrics = TimingMetrics::new();
7245 std::thread::sleep(std::time::Duration::from_millis(5));
7246 metrics.mark_first_byte();
7247
7248 let ttfb = metrics.ttfb_ms().unwrap();
7249 assert!(ttfb >= 5.0, "TTFB should be at least 5ms");
7250 }
7251
7252 #[test]
7253 fn timing_metrics_config_default() {
7254 let config = TimingMetricsConfig::default();
7255 assert!(config.add_server_timing_header);
7256 assert!(config.add_response_time_header);
7257 assert!(config.include_custom_metrics);
7258 assert!(config.include_ttfb);
7259 }
7260
7261 #[test]
7262 fn timing_metrics_config_production() {
7263 let config = TimingMetricsConfig::production();
7264 assert!(!config.add_server_timing_header);
7265 assert!(config.add_response_time_header);
7266 assert!(!config.include_custom_metrics);
7267 }
7268
7269 #[test]
7270 fn timing_middleware_adds_metrics_to_request() {
7271 let mw = TimingMetricsMiddleware::new();
7272 let mut req = test_request();
7273
7274 let result = run_middleware_before(&mw, &mut req);
7276 assert!(result.is_continue());
7277
7278 let metrics = req.get_extension::<TimingMetrics>();
7279 assert!(metrics.is_some(), "TimingMetrics should be in extensions");
7280 }
7281
7282 #[test]
7283 fn timing_middleware_adds_response_time_header() {
7284 let mw = TimingMetricsMiddleware::new();
7285 let mut req = test_request();
7286
7287 run_middleware_before(&mw, &mut req);
7289
7290 let resp = Response::with_status(StatusCode::OK);
7291 let result = run_middleware_after(&mw, &req, resp);
7292
7293 let has_timing = result
7294 .headers()
7295 .iter()
7296 .any(|(name, _)| name == "X-Response-Time");
7297 assert!(has_timing, "Should have X-Response-Time header");
7298 }
7299
7300 #[test]
7301 fn timing_middleware_adds_server_timing_header() {
7302 let mw = TimingMetricsMiddleware::new();
7303 let mut req = test_request();
7304
7305 run_middleware_before(&mw, &mut req);
7306
7307 let resp = Response::with_status(StatusCode::OK);
7308 let result = run_middleware_after(&mw, &req, resp);
7309
7310 let server_timing = result
7311 .headers()
7312 .iter()
7313 .find(|(name, _)| name == "Server-Timing")
7314 .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7315
7316 assert!(server_timing.is_some(), "Should have Server-Timing header");
7317 let header = server_timing.unwrap();
7318 assert!(header.contains("total"), "Should have total timing");
7319 }
7320
7321 #[test]
7322 fn timing_middleware_production_mode() {
7323 let mw = TimingMetricsMiddleware::production();
7324 let mut req = test_request();
7325
7326 run_middleware_before(&mw, &mut req);
7327
7328 let resp = Response::with_status(StatusCode::OK);
7329 let result = run_middleware_after(&mw, &req, resp);
7330
7331 let has_response_time = result
7333 .headers()
7334 .iter()
7335 .any(|(name, _)| name == "X-Response-Time");
7336 assert!(has_response_time);
7337
7338 let has_server_timing = result
7340 .headers()
7341 .iter()
7342 .any(|(name, _)| name == "Server-Timing");
7343 assert!(!has_server_timing);
7344 }
7345
7346 #[test]
7347 #[allow(clippy::float_cmp)]
7348 fn timing_histogram_basic() {
7349 let mut histogram = TimingHistogram::http_latency();
7350 assert_eq!(histogram.count(), 0);
7351 assert_eq!(histogram.sum(), 0.0);
7352
7353 histogram.observe(42.0);
7354 histogram.observe(150.0);
7355 histogram.observe(5.0);
7356
7357 assert_eq!(histogram.count(), 3);
7358 assert_eq!(histogram.sum(), 197.0);
7359 assert!((histogram.mean() - 65.666).abs() < 0.01);
7360 }
7361
7362 #[test]
7363 fn timing_histogram_buckets() {
7364 let mut histogram = TimingHistogram::with_buckets(vec![10.0, 50.0, 100.0]);
7365
7366 histogram.observe(5.0); histogram.observe(25.0); histogram.observe(75.0); histogram.observe(150.0); let buckets = histogram.buckets();
7372 assert_eq!(buckets.len(), 3);
7373
7374 assert_eq!(buckets[0].count, 1); assert_eq!(buckets[1].count, 2); assert_eq!(buckets[2].count, 3); }
7379
7380 #[test]
7381 #[allow(clippy::float_cmp)]
7382 fn timing_histogram_reset() {
7383 let mut histogram = TimingHistogram::http_latency();
7384 histogram.observe(100.0);
7385 histogram.observe(200.0);
7386
7387 assert_eq!(histogram.count(), 2);
7388
7389 histogram.reset();
7390
7391 assert_eq!(histogram.count(), 0);
7392 assert_eq!(histogram.sum(), 0.0);
7393 }
7394}
7395
7396#[cfg(test)]
7397mod response_interceptor_tests {
7398 use super::*;
7399 use crate::request::Method;
7400 use crate::response::StatusCode;
7401
7402 fn test_context() -> RequestContext {
7403 RequestContext::new(asupersync::Cx::for_testing(), 1)
7404 }
7405
7406 fn test_request() -> Request {
7407 Request::new(Method::Get, "/test")
7408 }
7409
7410 fn run_interceptor<I: ResponseInterceptor>(
7411 interceptor: &I,
7412 req: &Request,
7413 resp: Response,
7414 ) -> Response {
7415 let ctx = test_context();
7416 let start_time = Instant::now();
7417 let interceptor_ctx = ResponseInterceptorContext::new(req, &ctx, start_time);
7418 futures_executor::block_on(interceptor.intercept(&interceptor_ctx, resp))
7419 }
7420
7421 #[test]
7422 fn timing_interceptor_adds_header() {
7423 let interceptor = TimingInterceptor::new();
7424 let req = test_request();
7425 let resp = Response::with_status(StatusCode::OK);
7426
7427 let result = run_interceptor(&interceptor, &req, resp);
7428
7429 let has_timing = result
7430 .headers()
7431 .iter()
7432 .any(|(name, _)| name == "X-Response-Time");
7433 assert!(has_timing, "Should have X-Response-Time header");
7434 }
7435
7436 #[test]
7437 fn timing_interceptor_with_server_timing() {
7438 let interceptor = TimingInterceptor::new().with_server_timing("app");
7439 let req = test_request();
7440 let resp = Response::with_status(StatusCode::OK);
7441
7442 let result = run_interceptor(&interceptor, &req, resp);
7443
7444 let has_server_timing = result
7445 .headers()
7446 .iter()
7447 .any(|(name, _)| name == "Server-Timing");
7448 assert!(has_server_timing, "Should have Server-Timing header");
7449 }
7450
7451 #[test]
7452 fn timing_interceptor_custom_header_name() {
7453 let interceptor = TimingInterceptor::new().header_name("X-Custom-Time");
7454 let req = test_request();
7455 let resp = Response::with_status(StatusCode::OK);
7456
7457 let result = run_interceptor(&interceptor, &req, resp);
7458
7459 let has_custom = result
7460 .headers()
7461 .iter()
7462 .any(|(name, _)| name == "X-Custom-Time");
7463 assert!(has_custom, "Should have X-Custom-Time header");
7464 }
7465
7466 #[test]
7467 fn debug_info_interceptor_adds_headers() {
7468 let interceptor = DebugInfoInterceptor::new();
7469 let req = test_request();
7470 let resp = Response::with_status(StatusCode::OK);
7471
7472 let result = run_interceptor(&interceptor, &req, resp);
7473
7474 let has_path = result
7475 .headers()
7476 .iter()
7477 .any(|(name, _)| name == "X-Debug-Path");
7478 let has_method = result
7479 .headers()
7480 .iter()
7481 .any(|(name, _)| name == "X-Debug-Method");
7482 let has_timing = result
7483 .headers()
7484 .iter()
7485 .any(|(name, _)| name == "X-Debug-Handler-Time");
7486
7487 assert!(has_path, "Should have X-Debug-Path header");
7488 assert!(has_method, "Should have X-Debug-Method header");
7489 assert!(has_timing, "Should have X-Debug-Handler-Time header");
7490 }
7491
7492 #[test]
7493 fn debug_info_interceptor_custom_prefix() {
7494 let interceptor = DebugInfoInterceptor::new().header_prefix("X-Trace-");
7495 let req = test_request();
7496 let resp = Response::with_status(StatusCode::OK);
7497
7498 let result = run_interceptor(&interceptor, &req, resp);
7499
7500 let has_trace_path = result
7501 .headers()
7502 .iter()
7503 .any(|(name, _)| name == "X-Trace-Path");
7504 assert!(has_trace_path, "Should have X-Trace-Path header");
7505 }
7506
7507 #[test]
7508 fn debug_info_interceptor_selective_options() {
7509 let interceptor = DebugInfoInterceptor::new()
7510 .include_path(true)
7511 .include_method(false)
7512 .include_timing(false)
7513 .include_request_id(false);
7514 let req = test_request();
7515 let resp = Response::with_status(StatusCode::OK);
7516
7517 let result = run_interceptor(&interceptor, &req, resp);
7518
7519 let has_path = result
7520 .headers()
7521 .iter()
7522 .any(|(name, _)| name == "X-Debug-Path");
7523 let has_method = result
7524 .headers()
7525 .iter()
7526 .any(|(name, _)| name == "X-Debug-Method");
7527
7528 assert!(has_path, "Should have X-Debug-Path header");
7529 assert!(!has_method, "Should NOT have X-Debug-Method header");
7530 }
7531
7532 #[test]
7533 fn header_transform_adds_headers() {
7534 let interceptor = HeaderTransformInterceptor::new()
7535 .add("X-Powered-By", b"fastapi_rust".to_vec())
7536 .add("X-Version", b"1.0".to_vec());
7537 let req = test_request();
7538 let resp = Response::with_status(StatusCode::OK);
7539
7540 let result = run_interceptor(&interceptor, &req, resp);
7541
7542 let has_powered_by = result
7543 .headers()
7544 .iter()
7545 .any(|(name, _)| name == "X-Powered-By");
7546 let has_version = result.headers().iter().any(|(name, _)| name == "X-Version");
7547
7548 assert!(has_powered_by, "Should have X-Powered-By header");
7549 assert!(has_version, "Should have X-Version header");
7550 }
7551
7552 #[test]
7553 fn response_body_transform_modifies_body() {
7554 let transformer = ResponseBodyTransform::new(|body| {
7555 let mut result = b"[".to_vec();
7556 result.extend_from_slice(&body);
7557 result.extend_from_slice(b"]");
7558 result
7559 });
7560 let req = test_request();
7561 let resp = Response::with_status(StatusCode::OK)
7562 .body(crate::response::ResponseBody::Bytes(b"hello".to_vec()));
7563
7564 let result = run_interceptor(&transformer, &req, resp);
7565
7566 match result.body_ref() {
7567 crate::response::ResponseBody::Bytes(b) => {
7568 assert_eq!(b, b"[hello]");
7569 }
7570 _ => panic!("Expected bytes body"),
7571 }
7572 }
7573
7574 #[test]
7575 fn response_body_transform_with_content_type_filter() {
7576 let transformer =
7577 ResponseBodyTransform::new(|_| b"transformed".to_vec()).for_content_type("text/plain");
7578 let req = test_request();
7579
7580 let json_resp = Response::with_status(StatusCode::OK)
7582 .header("content-type", b"application/json".to_vec())
7583 .body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
7584
7585 let result = run_interceptor(&transformer, &req, json_resp);
7586
7587 match result.body_ref() {
7588 crate::response::ResponseBody::Bytes(b) => {
7589 assert_eq!(b, b"original", "JSON should not be transformed");
7590 }
7591 _ => panic!("Expected bytes body"),
7592 }
7593
7594 let text_resp = Response::with_status(StatusCode::OK)
7596 .header("content-type", b"text/plain".to_vec())
7597 .body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
7598
7599 let result = run_interceptor(&transformer, &req, text_resp);
7600
7601 match result.body_ref() {
7602 crate::response::ResponseBody::Bytes(b) => {
7603 assert_eq!(b, b"transformed", "Text should be transformed");
7604 }
7605 _ => panic!("Expected bytes body"),
7606 }
7607 }
7608
7609 #[test]
7610 fn error_response_transformer_hides_details() {
7611 let transformer = ErrorResponseTransformer::new()
7612 .hide_details_for_status(StatusCode::INTERNAL_SERVER_ERROR)
7613 .with_replacement_body(b"An error occurred");
7614
7615 let req = test_request();
7616
7617 let error_resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR).body(
7619 crate::response::ResponseBody::Bytes(b"Sensitive error details".to_vec()),
7620 );
7621
7622 let result = run_interceptor(&transformer, &req, error_resp);
7623
7624 match result.body_ref() {
7625 crate::response::ResponseBody::Bytes(b) => {
7626 assert_eq!(b, b"An error occurred");
7627 }
7628 _ => panic!("Expected bytes body"),
7629 }
7630
7631 let ok_resp = Response::with_status(StatusCode::OK)
7633 .body(crate::response::ResponseBody::Bytes(b"Success".to_vec()));
7634
7635 let result = run_interceptor(&transformer, &req, ok_resp);
7636
7637 match result.body_ref() {
7638 crate::response::ResponseBody::Bytes(b) => {
7639 assert_eq!(b, b"Success");
7640 }
7641 _ => panic!("Expected bytes body"),
7642 }
7643 }
7644
7645 #[test]
7646 fn response_interceptor_stack_chains_interceptors() {
7647 let mut stack = ResponseInterceptorStack::new();
7648 stack.push(TimingInterceptor::new());
7649 stack.push(HeaderTransformInterceptor::new().add("X-Extra", b"value".to_vec()));
7650
7651 let req = test_request();
7652 let resp = Response::with_status(StatusCode::OK);
7653
7654 let ctx = test_context();
7655 let start_time = Instant::now();
7656 let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7657 let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
7658
7659 let has_timing = result
7660 .headers()
7661 .iter()
7662 .any(|(name, _)| name == "X-Response-Time");
7663 let has_extra = result.headers().iter().any(|(name, _)| name == "X-Extra");
7664
7665 assert!(
7666 has_timing,
7667 "Should have timing header from first interceptor"
7668 );
7669 assert!(
7670 has_extra,
7671 "Should have extra header from second interceptor"
7672 );
7673 }
7674
7675 #[test]
7676 fn response_interceptor_stack_empty_is_noop() {
7677 let stack = ResponseInterceptorStack::new();
7678 assert!(stack.is_empty());
7679 assert_eq!(stack.len(), 0);
7680
7681 let req = test_request();
7682 let resp = Response::with_status(StatusCode::OK)
7683 .body(crate::response::ResponseBody::Bytes(b"unchanged".to_vec()));
7684
7685 let ctx = test_context();
7686 let start_time = Instant::now();
7687 let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7688 let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
7689
7690 match result.body_ref() {
7691 crate::response::ResponseBody::Bytes(b) => {
7692 assert_eq!(b, b"unchanged");
7693 }
7694 _ => panic!("Expected bytes body"),
7695 }
7696 }
7697
7698 #[test]
7699 fn interceptor_context_provides_timing() {
7700 let ctx = test_context();
7701 let req = test_request();
7702 let start_time = Instant::now();
7703 std::thread::sleep(std::time::Duration::from_millis(5));
7704
7705 let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7706
7707 assert!(
7708 interceptor_ctx.elapsed_ms() >= 5,
7709 "Elapsed time should be at least 5ms"
7710 );
7711 assert!(interceptor_ctx.elapsed().as_millis() >= 5);
7712 }
7713
7714 #[test]
7715 fn conditional_interceptor_applies_conditionally() {
7716 let inner = HeaderTransformInterceptor::new().add("X-Success", b"true".to_vec());
7718 let conditional =
7719 ConditionalInterceptor::new(inner, |_ctx, resp| resp.status().as_u16() == 200);
7720
7721 let req = test_request();
7722
7723 let ok_resp = Response::with_status(StatusCode::OK);
7725 let result = run_interceptor(&conditional, &req, ok_resp);
7726 let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
7727 assert!(has_success, "200 response should get X-Success header");
7728
7729 let not_found = Response::with_status(StatusCode::NOT_FOUND);
7731 let result = run_interceptor(&conditional, &req, not_found);
7732 let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
7733 assert!(!has_success, "404 response should NOT get X-Success header");
7734 }
7735}
7736
7737#[cfg(test)]
7738mod cache_control_tests {
7739 use super::*;
7740 use crate::request::Method;
7741 use crate::response::StatusCode;
7742
7743 fn test_context() -> RequestContext {
7744 RequestContext::new(asupersync::Cx::for_testing(), 1)
7745 }
7746
7747 fn run_after(mw: &CacheControlMiddleware, req: &Request, resp: Response) -> Response {
7748 let ctx = test_context();
7749 let fut = mw.after(&ctx, req, resp);
7750 futures_executor::block_on(fut)
7751 }
7752
7753 #[test]
7754 fn cache_directive_as_str_works() {
7755 assert_eq!(CacheDirective::Public.as_str(), "public");
7756 assert_eq!(CacheDirective::Private.as_str(), "private");
7757 assert_eq!(CacheDirective::NoStore.as_str(), "no-store");
7758 assert_eq!(CacheDirective::NoCache.as_str(), "no-cache");
7759 assert_eq!(CacheDirective::MustRevalidate.as_str(), "must-revalidate");
7760 assert_eq!(CacheDirective::Immutable.as_str(), "immutable");
7761 }
7762
7763 #[test]
7764 fn cache_control_builder_basic() {
7765 let cc = CacheControlBuilder::new()
7766 .public()
7767 .max_age_secs(3600)
7768 .build();
7769 assert!(cc.contains("public"));
7770 assert!(cc.contains("max-age=3600"));
7771 }
7772
7773 #[test]
7774 fn cache_control_builder_complex() {
7775 let cc = CacheControlBuilder::new()
7776 .public()
7777 .max_age_secs(60)
7778 .s_maxage_secs(3600)
7779 .stale_while_revalidate_secs(86400)
7780 .build();
7781 assert!(cc.contains("public"));
7782 assert!(cc.contains("max-age=60"));
7783 assert!(cc.contains("s-maxage=3600"));
7784 assert!(cc.contains("stale-while-revalidate=86400"));
7785 }
7786
7787 #[test]
7788 fn cache_control_builder_no_cache() {
7789 let cc = CacheControlBuilder::new()
7790 .no_store()
7791 .no_cache()
7792 .must_revalidate()
7793 .build();
7794 assert!(cc.contains("no-store"));
7795 assert!(cc.contains("no-cache"));
7796 assert!(cc.contains("must-revalidate"));
7797 }
7798
7799 #[test]
7800 fn cache_preset_no_cache() {
7801 let value = CachePreset::NoCache.to_header_value();
7802 assert!(value.contains("no-store"));
7803 assert!(value.contains("no-cache"));
7804 assert!(value.contains("must-revalidate"));
7805 }
7806
7807 #[test]
7808 fn cache_preset_immutable() {
7809 let value = CachePreset::Immutable.to_header_value();
7810 assert!(value.contains("public"));
7811 assert!(value.contains("max-age=31536000"));
7812 assert!(value.contains("immutable"));
7813 }
7814
7815 #[test]
7816 fn cache_preset_static_assets() {
7817 let value = CachePreset::StaticAssets.to_header_value();
7818 assert!(value.contains("public"));
7819 assert!(value.contains("max-age=86400"));
7820 }
7821
7822 #[test]
7823 fn middleware_adds_cache_control_header() {
7824 let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7825 let req = Request::new(Method::Get, "/api/test");
7826 let resp = Response::with_status(StatusCode::OK);
7827
7828 let result = run_after(&mw, &req, resp);
7829 let headers = result.headers();
7830 let cc_header = headers
7831 .iter()
7832 .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7833 assert!(
7834 cc_header.is_some(),
7835 "Cache-Control header should be present"
7836 );
7837 let (_, value) = cc_header.unwrap();
7838 let value_str = String::from_utf8_lossy(value);
7839 assert!(value_str.contains("public"));
7840 assert!(value_str.contains("max-age=3600"));
7841 }
7842
7843 #[test]
7844 fn middleware_skips_post_requests() {
7845 let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7846 let req = Request::new(Method::Post, "/api/test");
7847 let resp = Response::with_status(StatusCode::OK);
7848
7849 let result = run_after(&mw, &req, resp);
7850 let headers = result.headers();
7851 let cc_header = headers
7852 .iter()
7853 .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7854 assert!(
7855 cc_header.is_none(),
7856 "Cache-Control should not be added for POST"
7857 );
7858 }
7859
7860 #[test]
7861 fn middleware_skips_error_responses() {
7862 let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7863 let req = Request::new(Method::Get, "/api/test");
7864 let resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
7865
7866 let result = run_after(&mw, &req, resp);
7867 let headers = result.headers();
7868 let cc_header = headers
7869 .iter()
7870 .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7871 assert!(
7872 cc_header.is_none(),
7873 "Cache-Control should not be added for error responses"
7874 );
7875 }
7876
7877 #[test]
7878 fn middleware_with_vary_header() {
7879 let mw = CacheControlMiddleware::with_config(
7880 CacheControlConfig::from_preset(CachePreset::PublicOneHour)
7881 .vary("Accept-Encoding")
7882 .vary("Accept-Language"),
7883 );
7884 let req = Request::new(Method::Get, "/api/test");
7885 let resp = Response::with_status(StatusCode::OK);
7886
7887 let result = run_after(&mw, &req, resp);
7888 let headers = result.headers();
7889 let vary_header = headers
7890 .iter()
7891 .find(|(name, _)| name.eq_ignore_ascii_case("vary"));
7892 assert!(vary_header.is_some(), "Vary header should be present");
7893 let (_, value) = vary_header.unwrap();
7894 let value_str = String::from_utf8_lossy(value);
7895 assert!(value_str.contains("Accept-Encoding"));
7896 assert!(value_str.contains("Accept-Language"));
7897 }
7898
7899 #[test]
7900 fn middleware_preserves_existing_cache_control() {
7901 let mw = CacheControlMiddleware::with_config(
7902 CacheControlConfig::from_preset(CachePreset::PublicOneHour).preserve_existing(true),
7903 );
7904 let req = Request::new(Method::Get, "/api/test");
7905 let resp =
7906 Response::with_status(StatusCode::OK).header("Cache-Control", b"max-age=60".to_vec());
7907
7908 let result = run_after(&mw, &req, resp);
7909 let headers = result.headers();
7910 let cc_headers: Vec<_> = headers
7911 .iter()
7912 .filter(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
7913 .collect();
7914 assert_eq!(cc_headers.len(), 1);
7916 let (_, value) = cc_headers[0];
7917 let value_str = String::from_utf8_lossy(value);
7918 assert_eq!(value_str, "max-age=60");
7919 }
7920
7921 #[test]
7922 fn path_pattern_matching_exact() {
7923 assert!(path_matches_pattern("/api/users", "/api/users"));
7924 assert!(!path_matches_pattern("/api/users", "/api/items"));
7925 }
7926
7927 #[test]
7928 fn path_pattern_matching_wildcard() {
7929 assert!(path_matches_pattern("/api/users/123", "/api/users/*"));
7930 assert!(path_matches_pattern("/static/css/style.css", "/static/*"));
7931 assert!(path_matches_pattern("/anything", "*"));
7932 }
7933
7934 #[test]
7935 fn date_formatting_works() {
7936 let now = std::time::SystemTime::now();
7938 let formatted = format_http_date(now);
7939 assert!(formatted.ends_with(" GMT"));
7941 let days = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
7943 assert!(days.iter().any(|d| formatted.starts_with(d)));
7944 }
7945
7946 #[test]
7947 fn leap_year_detection() {
7948 assert!(!is_leap_year(1900)); assert!(is_leap_year(2000)); assert!(is_leap_year(2024)); assert!(!is_leap_year(2023)); }
7953}
7954
7955#[cfg(test)]
7960mod trace_rejection_tests {
7961 use super::*;
7962 use crate::request::Method;
7963 use crate::response::StatusCode;
7964
7965 fn test_context() -> RequestContext {
7966 RequestContext::new(asupersync::Cx::for_testing(), 1)
7967 }
7968
7969 fn run_before(mw: &TraceRejectionMiddleware, req: &mut Request) -> ControlFlow {
7970 let ctx = test_context();
7971 let fut = mw.before(&ctx, req);
7972 futures_executor::block_on(fut)
7973 }
7974
7975 fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
7976 headers
7977 .iter()
7978 .find(|(n, _)| n.eq_ignore_ascii_case(name))
7979 .map(|(_, v)| v.as_slice())
7980 }
7981
7982 #[test]
7983 fn trace_request_rejected() {
7984 let mw = TraceRejectionMiddleware::new();
7985 let mut req = Request::new(Method::Trace, "/");
7986
7987 let result = run_before(&mw, &mut req);
7988
7989 match result {
7990 ControlFlow::Break(response) => {
7991 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
7992 }
7993 ControlFlow::Continue => panic!("TRACE request should have been rejected"),
7994 }
7995 }
7996
7997 #[test]
7998 fn trace_request_with_path() {
7999 let mw = TraceRejectionMiddleware::new();
8000 let mut req = Request::new(Method::Trace, "/api/users/123");
8001
8002 let result = run_before(&mw, &mut req);
8003
8004 match result {
8005 ControlFlow::Break(response) => {
8006 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
8007 }
8008 ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8009 }
8010 }
8011
8012 #[test]
8013 fn get_request_allowed() {
8014 let mw = TraceRejectionMiddleware::new();
8015 let mut req = Request::new(Method::Get, "/");
8016
8017 let result = run_before(&mw, &mut req);
8018
8019 match result {
8020 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("GET request should be allowed"),
8022 }
8023 }
8024
8025 #[test]
8026 fn post_request_allowed() {
8027 let mw = TraceRejectionMiddleware::new();
8028 let mut req = Request::new(Method::Post, "/api/users");
8029
8030 let result = run_before(&mw, &mut req);
8031
8032 match result {
8033 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("POST request should be allowed"),
8035 }
8036 }
8037
8038 #[test]
8039 fn put_request_allowed() {
8040 let mw = TraceRejectionMiddleware::new();
8041 let mut req = Request::new(Method::Put, "/api/users/1");
8042
8043 let result = run_before(&mw, &mut req);
8044
8045 match result {
8046 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("PUT request should be allowed"),
8048 }
8049 }
8050
8051 #[test]
8052 fn delete_request_allowed() {
8053 let mw = TraceRejectionMiddleware::new();
8054 let mut req = Request::new(Method::Delete, "/api/users/1");
8055
8056 let result = run_before(&mw, &mut req);
8057
8058 match result {
8059 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("DELETE request should be allowed"),
8061 }
8062 }
8063
8064 #[test]
8065 fn patch_request_allowed() {
8066 let mw = TraceRejectionMiddleware::new();
8067 let mut req = Request::new(Method::Patch, "/api/users/1");
8068
8069 let result = run_before(&mw, &mut req);
8070
8071 match result {
8072 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("PATCH request should be allowed"),
8074 }
8075 }
8076
8077 #[test]
8078 fn options_request_allowed() {
8079 let mw = TraceRejectionMiddleware::new();
8080 let mut req = Request::new(Method::Options, "/api/users");
8081
8082 let result = run_before(&mw, &mut req);
8083
8084 match result {
8085 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("OPTIONS request should be allowed"),
8087 }
8088 }
8089
8090 #[test]
8091 fn head_request_allowed() {
8092 let mw = TraceRejectionMiddleware::new();
8093 let mut req = Request::new(Method::Head, "/");
8094
8095 let result = run_before(&mw, &mut req);
8096
8097 match result {
8098 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("HEAD request should be allowed"),
8100 }
8101 }
8102
8103 #[test]
8104 fn response_includes_allow_header() {
8105 let mw = TraceRejectionMiddleware::new();
8106 let mut req = Request::new(Method::Trace, "/");
8107
8108 let result = run_before(&mw, &mut req);
8109
8110 match result {
8111 ControlFlow::Break(response) => {
8112 let allow_header = find_header(response.headers(), "Allow");
8113 assert!(
8114 allow_header.is_some(),
8115 "Response should include Allow header"
8116 );
8117 }
8118 ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8119 }
8120 }
8121
8122 #[test]
8123 fn response_has_json_content_type() {
8124 let mw = TraceRejectionMiddleware::new();
8125 let mut req = Request::new(Method::Trace, "/");
8126
8127 let result = run_before(&mw, &mut req);
8128
8129 match result {
8130 ControlFlow::Break(response) => {
8131 let ct_header = find_header(response.headers(), "Content-Type");
8132 assert_eq!(ct_header, Some(b"application/json".as_slice()));
8133 }
8134 ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8135 }
8136 }
8137
8138 #[test]
8139 fn default_enables_logging() {
8140 let mw = TraceRejectionMiddleware::new();
8141 assert!(mw.log_attempts);
8142 }
8143
8144 #[test]
8145 fn log_attempts_can_be_disabled() {
8146 let mw = TraceRejectionMiddleware::new().log_attempts(false);
8147 assert!(!mw.log_attempts);
8148 }
8149
8150 #[test]
8151 fn middleware_name() {
8152 let mw = TraceRejectionMiddleware::new();
8153 assert_eq!(mw.name(), "TraceRejection");
8154 }
8155
8156 #[test]
8157 fn default_impl() {
8158 let mw = TraceRejectionMiddleware::default();
8159 assert!(mw.log_attempts);
8160 }
8161}
8162
8163#[cfg(test)]
8172mod https_redirect_tests {
8173 use super::*;
8174 use crate::request::Method;
8175 use crate::response::StatusCode;
8176
8177 fn test_context() -> RequestContext {
8178 RequestContext::new(asupersync::Cx::for_testing(), 1)
8179 }
8180
8181 fn run_before(mw: &HttpsRedirectMiddleware, req: &mut Request) -> ControlFlow {
8182 let ctx = test_context();
8183 let fut = mw.before(&ctx, req);
8184 futures_executor::block_on(fut)
8185 }
8186
8187 fn run_after(mw: &HttpsRedirectMiddleware, req: &Request, resp: Response) -> Response {
8188 let ctx = test_context();
8189 let fut = mw.after(&ctx, req, resp);
8190 futures_executor::block_on(fut)
8191 }
8192
8193 fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
8194 headers
8195 .iter()
8196 .find(|(n, _)| n.eq_ignore_ascii_case(name))
8197 .map(|(_, v)| v.as_slice())
8198 }
8199
8200 #[test]
8201 fn http_request_redirected() {
8202 let mw = HttpsRedirectMiddleware::new();
8203 let mut req = Request::new(Method::Get, "/");
8204 req.headers_mut().insert("Host", b"example.com".to_vec());
8205
8206 let result = run_before(&mw, &mut req);
8207
8208 match result {
8209 ControlFlow::Break(response) => {
8210 assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY);
8211 let location = find_header(response.headers(), "Location");
8212 assert_eq!(location, Some(b"https://example.com/".as_slice()));
8213 }
8214 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8215 }
8216 }
8217
8218 #[test]
8219 fn http_request_with_path_and_query() {
8220 let mw = HttpsRedirectMiddleware::new();
8221 let mut req = Request::new(Method::Get, "/api/users?page=1");
8222 req.headers_mut().insert("Host", b"example.com".to_vec());
8223
8224 let result = run_before(&mw, &mut req);
8225
8226 match result {
8227 ControlFlow::Break(response) => {
8228 let location = find_header(response.headers(), "Location");
8229 assert_eq!(
8230 location,
8231 Some(b"https://example.com/api/users?page=1".as_slice())
8232 );
8233 }
8234 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8235 }
8236 }
8237
8238 #[test]
8239 fn https_request_not_redirected() {
8240 let mw = HttpsRedirectMiddleware::new();
8241 let mut req = Request::new(Method::Get, "/");
8242 req.headers_mut().insert("Host", b"example.com".to_vec());
8243 req.headers_mut()
8244 .insert("X-Forwarded-Proto", b"https".to_vec());
8245
8246 let result = run_before(&mw, &mut req);
8247
8248 match result {
8249 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("HTTPS request should not be redirected"),
8251 }
8252 }
8253
8254 #[test]
8255 fn x_forwarded_ssl_recognized() {
8256 let mw = HttpsRedirectMiddleware::new();
8257 let mut req = Request::new(Method::Get, "/");
8258 req.headers_mut().insert("Host", b"example.com".to_vec());
8259 req.headers_mut().insert("X-Forwarded-Ssl", b"on".to_vec());
8260
8261 let result = run_before(&mw, &mut req);
8262
8263 match result {
8264 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Request with X-Forwarded-Ssl=on should not redirect"),
8266 }
8267 }
8268
8269 #[test]
8270 fn excluded_path_not_redirected() {
8271 let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
8272 let mut req = Request::new(Method::Get, "/health");
8273 req.headers_mut().insert("Host", b"example.com".to_vec());
8274
8275 let result = run_before(&mw, &mut req);
8276
8277 match result {
8278 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Excluded path should not be redirected"),
8280 }
8281 }
8282
8283 #[test]
8284 fn excluded_path_prefix_matches() {
8285 let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
8286 let mut req = Request::new(Method::Get, "/health/live");
8287 req.headers_mut().insert("Host", b"example.com".to_vec());
8288
8289 let result = run_before(&mw, &mut req);
8290
8291 match result {
8292 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Path with excluded prefix should not be redirected"),
8294 }
8295 }
8296
8297 #[test]
8298 fn temporary_redirect_option() {
8299 let mw = HttpsRedirectMiddleware::new().permanent_redirect(false);
8300 let mut req = Request::new(Method::Get, "/");
8301 req.headers_mut().insert("Host", b"example.com".to_vec());
8302
8303 let result = run_before(&mw, &mut req);
8304
8305 match result {
8306 ControlFlow::Break(response) => {
8307 assert_eq!(response.status(), StatusCode::TEMPORARY_REDIRECT);
8308 }
8309 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8310 }
8311 }
8312
8313 #[test]
8314 fn redirect_disabled() {
8315 let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
8316 let mut req = Request::new(Method::Get, "/");
8317 req.headers_mut().insert("Host", b"example.com".to_vec());
8318
8319 let result = run_before(&mw, &mut req);
8320
8321 match result {
8322 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Redirects are disabled, should continue"),
8324 }
8325 }
8326
8327 #[test]
8328 fn hsts_header_on_https_response() {
8329 let mw = HttpsRedirectMiddleware::new();
8330 let mut req = Request::new(Method::Get, "/");
8331 req.headers_mut()
8332 .insert("X-Forwarded-Proto", b"https".to_vec());
8333
8334 let response = Response::with_status(StatusCode::OK);
8335 let result = run_after(&mw, &req, response);
8336
8337 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8338 assert!(
8339 hsts.is_some(),
8340 "HSTS header should be present on HTTPS response"
8341 );
8342 let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8343 assert!(hsts_str.contains("max-age=31536000"));
8344 }
8345
8346 #[test]
8347 fn hsts_header_not_on_http_response() {
8348 let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
8349 let req = Request::new(Method::Get, "/");
8350 let response = Response::with_status(StatusCode::OK);
8353 let result = run_after(&mw, &req, response);
8354
8355 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8356 assert!(hsts.is_none(), "HSTS header should not be on HTTP response");
8357 }
8358
8359 #[test]
8360 fn hsts_with_include_subdomains() {
8361 let mw = HttpsRedirectMiddleware::new().include_subdomains(true);
8362 let mut req = Request::new(Method::Get, "/");
8363 req.headers_mut()
8364 .insert("X-Forwarded-Proto", b"https".to_vec());
8365
8366 let response = Response::with_status(StatusCode::OK);
8367 let result = run_after(&mw, &req, response);
8368
8369 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8370 let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8371 assert!(hsts_str.contains("includeSubDomains"));
8372 }
8373
8374 #[test]
8375 fn hsts_with_preload() {
8376 let mw = HttpsRedirectMiddleware::new().preload(true);
8377 let mut req = Request::new(Method::Get, "/");
8378 req.headers_mut()
8379 .insert("X-Forwarded-Proto", b"https".to_vec());
8380
8381 let response = Response::with_status(StatusCode::OK);
8382 let result = run_after(&mw, &req, response);
8383
8384 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8385 let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8386 assert!(hsts_str.contains("preload"));
8387 }
8388
8389 #[test]
8390 fn hsts_disabled_with_zero_max_age() {
8391 let mw = HttpsRedirectMiddleware::new().hsts_max_age_secs(0);
8392 let mut req = Request::new(Method::Get, "/");
8393 req.headers_mut()
8394 .insert("X-Forwarded-Proto", b"https".to_vec());
8395
8396 let response = Response::with_status(StatusCode::OK);
8397 let result = run_after(&mw, &req, response);
8398
8399 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8400 assert!(hsts.is_none(), "HSTS should be disabled with max-age=0");
8401 }
8402
8403 #[test]
8404 fn custom_https_port() {
8405 let mw = HttpsRedirectMiddleware::new().https_port(8443);
8406 let mut req = Request::new(Method::Get, "/");
8407 req.headers_mut().insert("Host", b"example.com".to_vec());
8408
8409 let result = run_before(&mw, &mut req);
8410
8411 match result {
8412 ControlFlow::Break(response) => {
8413 let location = find_header(response.headers(), "Location");
8414 assert_eq!(location, Some(b"https://example.com:8443/".as_slice()));
8415 }
8416 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8417 }
8418 }
8419
8420 #[test]
8421 fn host_with_port_stripped() {
8422 let mw = HttpsRedirectMiddleware::new();
8423 let mut req = Request::new(Method::Get, "/");
8424 req.headers_mut()
8425 .insert("Host", b"example.com:8080".to_vec());
8426
8427 let result = run_before(&mw, &mut req);
8428
8429 match result {
8430 ControlFlow::Break(response) => {
8431 let location = find_header(response.headers(), "Location");
8432 assert_eq!(location, Some(b"https://example.com/".as_slice()));
8434 }
8435 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8436 }
8437 }
8438
8439 #[test]
8440 fn middleware_name() {
8441 let mw = HttpsRedirectMiddleware::new();
8442 assert_eq!(mw.name(), "HttpsRedirect");
8443 }
8444
8445 #[test]
8446 fn default_impl() {
8447 let mw = HttpsRedirectMiddleware::default();
8448 assert!(mw.config.redirect_enabled);
8449 assert!(mw.config.permanent_redirect);
8450 assert_eq!(mw.config.hsts_max_age_secs, 31_536_000);
8451 }
8452
8453 #[test]
8454 fn config_builder() {
8455 let mw = HttpsRedirectMiddleware::new()
8456 .redirect_enabled(false)
8457 .permanent_redirect(false)
8458 .hsts_max_age_secs(86400)
8459 .include_subdomains(true)
8460 .preload(true)
8461 .https_port(8443);
8462
8463 assert!(!mw.config.redirect_enabled);
8464 assert!(!mw.config.permanent_redirect);
8465 assert_eq!(mw.config.hsts_max_age_secs, 86400);
8466 assert!(mw.config.hsts_include_subdomains);
8467 assert!(mw.config.hsts_preload);
8468 assert_eq!(mw.config.https_port, 8443);
8469 }
8470
8471 #[test]
8472 fn exclude_paths_method() {
8473 let mw = HttpsRedirectMiddleware::new()
8474 .exclude_paths(vec!["/health".to_string(), "/ready".to_string()]);
8475
8476 assert_eq!(mw.config.exclude_paths.len(), 2);
8477 assert!(mw.config.exclude_paths.contains(&"/health".to_string()));
8478 assert!(mw.config.exclude_paths.contains(&"/ready".to_string()));
8479 }
8480}
8481
8482#[cfg(test)]
8491mod tests {
8492 use super::*;
8493 use crate::response::{ResponseBody, StatusCode};
8494
8495 #[allow(dead_code)]
8497 struct AddHeaderMiddleware {
8498 name: &'static str,
8499 value: &'static [u8],
8500 }
8501
8502 impl Middleware for AddHeaderMiddleware {
8503 fn after<'a>(
8504 &'a self,
8505 _ctx: &'a RequestContext,
8506 _req: &'a Request,
8507 response: Response,
8508 ) -> BoxFuture<'a, Response> {
8509 Box::pin(async move { response.header(self.name, self.value.to_vec()) })
8510 }
8511 }
8512
8513 #[allow(dead_code)]
8515 struct BlockingMiddleware;
8516
8517 impl Middleware for BlockingMiddleware {
8518 fn before<'a>(
8519 &'a self,
8520 _ctx: &'a RequestContext,
8521 _req: &'a mut Request,
8522 ) -> BoxFuture<'a, ControlFlow> {
8523 Box::pin(async {
8524 ControlFlow::Break(
8525 Response::with_status(StatusCode::FORBIDDEN)
8526 .body(ResponseBody::Bytes(b"blocked".to_vec())),
8527 )
8528 })
8529 }
8530 }
8531
8532 #[allow(dead_code)]
8534 struct TrackingMiddleware {
8535 before_count: std::sync::atomic::AtomicUsize,
8536 after_count: std::sync::atomic::AtomicUsize,
8537 }
8538
8539 #[allow(dead_code)]
8540 impl TrackingMiddleware {
8541 fn new() -> Self {
8542 Self {
8543 before_count: std::sync::atomic::AtomicUsize::new(0),
8544 after_count: std::sync::atomic::AtomicUsize::new(0),
8545 }
8546 }
8547
8548 fn before_count(&self) -> usize {
8549 self.before_count.load(std::sync::atomic::Ordering::SeqCst)
8550 }
8551
8552 fn after_count(&self) -> usize {
8553 self.after_count.load(std::sync::atomic::Ordering::SeqCst)
8554 }
8555 }
8556
8557 impl Middleware for TrackingMiddleware {
8558 fn before<'a>(
8559 &'a self,
8560 _ctx: &'a RequestContext,
8561 _req: &'a mut Request,
8562 ) -> BoxFuture<'a, ControlFlow> {
8563 self.before_count
8564 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
8565 Box::pin(async { ControlFlow::Continue })
8566 }
8567
8568 fn after<'a>(
8569 &'a self,
8570 _ctx: &'a RequestContext,
8571 _req: &'a Request,
8572 response: Response,
8573 ) -> BoxFuture<'a, Response> {
8574 self.after_count
8575 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
8576 Box::pin(async move { response })
8577 }
8578 }
8579
8580 #[test]
8581 fn control_flow_variants() {
8582 let cont = ControlFlow::Continue;
8583 assert!(cont.is_continue());
8584 assert!(!cont.is_break());
8585
8586 let brk = ControlFlow::Break(Response::ok());
8587 assert!(!brk.is_continue());
8588 assert!(brk.is_break());
8589 }
8590
8591 #[test]
8592 fn middleware_stack_empty() {
8593 let stack = MiddlewareStack::new();
8594 assert!(stack.is_empty());
8595 assert_eq!(stack.len(), 0);
8596 }
8597
8598 #[test]
8599 fn middleware_stack_push() {
8600 let mut stack = MiddlewareStack::new();
8601 stack.push(NoopMiddleware);
8602 stack.push(NoopMiddleware);
8603 assert_eq!(stack.len(), 2);
8604 assert!(!stack.is_empty());
8605 }
8606
8607 #[test]
8608 fn noop_middleware_name() {
8609 let mw = NoopMiddleware;
8610 assert_eq!(mw.name(), "Noop");
8611 }
8612
8613 #[test]
8614 fn logging_redacts_sensitive_headers() {
8615 let mut headers = crate::request::Headers::new();
8616 headers.insert("Authorization", b"secret".to_vec());
8617 headers.insert("X-Request-Id", b"abc123".to_vec());
8618
8619 let redacted = super::default_redacted_headers();
8620 let formatted = super::format_headers(headers.iter(), &redacted);
8621
8622 assert!(formatted.contains("authorization=<redacted>"));
8623 assert!(formatted.contains("x-request-id=abc123"));
8624 }
8625
8626 #[test]
8627 fn logging_body_truncation() {
8628 let body = b"abcdef";
8629 let preview = super::format_bytes(body, 4);
8630 assert_eq!(preview, "abcd...");
8631
8632 let preview_full = super::format_bytes(body, 10);
8633 assert_eq!(preview_full, "abcdef");
8634 }
8635
8636 fn test_context() -> RequestContext {
8637 let cx = asupersync::Cx::for_testing();
8638 RequestContext::new(cx, 1)
8639 }
8640
8641 fn header_value(response: &Response, name: &str) -> Option<String> {
8642 response
8643 .headers()
8644 .iter()
8645 .find(|(n, _)| n.eq_ignore_ascii_case(name))
8646 .and_then(|(_, v)| std::str::from_utf8(v).ok())
8647 .map(ToString::to_string)
8648 }
8649
8650 #[test]
8651 fn cors_exact_origin_allows() {
8652 let cors = Cors::new().allow_origin("https://example.com");
8653 let ctx = test_context();
8654 let mut req = Request::new(crate::request::Method::Get, "/");
8655 req.headers_mut()
8656 .insert("origin", b"https://example.com".to_vec());
8657
8658 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8659 assert!(matches!(result, ControlFlow::Continue));
8660
8661 let response = Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()));
8662 let response = futures_executor::block_on(cors.after(&ctx, &req, response));
8663
8664 assert_eq!(
8665 header_value(&response, "access-control-allow-origin"),
8666 Some("https://example.com".to_string())
8667 );
8668 assert_eq!(header_value(&response, "vary"), Some("Origin".to_string()));
8669 }
8670
8671 #[test]
8672 fn cors_wildcard_origin_allows() {
8673 let cors = Cors::new().allow_origin_wildcard("https://*.example.com");
8674 let ctx = test_context();
8675 let mut req = Request::new(crate::request::Method::Get, "/");
8676 req.headers_mut()
8677 .insert("origin", b"https://api.example.com".to_vec());
8678
8679 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8680 assert!(matches!(result, ControlFlow::Continue));
8681 }
8682
8683 #[test]
8684 fn cors_regex_origin_allows() {
8685 let cors = Cors::new().allow_origin_regex(r"^https://.*\.example\.com$");
8686 let ctx = test_context();
8687 let mut req = Request::new(crate::request::Method::Get, "/");
8688 req.headers_mut()
8689 .insert("origin", b"https://svc.example.com".to_vec());
8690
8691 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8692 assert!(matches!(result, ControlFlow::Continue));
8693 }
8694
8695 #[test]
8696 fn cors_preflight_handled() {
8697 let cors = Cors::new()
8698 .allow_any_origin()
8699 .allow_headers(["x-test", "content-type"])
8700 .max_age(600);
8701 let ctx = test_context();
8702 let mut req = Request::new(crate::request::Method::Options, "/");
8703 req.headers_mut()
8704 .insert("origin", b"https://example.com".to_vec());
8705 req.headers_mut()
8706 .insert("access-control-request-method", b"POST".to_vec());
8707 req.headers_mut().insert(
8708 "access-control-request-headers",
8709 b"x-test, content-type".to_vec(),
8710 );
8711
8712 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8713 let ControlFlow::Break(response) = result else {
8714 panic!("expected preflight break");
8715 };
8716
8717 assert_eq!(response.status().as_u16(), 204);
8718 assert_eq!(
8719 header_value(&response, "access-control-allow-origin"),
8720 Some("*".to_string())
8721 );
8722 assert_eq!(
8723 header_value(&response, "access-control-allow-methods"),
8724 Some("GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD".to_string())
8725 );
8726 assert_eq!(
8727 header_value(&response, "access-control-allow-headers"),
8728 Some("x-test, content-type".to_string())
8729 );
8730 assert_eq!(
8731 header_value(&response, "access-control-max-age"),
8732 Some("600".to_string())
8733 );
8734 }
8735
8736 #[test]
8737 fn cors_credentials_echo_origin() {
8738 let cors = Cors::new().allow_any_origin().allow_credentials(true);
8739 let ctx = test_context();
8740 let mut req = Request::new(crate::request::Method::Get, "/");
8741 req.headers_mut()
8742 .insert("origin", b"https://example.com".to_vec());
8743
8744 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8745 assert!(matches!(result, ControlFlow::Continue));
8746
8747 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8748 assert_eq!(
8749 header_value(&response, "access-control-allow-origin"),
8750 Some("https://example.com".to_string())
8751 );
8752 assert_eq!(
8753 header_value(&response, "access-control-allow-credentials"),
8754 Some("true".to_string())
8755 );
8756 }
8757
8758 #[test]
8763 fn cors_spec_compliance_credentials_never_wildcard_origin() {
8764 let cors = Cors::new().allow_any_origin().allow_credentials(true);
8767 let ctx = test_context();
8768
8769 for origin in &[
8771 "https://example.com",
8772 "https://api.example.com",
8773 "http://localhost:3000",
8774 ] {
8775 let mut req = Request::new(crate::request::Method::Get, "/");
8776 req.headers_mut()
8777 .insert("origin", origin.as_bytes().to_vec());
8778
8779 futures_executor::block_on(cors.before(&ctx, &mut req));
8780 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8781
8782 let allow_origin = header_value(&response, "access-control-allow-origin");
8783 assert_eq!(
8784 allow_origin,
8785 Some((*origin).to_string()),
8786 "With credentials enabled, Access-Control-Allow-Origin must echo '{}', not '*'",
8787 origin
8788 );
8789 assert_ne!(
8790 allow_origin,
8791 Some("*".to_string()),
8792 "CORS spec violation: credentials + wildcard origin is forbidden"
8793 );
8794 }
8795 }
8796
8797 #[test]
8798 fn cors_spec_compliance_preflight_with_credentials() {
8799 let cors = Cors::new()
8801 .allow_any_origin()
8802 .allow_credentials(true)
8803 .allow_headers(["content-type", "x-custom-header"]);
8804 let ctx = test_context();
8805
8806 let mut req = Request::new(crate::request::Method::Options, "/");
8807 req.headers_mut()
8808 .insert("origin", b"https://example.com".to_vec());
8809 req.headers_mut()
8810 .insert("access-control-request-method", b"POST".to_vec());
8811 req.headers_mut()
8812 .insert("access-control-request-headers", b"content-type".to_vec());
8813
8814 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8815 let ControlFlow::Break(response) = result else {
8816 panic!("expected preflight break");
8817 };
8818
8819 let allow_origin = header_value(&response, "access-control-allow-origin");
8821 assert_eq!(allow_origin, Some("https://example.com".to_string()));
8822 assert_ne!(
8823 allow_origin,
8824 Some("*".to_string()),
8825 "CORS spec violation: preflight with credentials must not use wildcard origin"
8826 );
8827
8828 assert_eq!(
8830 header_value(&response, "access-control-allow-credentials"),
8831 Some("true".to_string())
8832 );
8833 }
8834
8835 #[test]
8836 fn cors_spec_without_credentials_allows_wildcard() {
8837 let cors = Cors::new().allow_any_origin();
8839 let ctx = test_context();
8840 let mut req = Request::new(crate::request::Method::Get, "/");
8841 req.headers_mut()
8842 .insert("origin", b"https://example.com".to_vec());
8843
8844 futures_executor::block_on(cors.before(&ctx, &mut req));
8845 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8846
8847 assert_eq!(
8849 header_value(&response, "access-control-allow-origin"),
8850 Some("*".to_string())
8851 );
8852 assert!(header_value(&response, "access-control-allow-credentials").is_none());
8854 }
8855
8856 #[test]
8857 fn cors_disallowed_preflight_forbidden() {
8858 let cors = Cors::new().allow_origin("https://good.example");
8859 let ctx = test_context();
8860 let mut req = Request::new(crate::request::Method::Options, "/");
8861 req.headers_mut()
8862 .insert("origin", b"https://evil.example".to_vec());
8863 req.headers_mut()
8864 .insert("access-control-request-method", b"GET".to_vec());
8865
8866 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8867 let ControlFlow::Break(response) = result else {
8868 panic!("expected forbidden preflight");
8869 };
8870 assert_eq!(response.status().as_u16(), 403);
8871 }
8872
8873 #[test]
8874 fn cors_simple_request_disallowed_origin_no_headers() {
8875 let cors = Cors::new().allow_origin("https://good.example");
8877 let ctx = test_context();
8878 let mut req = Request::new(crate::request::Method::Get, "/");
8879 req.headers_mut()
8880 .insert("origin", b"https://evil.example".to_vec());
8881
8882 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8883 assert!(matches!(result, ControlFlow::Continue));
8885
8886 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8887 assert!(header_value(&response, "access-control-allow-origin").is_none());
8889 }
8890
8891 #[test]
8892 fn cors_expose_headers_configuration() {
8893 let cors = Cors::new()
8894 .allow_any_origin()
8895 .expose_headers(["x-custom-header", "x-another-header"]);
8896 let ctx = test_context();
8897 let mut req = Request::new(crate::request::Method::Get, "/");
8898 req.headers_mut()
8899 .insert("origin", b"https://example.com".to_vec());
8900
8901 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8902 assert!(matches!(result, ControlFlow::Continue));
8903
8904 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8905 assert_eq!(
8906 header_value(&response, "access-control-expose-headers"),
8907 Some("x-custom-header, x-another-header".to_string())
8908 );
8909 }
8910
8911 #[test]
8912 fn cors_any_origin_sets_wildcard() {
8913 let cors = Cors::new().allow_any_origin();
8914 let ctx = test_context();
8915 let mut req = Request::new(crate::request::Method::Get, "/");
8916 req.headers_mut()
8917 .insert("origin", b"https://any-site.com".to_vec());
8918
8919 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8920 assert!(matches!(result, ControlFlow::Continue));
8921
8922 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8923 assert_eq!(
8924 header_value(&response, "access-control-allow-origin"),
8925 Some("*".to_string())
8926 );
8927 }
8928
8929 #[test]
8930 fn cors_config_allows_method_override() {
8931 let cors = Cors::new()
8933 .allow_any_origin()
8934 .allow_methods([crate::request::Method::Get, crate::request::Method::Post]);
8935 let ctx = test_context();
8936 let mut req = Request::new(crate::request::Method::Options, "/");
8937 req.headers_mut()
8938 .insert("origin", b"https://example.com".to_vec());
8939 req.headers_mut()
8940 .insert("access-control-request-method", b"POST".to_vec());
8941
8942 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8943 let ControlFlow::Break(response) = result else {
8944 panic!("expected preflight break");
8945 };
8946 assert_eq!(
8947 header_value(&response, "access-control-allow-methods"),
8948 Some("GET, POST".to_string())
8949 );
8950 }
8951
8952 #[test]
8953 fn cors_no_origin_header_skips_cors() {
8954 let cors = Cors::new().allow_any_origin();
8956 let ctx = test_context();
8957 let mut req = Request::new(crate::request::Method::Get, "/");
8958
8959 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8960 assert!(matches!(result, ControlFlow::Continue));
8961
8962 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8963 assert!(header_value(&response, "access-control-allow-origin").is_none());
8964 }
8965
8966 #[test]
8967 fn cors_middleware_name() {
8968 let cors = Cors::new();
8969 assert_eq!(cors.name(), "Cors");
8970 }
8971
8972 #[test]
8973 fn cors_empty_allowed_headers_does_not_reflect_request_headers() {
8974 let cors = Cors::new().allow_any_origin(); let ctx = test_context();
8979 let mut req = Request::new(crate::request::Method::Options, "/api");
8980 req.headers_mut()
8981 .insert("origin", b"https://example.com".to_vec());
8982 req.headers_mut()
8983 .insert("access-control-request-method", b"GET".to_vec());
8984 req.headers_mut().insert(
8985 "access-control-request-headers",
8986 b"x-evil-custom, authorization".to_vec(),
8987 );
8988
8989 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8990 if let ControlFlow::Break(response) = result {
8991 assert_eq!(
8994 header_value(&response, "access-control-allow-headers"),
8995 None,
8996 "Empty allowed_headers must not reflect request headers"
8997 );
8998 } else {
8999 panic!("Preflight should have been handled (Break)");
9000 }
9001 }
9002
9003 #[test]
9004 fn cors_explicit_allowed_headers_returned_in_preflight() {
9005 let cors = Cors::new()
9006 .allow_any_origin()
9007 .allow_headers(["x-token", "content-type"]);
9008 let ctx = test_context();
9009 let mut req = Request::new(crate::request::Method::Options, "/api");
9010 req.headers_mut()
9011 .insert("origin", b"https://example.com".to_vec());
9012 req.headers_mut()
9013 .insert("access-control-request-method", b"POST".to_vec());
9014
9015 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
9016 if let ControlFlow::Break(response) = result {
9017 let headers_val = header_value(&response, "access-control-allow-headers");
9018 assert!(headers_val.is_some());
9019 let val = headers_val.unwrap();
9020 assert!(val.contains("x-token"));
9021 assert!(val.contains("content-type"));
9022 } else {
9023 panic!("Preflight should have been handled (Break)");
9024 }
9025 }
9026
9027 #[test]
9032 fn request_id_generates_unique_ids() {
9033 let id1 = RequestId::generate();
9034 let id2 = RequestId::generate();
9035 let id3 = RequestId::generate();
9036
9037 assert_ne!(id1, id2);
9038 assert_ne!(id2, id3);
9039 assert_ne!(id1, id3);
9040
9041 assert!(!id1.as_str().is_empty());
9043 assert!(!id2.as_str().is_empty());
9044 assert!(!id3.as_str().is_empty());
9045 }
9046
9047 #[test]
9048 fn request_id_display() {
9049 let id = RequestId::new("test-request-123");
9050 assert_eq!(format!("{}", id), "test-request-123");
9051 }
9052
9053 #[test]
9054 fn request_id_from_string() {
9055 let id: RequestId = "my-id".into();
9056 assert_eq!(id.as_str(), "my-id");
9057
9058 let id2: RequestId = String::from("my-id-2").into();
9059 assert_eq!(id2.as_str(), "my-id-2");
9060 }
9061
9062 #[test]
9063 fn request_id_config_defaults() {
9064 let config = RequestIdConfig::default();
9065 assert_eq!(config.header_name, "x-request-id");
9066 assert!(config.accept_from_client);
9067 assert!(config.add_to_response);
9068 assert_eq!(config.max_client_id_length, 128);
9069 }
9070
9071 #[test]
9072 fn request_id_config_builder() {
9073 let config = RequestIdConfig::new()
9074 .header_name("X-Trace-ID")
9075 .accept_from_client(false)
9076 .add_to_response(false)
9077 .max_client_id_length(64);
9078
9079 assert_eq!(config.header_name, "X-Trace-ID");
9080 assert!(!config.accept_from_client);
9081 assert!(!config.add_to_response);
9082 assert_eq!(config.max_client_id_length, 64);
9083 }
9084
9085 #[test]
9086 fn request_id_middleware_generates_id() {
9087 let middleware = RequestIdMiddleware::new();
9088 let ctx = test_context();
9089 let mut req = Request::new(crate::request::Method::Get, "/");
9090
9091 let result = futures_executor::block_on(middleware.before(&ctx, &mut req));
9092 assert!(matches!(result, ControlFlow::Continue));
9093
9094 let stored_id = req.get_extension::<RequestId>();
9095 assert!(stored_id.is_some());
9096 assert!(!stored_id.unwrap().as_str().is_empty());
9097 }
9098
9099 #[test]
9100 fn request_id_middleware_accepts_client_id() {
9101 let middleware = RequestIdMiddleware::new();
9102 let ctx = test_context();
9103 let mut req = Request::new(crate::request::Method::Get, "/");
9104 req.headers_mut()
9105 .insert("x-request-id", b"client-provided-id-123".to_vec());
9106
9107 futures_executor::block_on(middleware.before(&ctx, &mut req));
9108
9109 let stored_id = req.get_extension::<RequestId>().unwrap();
9110 assert_eq!(stored_id.as_str(), "client-provided-id-123");
9111 }
9112
9113 #[test]
9114 fn request_id_middleware_rejects_invalid_client_id() {
9115 let middleware = RequestIdMiddleware::new();
9116 let ctx = test_context();
9117
9118 let mut req = Request::new(crate::request::Method::Get, "/");
9120 req.headers_mut()
9121 .insert("x-request-id", b"invalid<script>id".to_vec());
9122
9123 futures_executor::block_on(middleware.before(&ctx, &mut req));
9124
9125 let stored_id = req.get_extension::<RequestId>().unwrap();
9126 assert_ne!(stored_id.as_str(), "invalid<script>id");
9128 }
9129
9130 #[test]
9131 fn request_id_middleware_rejects_too_long_client_id() {
9132 let config = RequestIdConfig::new().max_client_id_length(10);
9133 let middleware = RequestIdMiddleware::with_config(config);
9134 let ctx = test_context();
9135
9136 let mut req = Request::new(crate::request::Method::Get, "/");
9137 req.headers_mut()
9138 .insert("x-request-id", b"this-id-is-way-too-long".to_vec());
9139
9140 futures_executor::block_on(middleware.before(&ctx, &mut req));
9141
9142 let stored_id = req.get_extension::<RequestId>().unwrap();
9143 assert_ne!(stored_id.as_str(), "this-id-is-way-too-long");
9145 }
9146
9147 #[test]
9148 fn request_id_middleware_adds_to_response() {
9149 let middleware = RequestIdMiddleware::new();
9150 let ctx = test_context();
9151 let mut req = Request::new(crate::request::Method::Get, "/");
9152
9153 futures_executor::block_on(middleware.before(&ctx, &mut req));
9154 let stored_id = req.get_extension::<RequestId>().unwrap().clone();
9155
9156 let response = Response::ok();
9157 let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9158
9159 let header = header_value(&response, "x-request-id");
9160 assert_eq!(header, Some(stored_id.0));
9161 }
9162
9163 #[test]
9164 fn request_id_middleware_respects_add_to_response_false() {
9165 let config = RequestIdConfig::new().add_to_response(false);
9166 let middleware = RequestIdMiddleware::with_config(config);
9167 let ctx = test_context();
9168 let mut req = Request::new(crate::request::Method::Get, "/");
9169
9170 futures_executor::block_on(middleware.before(&ctx, &mut req));
9171
9172 let response = Response::ok();
9173 let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9174
9175 let header = header_value(&response, "x-request-id");
9176 assert!(header.is_none());
9177 }
9178
9179 #[test]
9180 fn request_id_middleware_respects_accept_from_client_false() {
9181 let config = RequestIdConfig::new().accept_from_client(false);
9182 let middleware = RequestIdMiddleware::with_config(config);
9183 let ctx = test_context();
9184 let mut req = Request::new(crate::request::Method::Get, "/");
9185 req.headers_mut()
9186 .insert("x-request-id", b"client-id".to_vec());
9187
9188 futures_executor::block_on(middleware.before(&ctx, &mut req));
9189
9190 let stored_id = req.get_extension::<RequestId>().unwrap();
9191 assert_ne!(stored_id.as_str(), "client-id");
9193 }
9194
9195 #[test]
9196 fn request_id_middleware_custom_header_name() {
9197 let config = RequestIdConfig::new().header_name("X-Trace-ID");
9198 let middleware = RequestIdMiddleware::with_config(config);
9199 let ctx = test_context();
9200 let mut req = Request::new(crate::request::Method::Get, "/");
9201 req.headers_mut()
9202 .insert("X-Trace-ID", b"trace-123".to_vec());
9203
9204 futures_executor::block_on(middleware.before(&ctx, &mut req));
9205
9206 let stored_id = req.get_extension::<RequestId>().unwrap();
9207 assert_eq!(stored_id.as_str(), "trace-123");
9208
9209 let response = Response::ok();
9210 let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9211
9212 let header = header_value(&response, "X-Trace-ID");
9213 assert_eq!(header, Some("trace-123".to_string()));
9214 }
9215
9216 #[test]
9217 fn is_valid_request_id_accepts_valid() {
9218 assert!(super::is_valid_request_id("abc123"));
9219 assert!(super::is_valid_request_id("request-id-123"));
9220 assert!(super::is_valid_request_id("request_id_123"));
9221 assert!(super::is_valid_request_id("request.id.123"));
9222 assert!(super::is_valid_request_id("ABC123"));
9223 assert!(super::is_valid_request_id("a-b_c.D"));
9224 }
9225
9226 #[test]
9227 fn is_valid_request_id_rejects_invalid() {
9228 assert!(!super::is_valid_request_id(""));
9229 assert!(!super::is_valid_request_id("id with spaces"));
9230 assert!(!super::is_valid_request_id("id<script>"));
9231 assert!(!super::is_valid_request_id("id\nwith\nnewlines"));
9232 assert!(!super::is_valid_request_id("id;with;semicolons"));
9233 assert!(!super::is_valid_request_id("id/with/slashes"));
9234 }
9235
9236 #[test]
9237 fn request_id_middleware_name() {
9238 let middleware = RequestIdMiddleware::new();
9239 assert_eq!(middleware.name(), "RequestId");
9240 }
9241
9242 struct OrderTrackingMiddleware {
9248 id: &'static str,
9249 log: Arc<std::sync::Mutex<Vec<String>>>,
9250 }
9251
9252 impl OrderTrackingMiddleware {
9253 fn new(id: &'static str, log: Arc<std::sync::Mutex<Vec<String>>>) -> Self {
9254 Self { id, log }
9255 }
9256 }
9257
9258 impl Middleware for OrderTrackingMiddleware {
9259 fn before<'a>(
9260 &'a self,
9261 _ctx: &'a RequestContext,
9262 _req: &'a mut Request,
9263 ) -> BoxFuture<'a, ControlFlow> {
9264 self.log.lock().unwrap().push(format!("{}.before", self.id));
9265 Box::pin(async { ControlFlow::Continue })
9266 }
9267
9268 fn after<'a>(
9269 &'a self,
9270 _ctx: &'a RequestContext,
9271 _req: &'a Request,
9272 response: Response,
9273 ) -> BoxFuture<'a, Response> {
9274 self.log.lock().unwrap().push(format!("{}.after", self.id));
9275 Box::pin(async move { response })
9276 }
9277 }
9278
9279 struct ConditionalBreakMiddleware {
9281 id: &'static str,
9282 should_break: bool,
9283 log: Arc<std::sync::Mutex<Vec<String>>>,
9284 }
9285
9286 impl ConditionalBreakMiddleware {
9287 fn new(
9288 id: &'static str,
9289 should_break: bool,
9290 log: Arc<std::sync::Mutex<Vec<String>>>,
9291 ) -> Self {
9292 Self {
9293 id,
9294 should_break,
9295 log,
9296 }
9297 }
9298 }
9299
9300 impl Middleware for ConditionalBreakMiddleware {
9301 fn before<'a>(
9302 &'a self,
9303 _ctx: &'a RequestContext,
9304 _req: &'a mut Request,
9305 ) -> BoxFuture<'a, ControlFlow> {
9306 self.log.lock().unwrap().push(format!("{}.before", self.id));
9307 let should_break = self.should_break;
9308 Box::pin(async move {
9309 if should_break {
9310 ControlFlow::Break(
9311 Response::with_status(StatusCode::FORBIDDEN)
9312 .body(ResponseBody::Bytes(b"blocked".to_vec())),
9313 )
9314 } else {
9315 ControlFlow::Continue
9316 }
9317 })
9318 }
9319
9320 fn after<'a>(
9321 &'a self,
9322 _ctx: &'a RequestContext,
9323 _req: &'a Request,
9324 response: Response,
9325 ) -> BoxFuture<'a, Response> {
9326 self.log.lock().unwrap().push(format!("{}.after", self.id));
9327 Box::pin(async move { response })
9328 }
9329 }
9330
9331 struct OkHandler;
9333
9334 impl Handler for OkHandler {
9335 fn call<'a>(
9336 &'a self,
9337 _ctx: &'a RequestContext,
9338 _req: &'a mut Request,
9339 ) -> BoxFuture<'a, Response> {
9340 Box::pin(async move { Response::ok().body(ResponseBody::Bytes(b"handler".to_vec())) })
9341 }
9342 }
9343
9344 struct CheckHeaderHandler;
9346
9347 impl Handler for CheckHeaderHandler {
9348 fn call<'a>(
9349 &'a self,
9350 _ctx: &'a RequestContext,
9351 req: &'a mut Request,
9352 ) -> BoxFuture<'a, Response> {
9353 let has_header = req.headers().get("X-Modified-By").is_some();
9354 Box::pin(async move {
9355 if has_header {
9356 Response::ok().body(ResponseBody::Bytes(b"header-present".to_vec()))
9357 } else {
9358 Response::with_status(StatusCode::BAD_REQUEST)
9359 }
9360 })
9361 }
9362 }
9363
9364 struct ErrorHandler;
9366
9367 impl Handler for ErrorHandler {
9368 fn call<'a>(
9369 &'a self,
9370 _ctx: &'a RequestContext,
9371 _req: &'a mut Request,
9372 ) -> BoxFuture<'a, Response> {
9373 Box::pin(async move { Response::with_status(StatusCode::INTERNAL_SERVER_ERROR) })
9374 }
9375 }
9376
9377 #[test]
9378 fn middleware_stack_executes_in_correct_order() {
9379 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9382
9383 let mut stack = MiddlewareStack::new();
9384 stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9385 stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9386 stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
9387
9388 let ctx = test_context();
9389 let mut req = Request::new(crate::request::Method::Get, "/");
9390
9391 futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9392
9393 let calls = log.lock().unwrap().clone();
9394 assert_eq!(
9395 calls,
9396 vec![
9397 "mw1.before",
9398 "mw2.before",
9399 "mw3.before",
9400 "mw3.after",
9401 "mw2.after",
9402 "mw1.after",
9403 ]
9404 );
9405 }
9406
9407 #[test]
9408 fn middleware_stack_short_circuit_skips_later_middleware() {
9409 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9412
9413 let mut stack = MiddlewareStack::new();
9414 stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9415 stack.push(ConditionalBreakMiddleware::new("mw2", true, log.clone()));
9416 stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
9417
9418 let ctx = test_context();
9419 let mut req = Request::new(crate::request::Method::Get, "/");
9420
9421 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9422
9423 assert_eq!(response.status().as_u16(), 403);
9425
9426 let calls = log.lock().unwrap().clone();
9427 assert_eq!(
9428 calls,
9429 vec![
9430 "mw1.before",
9431 "mw2.before",
9432 "mw1.after",
9435 ]
9436 );
9437 }
9438
9439 #[test]
9440 fn middleware_stack_first_middleware_breaks() {
9441 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9443
9444 let mut stack = MiddlewareStack::new();
9445 stack.push(ConditionalBreakMiddleware::new("mw1", true, log.clone()));
9446 stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9447
9448 let ctx = test_context();
9449 let mut req = Request::new(crate::request::Method::Get, "/");
9450
9451 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9452
9453 assert_eq!(response.status().as_u16(), 403);
9454
9455 let calls = log.lock().unwrap().clone();
9456 assert_eq!(calls, vec!["mw1.before"]);
9457 }
9459
9460 #[test]
9461 fn middleware_stack_last_middleware_breaks() {
9462 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9464
9465 let mut stack = MiddlewareStack::new();
9466 stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9467 stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9468 stack.push(ConditionalBreakMiddleware::new("mw3", true, log.clone()));
9469
9470 let ctx = test_context();
9471 let mut req = Request::new(crate::request::Method::Get, "/");
9472
9473 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9474
9475 assert_eq!(response.status().as_u16(), 403);
9476
9477 let calls = log.lock().unwrap().clone();
9478 assert_eq!(
9479 calls,
9480 vec![
9481 "mw1.before",
9482 "mw2.before",
9483 "mw3.before",
9484 "mw2.after",
9486 "mw1.after",
9487 ]
9488 );
9489 }
9490
9491 #[test]
9492 fn middleware_stack_empty_executes_handler_directly() {
9493 let stack = MiddlewareStack::new();
9494 let ctx = test_context();
9495 let mut req = Request::new(crate::request::Method::Get, "/");
9496
9497 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9498
9499 assert_eq!(response.status().as_u16(), 200);
9500 }
9501
9502 #[test]
9503 fn middleware_stack_with_capacity() {
9504 let stack = MiddlewareStack::with_capacity(10);
9505 assert!(stack.is_empty());
9506 assert_eq!(stack.len(), 0);
9507 }
9508
9509 #[test]
9510 fn middleware_stack_push_arc() {
9511 let mut stack = MiddlewareStack::new();
9512 let mw: Arc<dyn Middleware> = Arc::new(NoopMiddleware);
9513 stack.push_arc(mw);
9514 assert_eq!(stack.len(), 1);
9515 }
9516
9517 #[test]
9522 fn add_response_header_adds_header() {
9523 let mw = AddResponseHeader::new("X-Custom", b"custom-value".to_vec());
9524 let ctx = test_context();
9525 let req = Request::new(crate::request::Method::Get, "/");
9526
9527 let response = Response::ok();
9528 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9529
9530 assert_eq!(
9531 header_value(&response, "X-Custom"),
9532 Some("custom-value".to_string())
9533 );
9534 }
9535
9536 #[test]
9537 fn add_response_header_preserves_existing_headers() {
9538 let mw = AddResponseHeader::new("X-New", b"new".to_vec());
9539 let ctx = test_context();
9540 let req = Request::new(crate::request::Method::Get, "/");
9541
9542 let response = Response::ok().header("X-Existing", b"existing".to_vec());
9543 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9544
9545 assert_eq!(
9546 header_value(&response, "X-Existing"),
9547 Some("existing".to_string())
9548 );
9549 assert_eq!(header_value(&response, "X-New"), Some("new".to_string()));
9550 }
9551
9552 #[test]
9553 fn add_response_header_name() {
9554 let mw = AddResponseHeader::new("X-Test", b"test".to_vec());
9555 assert_eq!(mw.name(), "AddResponseHeader");
9556 }
9557
9558 #[test]
9563 fn require_header_allows_with_header() {
9564 let mw = RequireHeader::new("X-Api-Key");
9565 let ctx = test_context();
9566 let mut req = Request::new(crate::request::Method::Get, "/");
9567 req.headers_mut()
9568 .insert("X-Api-Key", b"secret-key".to_vec());
9569
9570 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9571 assert!(matches!(result, ControlFlow::Continue));
9572 }
9573
9574 #[test]
9575 fn require_header_blocks_without_header() {
9576 let mw = RequireHeader::new("X-Api-Key");
9577 let ctx = test_context();
9578 let mut req = Request::new(crate::request::Method::Get, "/");
9579
9580 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9581
9582 match result {
9583 ControlFlow::Break(response) => {
9584 assert_eq!(response.status().as_u16(), 400);
9585 }
9586 ControlFlow::Continue => panic!("Expected Break, got Continue"),
9587 }
9588 }
9589
9590 #[test]
9591 fn require_header_name() {
9592 let mw = RequireHeader::new("X-Test");
9593 assert_eq!(mw.name(), "RequireHeader");
9594 }
9595
9596 #[test]
9601 fn path_prefix_filter_allows_matching_path() {
9602 let mw = PathPrefixFilter::new("/api");
9603 let ctx = test_context();
9604 let mut req = Request::new(crate::request::Method::Get, "/api/users");
9605
9606 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9607 assert!(matches!(result, ControlFlow::Continue));
9608 }
9609
9610 #[test]
9611 fn path_prefix_filter_allows_exact_prefix() {
9612 let mw = PathPrefixFilter::new("/api");
9613 let ctx = test_context();
9614 let mut req = Request::new(crate::request::Method::Get, "/api");
9615
9616 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9617 assert!(matches!(result, ControlFlow::Continue));
9618 }
9619
9620 #[test]
9621 fn path_prefix_filter_blocks_non_matching_path() {
9622 let mw = PathPrefixFilter::new("/api");
9623 let ctx = test_context();
9624 let mut req = Request::new(crate::request::Method::Get, "/admin/users");
9625
9626 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9627
9628 match result {
9629 ControlFlow::Break(response) => {
9630 assert_eq!(response.status().as_u16(), 404);
9631 }
9632 ControlFlow::Continue => panic!("Expected Break, got Continue"),
9633 }
9634 }
9635
9636 #[test]
9637 fn path_prefix_filter_name() {
9638 let mw = PathPrefixFilter::new("/api");
9639 assert_eq!(mw.name(), "PathPrefixFilter");
9640 }
9641
9642 #[test]
9647 fn conditional_status_applies_true_status() {
9648 let mw = ConditionalStatus::new(
9649 |req| req.path() == "/health",
9650 StatusCode::OK,
9651 StatusCode::NOT_FOUND,
9652 );
9653 let ctx = test_context();
9654 let req = Request::new(crate::request::Method::Get, "/health");
9655 let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
9656
9657 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9658 assert_eq!(response.status().as_u16(), 200);
9659 }
9660
9661 #[test]
9662 fn conditional_status_applies_false_status() {
9663 let mw = ConditionalStatus::new(
9664 |req| req.path() == "/health",
9665 StatusCode::OK,
9666 StatusCode::NOT_FOUND,
9667 );
9668 let ctx = test_context();
9669 let req = Request::new(crate::request::Method::Get, "/other");
9670 let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
9671
9672 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9673 assert_eq!(response.status().as_u16(), 404);
9674 }
9675
9676 #[test]
9677 fn conditional_status_name() {
9678 let mw = ConditionalStatus::new(|_| true, StatusCode::OK, StatusCode::NOT_FOUND);
9679 assert_eq!(mw.name(), "ConditionalStatus");
9680 }
9681
9682 #[derive(Clone)]
9687 struct LayerTestMiddleware {
9688 prefix: String,
9689 }
9690
9691 impl LayerTestMiddleware {
9692 fn new(prefix: impl Into<String>) -> Self {
9693 Self {
9694 prefix: prefix.into(),
9695 }
9696 }
9697 }
9698
9699 impl Middleware for LayerTestMiddleware {
9700 fn after<'a>(
9701 &'a self,
9702 _ctx: &'a RequestContext,
9703 _req: &'a Request,
9704 response: Response,
9705 ) -> BoxFuture<'a, Response> {
9706 let prefix = self.prefix.clone();
9707 Box::pin(async move { response.header("X-Layer", prefix.into_bytes()) })
9708 }
9709 }
9710
9711 #[test]
9712 fn layer_wraps_handler() {
9713 let layer = Layer::new(LayerTestMiddleware::new("wrapped"));
9714 let wrapped = layer.wrap(OkHandler);
9715
9716 let ctx = test_context();
9717 let mut req = Request::new(crate::request::Method::Get, "/");
9718
9719 let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
9720
9721 assert_eq!(response.status().as_u16(), 200);
9722 assert_eq!(
9723 header_value(&response, "X-Layer"),
9724 Some("wrapped".to_string())
9725 );
9726 }
9727
9728 #[test]
9729 fn layered_handles_break() {
9730 #[derive(Clone)]
9731 struct BreakingMiddleware;
9732
9733 impl Middleware for BreakingMiddleware {
9734 fn before<'a>(
9735 &'a self,
9736 _ctx: &'a RequestContext,
9737 _req: &'a mut Request,
9738 ) -> BoxFuture<'a, ControlFlow> {
9739 Box::pin(async {
9740 ControlFlow::Break(Response::with_status(StatusCode::UNAUTHORIZED))
9741 })
9742 }
9743
9744 fn after<'a>(
9745 &'a self,
9746 _ctx: &'a RequestContext,
9747 _req: &'a Request,
9748 response: Response,
9749 ) -> BoxFuture<'a, Response> {
9750 Box::pin(async move { response.header("X-After", b"ran".to_vec()) })
9751 }
9752 }
9753
9754 let layer = Layer::new(BreakingMiddleware);
9755 let wrapped = layer.wrap(OkHandler);
9756
9757 let ctx = test_context();
9758 let mut req = Request::new(crate::request::Method::Get, "/");
9759
9760 let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
9761
9762 assert_eq!(response.status().as_u16(), 401);
9764 assert_eq!(header_value(&response, "X-After"), Some("ran".to_string()));
9766 }
9767
9768 #[test]
9773 fn request_response_logger_default() {
9774 let logger = RequestResponseLogger::default();
9775 assert!(logger.log_request_headers);
9776 assert!(logger.log_response_headers);
9777 assert!(!logger.log_body);
9778 assert_eq!(logger.max_body_bytes, 1024);
9779 }
9780
9781 #[test]
9782 fn request_response_logger_builder() {
9783 let logger = RequestResponseLogger::new()
9784 .log_request_headers(false)
9785 .log_response_headers(false)
9786 .log_body(true)
9787 .max_body_bytes(2048)
9788 .redact_header("x-secret");
9789
9790 assert!(!logger.log_request_headers);
9791 assert!(!logger.log_response_headers);
9792 assert!(logger.log_body);
9793 assert_eq!(logger.max_body_bytes, 2048);
9794 assert!(logger.redact_headers.contains("x-secret"));
9795 }
9796
9797 #[test]
9798 fn request_response_logger_name() {
9799 let logger = RequestResponseLogger::new();
9800 assert_eq!(logger.name(), "RequestResponseLogger");
9801 }
9802
9803 #[test]
9808 fn middleware_stack_modifies_request_for_handler() {
9809 struct RequestModifier;
9811
9812 impl Middleware for RequestModifier {
9813 fn before<'a>(
9814 &'a self,
9815 _ctx: &'a RequestContext,
9816 req: &'a mut Request,
9817 ) -> BoxFuture<'a, ControlFlow> {
9818 req.headers_mut()
9819 .insert("X-Modified-By", b"middleware".to_vec());
9820 Box::pin(async { ControlFlow::Continue })
9821 }
9822 }
9823
9824 let mut stack = MiddlewareStack::new();
9825 stack.push(RequestModifier);
9826
9827 let ctx = test_context();
9828 let mut req = Request::new(crate::request::Method::Get, "/");
9829
9830 let response =
9831 futures_executor::block_on(stack.execute(&CheckHeaderHandler, &ctx, &mut req));
9832
9833 assert_eq!(response.status().as_u16(), 200);
9834 }
9835
9836 #[test]
9837 fn middleware_stack_multiple_response_modifications() {
9838 let mut stack = MiddlewareStack::new();
9839 stack.push(AddResponseHeader::new("X-First", b"1".to_vec()));
9840 stack.push(AddResponseHeader::new("X-Second", b"2".to_vec()));
9841 stack.push(AddResponseHeader::new("X-Third", b"3".to_vec()));
9842
9843 let ctx = test_context();
9844 let mut req = Request::new(crate::request::Method::Get, "/");
9845
9846 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9847
9848 assert_eq!(header_value(&response, "X-First"), Some("1".to_string()));
9850 assert_eq!(header_value(&response, "X-Second"), Some("2".to_string()));
9851 assert_eq!(header_value(&response, "X-Third"), Some("3".to_string()));
9852 }
9853
9854 #[test]
9855 fn middleware_stack_handler_receives_response_after_break() {
9856 let mut stack = MiddlewareStack::new();
9858 stack.push(ConditionalBreakMiddleware::new(
9859 "breaker",
9860 true,
9861 Arc::new(std::sync::Mutex::new(Vec::new())),
9862 ));
9863
9864 let ctx = test_context();
9865 let mut req = Request::new(crate::request::Method::Get, "/");
9866
9867 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9868
9869 assert_eq!(response.status().as_u16(), 403);
9870 match response.body_ref() {
9872 ResponseBody::Bytes(b) => assert_eq!(b, b"blocked"),
9873 _ => panic!("Expected Bytes body"),
9874 }
9875 }
9876
9877 #[test]
9882 fn middleware_after_can_change_status() {
9883 struct StatusChanger;
9884
9885 impl Middleware for StatusChanger {
9886 fn after<'a>(
9887 &'a self,
9888 _ctx: &'a RequestContext,
9889 _req: &'a Request,
9890 _response: Response,
9891 ) -> BoxFuture<'a, Response> {
9892 Box::pin(async { Response::with_status(StatusCode::SERVICE_UNAVAILABLE) })
9893 }
9894 }
9895
9896 let mut stack = MiddlewareStack::new();
9897 stack.push(StatusChanger);
9898
9899 let ctx = test_context();
9900 let mut req = Request::new(crate::request::Method::Get, "/");
9901
9902 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9903
9904 assert_eq!(response.status().as_u16(), 503);
9906 }
9907
9908 #[test]
9909 fn middleware_after_runs_even_on_error_status() {
9910 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9911 let mut stack = MiddlewareStack::new();
9912 stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9913
9914 let ctx = test_context();
9915 let mut req = Request::new(crate::request::Method::Get, "/");
9916
9917 let response = futures_executor::block_on(stack.execute(&ErrorHandler, &ctx, &mut req));
9918
9919 assert_eq!(response.status().as_u16(), 500);
9920
9921 let calls = log.lock().unwrap().clone();
9922 assert_eq!(calls, vec!["mw1.before", "mw1.after"]);
9924 }
9925
9926 #[test]
9931 fn wildcard_match_simple() {
9932 assert!(super::wildcard_match("*.example.com", "api.example.com"));
9933 assert!(super::wildcard_match("*.example.com", "www.example.com"));
9934 assert!(!super::wildcard_match("*.example.com", "example.com"));
9935 }
9936
9937 #[test]
9938 fn wildcard_match_suffix_pattern() {
9939 assert!(super::wildcard_match("*.txt", "file.txt"));
9941 assert!(super::wildcard_match("*.txt", "document.txt"));
9942 assert!(!super::wildcard_match("*.txt", "file.doc"));
9943 assert!(super::wildcard_match("*-suffix", "any-suffix"));
9944 }
9945
9946 #[test]
9947 fn wildcard_match_no_wildcard() {
9948 assert!(super::wildcard_match("exact", "exact"));
9949 assert!(!super::wildcard_match("exact", "different"));
9950 }
9951
9952 #[test]
9953 fn regex_match_anchored() {
9954 assert!(super::regex_match("^hello$", "hello"));
9955 assert!(!super::regex_match("^hello$", "hello world"));
9956 assert!(!super::regex_match("^hello$", "say hello"));
9957 }
9958
9959 #[test]
9960 fn regex_match_dot_wildcard() {
9961 assert!(super::regex_match("h.llo", "hello"));
9962 assert!(super::regex_match("h.llo", "hallo"));
9963 }
9964
9965 #[test]
9966 fn regex_match_star() {
9967 assert!(super::regex_match("hel*o", "hello"));
9968 assert!(super::regex_match("hel*o", "helo"));
9969 assert!(super::regex_match("hel*o", "hellllllo"));
9970 }
9971
9972 #[test]
9977 fn middleware_default_before_continues() {
9978 struct DefaultBefore;
9979 impl Middleware for DefaultBefore {}
9980
9981 let mw = DefaultBefore;
9982 let ctx = test_context();
9983 let mut req = Request::new(crate::request::Method::Get, "/");
9984
9985 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9986 assert!(matches!(result, ControlFlow::Continue));
9987 }
9988
9989 #[test]
9990 fn middleware_default_after_passes_through() {
9991 struct DefaultAfter;
9992 impl Middleware for DefaultAfter {}
9993
9994 let mw = DefaultAfter;
9995 let ctx = test_context();
9996 let req = Request::new(crate::request::Method::Get, "/");
9997 let response = Response::with_status(StatusCode::CREATED);
9998
9999 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10000 assert_eq!(result.status().as_u16(), 201);
10001 }
10002
10003 #[test]
10004 fn middleware_default_name_is_type_name() {
10005 struct MyCustomMiddleware;
10006 impl Middleware for MyCustomMiddleware {}
10007
10008 let mw = MyCustomMiddleware;
10009 assert!(mw.name().contains("MyCustomMiddleware"));
10010 }
10011
10012 #[test]
10017 fn security_headers_default_config() {
10018 let config = SecurityHeadersConfig::default();
10019 assert_eq!(config.x_content_type_options, Some("nosniff"));
10020 assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
10021 assert_eq!(config.x_xss_protection, Some("0"));
10022 assert!(config.content_security_policy.is_none());
10023 assert!(config.hsts.is_none());
10024 assert_eq!(
10025 config.referrer_policy,
10026 Some(ReferrerPolicy::StrictOriginWhenCrossOrigin)
10027 );
10028 assert!(config.permissions_policy.is_none());
10029 }
10030
10031 #[test]
10032 fn security_headers_none_config() {
10033 let config = SecurityHeadersConfig::none();
10034 assert!(config.x_content_type_options.is_none());
10035 assert!(config.x_frame_options.is_none());
10036 assert!(config.x_xss_protection.is_none());
10037 assert!(config.content_security_policy.is_none());
10038 assert!(config.hsts.is_none());
10039 assert!(config.referrer_policy.is_none());
10040 assert!(config.permissions_policy.is_none());
10041 }
10042
10043 #[test]
10044 fn security_headers_strict_config() {
10045 let config = SecurityHeadersConfig::strict();
10046 assert_eq!(config.x_content_type_options, Some("nosniff"));
10047 assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
10048 assert_eq!(
10049 config.content_security_policy,
10050 Some("default-src 'self'".to_string())
10051 );
10052 assert_eq!(config.hsts, Some((31536000, true, false)));
10053 assert_eq!(config.referrer_policy, Some(ReferrerPolicy::NoReferrer));
10054 assert!(config.permissions_policy.is_some());
10055 }
10056
10057 #[test]
10058 fn security_headers_config_builder() {
10059 let config = SecurityHeadersConfig::new()
10060 .x_frame_options(Some(XFrameOptions::SameOrigin))
10061 .content_security_policy("default-src 'self'")
10062 .hsts(86400, false, false)
10063 .referrer_policy(Some(ReferrerPolicy::Origin));
10064
10065 assert_eq!(config.x_frame_options, Some(XFrameOptions::SameOrigin));
10066 assert_eq!(
10067 config.content_security_policy,
10068 Some("default-src 'self'".to_string())
10069 );
10070 assert_eq!(config.hsts, Some((86400, false, false)));
10071 assert_eq!(config.referrer_policy, Some(ReferrerPolicy::Origin));
10072 }
10073
10074 #[test]
10075 fn security_headers_hsts_value_format() {
10076 let config = SecurityHeadersConfig::none().hsts(3600, false, false);
10078 assert_eq!(config.build_hsts_value(), Some("max-age=3600".to_string()));
10079
10080 let config = SecurityHeadersConfig::none().hsts(3600, true, false);
10082 assert_eq!(
10083 config.build_hsts_value(),
10084 Some("max-age=3600; includeSubDomains".to_string())
10085 );
10086
10087 let config = SecurityHeadersConfig::none().hsts(3600, false, true);
10089 assert_eq!(
10090 config.build_hsts_value(),
10091 Some("max-age=3600; preload".to_string())
10092 );
10093
10094 let config = SecurityHeadersConfig::none().hsts(3600, true, true);
10096 assert_eq!(
10097 config.build_hsts_value(),
10098 Some("max-age=3600; includeSubDomains; preload".to_string())
10099 );
10100 }
10101
10102 #[test]
10103 fn security_headers_middleware_adds_default_headers() {
10104 let mw = SecurityHeaders::new();
10105 let ctx = test_context();
10106 let req = Request::new(crate::request::Method::Get, "/");
10107 let response = Response::ok();
10108
10109 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10110
10111 assert!(header_value(&result, "X-Content-Type-Options").is_some());
10113 assert!(header_value(&result, "X-Frame-Options").is_some());
10114 assert!(header_value(&result, "X-XSS-Protection").is_some());
10115 assert!(header_value(&result, "Referrer-Policy").is_some());
10116
10117 assert!(header_value(&result, "Content-Security-Policy").is_none());
10119 assert!(header_value(&result, "Strict-Transport-Security").is_none());
10120 assert!(header_value(&result, "Permissions-Policy").is_none());
10121 }
10122
10123 #[test]
10124 fn security_headers_middleware_with_csp() {
10125 let config = SecurityHeadersConfig::new()
10126 .content_security_policy("default-src 'self'; script-src 'self' 'unsafe-inline'");
10127 let mw = SecurityHeaders::with_config(config);
10128 let ctx = test_context();
10129 let req = Request::new(crate::request::Method::Get, "/");
10130 let response = Response::ok();
10131
10132 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10133
10134 let csp = header_value(&result, "Content-Security-Policy");
10135 assert!(csp.is_some());
10136 assert_eq!(
10137 csp.unwrap(),
10138 "default-src 'self'; script-src 'self' 'unsafe-inline'"
10139 );
10140 }
10141
10142 #[test]
10143 fn security_headers_middleware_with_hsts() {
10144 let config = SecurityHeadersConfig::new().hsts(31536000, true, false);
10145 let mw = SecurityHeaders::with_config(config);
10146 let ctx = test_context();
10147 let req = Request::new(crate::request::Method::Get, "/");
10148 let response = Response::ok();
10149
10150 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10151
10152 let hsts = header_value(&result, "Strict-Transport-Security");
10153 assert!(hsts.is_some());
10154 assert_eq!(hsts.unwrap(), "max-age=31536000; includeSubDomains");
10155 }
10156
10157 #[test]
10158 fn security_headers_middleware_name() {
10159 let mw = SecurityHeaders::new();
10160 assert_eq!(mw.name(), "SecurityHeaders");
10161 }
10162
10163 #[test]
10164 fn x_frame_options_values() {
10165 assert_eq!(XFrameOptions::Deny.as_bytes(), b"DENY");
10166 assert_eq!(XFrameOptions::SameOrigin.as_bytes(), b"SAMEORIGIN");
10167 }
10168
10169 #[test]
10170 fn referrer_policy_values() {
10171 assert_eq!(ReferrerPolicy::NoReferrer.as_bytes(), b"no-referrer");
10172 assert_eq!(
10173 ReferrerPolicy::NoReferrerWhenDowngrade.as_bytes(),
10174 b"no-referrer-when-downgrade"
10175 );
10176 assert_eq!(ReferrerPolicy::Origin.as_bytes(), b"origin");
10177 assert_eq!(
10178 ReferrerPolicy::OriginWhenCrossOrigin.as_bytes(),
10179 b"origin-when-cross-origin"
10180 );
10181 assert_eq!(ReferrerPolicy::SameOrigin.as_bytes(), b"same-origin");
10182 assert_eq!(ReferrerPolicy::StrictOrigin.as_bytes(), b"strict-origin");
10183 assert_eq!(
10184 ReferrerPolicy::StrictOriginWhenCrossOrigin.as_bytes(),
10185 b"strict-origin-when-cross-origin"
10186 );
10187 assert_eq!(ReferrerPolicy::UnsafeUrl.as_bytes(), b"unsafe-url");
10188 }
10189
10190 #[test]
10191 fn security_headers_strict_preset() {
10192 let mw = SecurityHeaders::strict();
10193 let ctx = test_context();
10194 let req = Request::new(crate::request::Method::Get, "/");
10195 let response = Response::ok();
10196
10197 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10198
10199 assert!(header_value(&result, "X-Content-Type-Options").is_some());
10201 assert!(header_value(&result, "X-Frame-Options").is_some());
10202 assert!(header_value(&result, "Content-Security-Policy").is_some());
10203 assert!(header_value(&result, "Strict-Transport-Security").is_some());
10204 assert!(header_value(&result, "Referrer-Policy").is_some());
10205 assert!(header_value(&result, "Permissions-Policy").is_some());
10206 }
10207
10208 #[test]
10209 fn security_headers_config_clearing_methods() {
10210 let config = SecurityHeadersConfig::strict()
10211 .no_content_security_policy()
10212 .no_hsts()
10213 .no_permissions_policy();
10214
10215 assert!(config.content_security_policy.is_none());
10216 assert!(config.hsts.is_none());
10217 assert!(config.permissions_policy.is_none());
10218 }
10219
10220 #[test]
10225 fn csrf_token_generate_produces_unique_tokens() {
10226 let token1 = CsrfToken::generate();
10227 let token2 = CsrfToken::generate();
10228 assert_ne!(token1, token2);
10229 assert!(!token1.as_str().is_empty());
10230 assert!(!token2.as_str().is_empty());
10231 }
10232
10233 #[test]
10234 fn csrf_token_display() {
10235 let token = CsrfToken::new("test-token-123");
10236 assert_eq!(format!("{}", token), "test-token-123");
10237 }
10238
10239 #[test]
10240 fn csrf_config_defaults() {
10241 let config = CsrfConfig::default();
10242 assert_eq!(config.cookie_name, "csrf_token");
10243 assert_eq!(config.header_name, "x-csrf-token");
10244 assert_eq!(config.mode, CsrfMode::DoubleSubmit);
10245 assert!(!config.rotate_token);
10246 assert!(config.production);
10247 assert!(config.error_message.is_none());
10248 }
10249
10250 #[test]
10251 fn csrf_config_builder() {
10252 let config = CsrfConfig::new()
10253 .cookie_name("XSRF-TOKEN")
10254 .header_name("X-XSRF-Token")
10255 .mode(CsrfMode::HeaderOnly)
10256 .rotate_token(true)
10257 .production(false)
10258 .error_message("Custom CSRF error");
10259
10260 assert_eq!(config.cookie_name, "XSRF-TOKEN");
10261 assert_eq!(config.header_name, "X-XSRF-Token");
10262 assert_eq!(config.mode, CsrfMode::HeaderOnly);
10263 assert!(config.rotate_token);
10264 assert!(!config.production);
10265 assert_eq!(config.error_message, Some("Custom CSRF error".to_string()));
10266 }
10267
10268 #[test]
10269 fn csrf_middleware_allows_get_without_token() {
10270 let csrf = CsrfMiddleware::new();
10271 let ctx = test_context();
10272 let mut req = Request::new(crate::request::Method::Get, "/");
10273
10274 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10275 assert!(result.is_continue());
10276 assert!(req.get_extension::<CsrfToken>().is_some());
10278 }
10279
10280 #[test]
10281 fn csrf_middleware_allows_head_without_token() {
10282 let csrf = CsrfMiddleware::new();
10283 let ctx = test_context();
10284 let mut req = Request::new(crate::request::Method::Head, "/");
10285
10286 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10287 assert!(result.is_continue());
10288 }
10289
10290 #[test]
10291 fn csrf_middleware_allows_options_without_token() {
10292 let csrf = CsrfMiddleware::new();
10293 let ctx = test_context();
10294 let mut req = Request::new(crate::request::Method::Options, "/");
10295
10296 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10297 assert!(result.is_continue());
10298 }
10299
10300 #[test]
10301 fn csrf_middleware_blocks_post_without_token() {
10302 let csrf = CsrfMiddleware::new();
10303 let ctx = test_context();
10304 let mut req = Request::new(crate::request::Method::Post, "/");
10305
10306 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10307 assert!(result.is_break());
10308
10309 if let ControlFlow::Break(response) = result {
10310 assert_eq!(response.status(), StatusCode::FORBIDDEN);
10311 }
10312 }
10313
10314 #[test]
10315 fn csrf_middleware_blocks_put_without_token() {
10316 let csrf = CsrfMiddleware::new();
10317 let ctx = test_context();
10318 let mut req = Request::new(crate::request::Method::Put, "/");
10319
10320 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10321 assert!(result.is_break());
10322 }
10323
10324 #[test]
10325 fn csrf_middleware_blocks_delete_without_token() {
10326 let csrf = CsrfMiddleware::new();
10327 let ctx = test_context();
10328 let mut req = Request::new(crate::request::Method::Delete, "/");
10329
10330 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10331 assert!(result.is_break());
10332 }
10333
10334 #[test]
10335 fn csrf_middleware_blocks_patch_without_token() {
10336 let csrf = CsrfMiddleware::new();
10337 let ctx = test_context();
10338 let mut req = Request::new(crate::request::Method::Patch, "/");
10339
10340 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10341 assert!(result.is_break());
10342 }
10343
10344 #[test]
10345 fn csrf_middleware_allows_post_with_matching_tokens() {
10346 let csrf = CsrfMiddleware::new();
10347 let ctx = test_context();
10348 let mut req = Request::new(crate::request::Method::Post, "/");
10349
10350 let token = "valid-csrf-token-12345";
10352 req.headers_mut()
10353 .insert("cookie", format!("csrf_token={}", token).into_bytes());
10354 req.headers_mut()
10355 .insert("x-csrf-token", token.as_bytes().to_vec());
10356
10357 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10358 assert!(result.is_continue());
10359
10360 let stored_token = req.get_extension::<CsrfToken>().unwrap();
10362 assert_eq!(stored_token.as_str(), token);
10363 }
10364
10365 #[test]
10366 fn csrf_middleware_blocks_post_with_mismatched_tokens() {
10367 let csrf = CsrfMiddleware::new();
10368 let ctx = test_context();
10369 let mut req = Request::new(crate::request::Method::Post, "/");
10370
10371 req.headers_mut()
10373 .insert("cookie", b"csrf_token=token-in-cookie".to_vec());
10374 req.headers_mut()
10375 .insert("x-csrf-token", b"different-token".to_vec());
10376
10377 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10378 assert!(result.is_break());
10379
10380 if let ControlFlow::Break(response) = result {
10381 assert_eq!(response.status(), StatusCode::FORBIDDEN);
10382 }
10383 }
10384
10385 #[test]
10386 fn csrf_middleware_blocks_post_with_header_only_in_double_submit_mode() {
10387 let csrf = CsrfMiddleware::new();
10388 let ctx = test_context();
10389 let mut req = Request::new(crate::request::Method::Post, "/");
10390
10391 req.headers_mut()
10393 .insert("x-csrf-token", b"some-token".to_vec());
10394
10395 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10396 assert!(result.is_break());
10397 }
10398
10399 #[test]
10400 fn csrf_middleware_blocks_post_with_cookie_only_in_double_submit_mode() {
10401 let csrf = CsrfMiddleware::new();
10402 let ctx = test_context();
10403 let mut req = Request::new(crate::request::Method::Post, "/");
10404
10405 req.headers_mut()
10407 .insert("cookie", b"csrf_token=some-token".to_vec());
10408
10409 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10410 assert!(result.is_break());
10411 }
10412
10413 #[test]
10414 fn csrf_middleware_header_only_mode_accepts_header_token() {
10415 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10416 let ctx = test_context();
10417 let mut req = Request::new(crate::request::Method::Post, "/");
10418
10419 req.headers_mut()
10420 .insert("x-csrf-token", b"valid-token".to_vec());
10421
10422 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10423 assert!(result.is_continue());
10424 }
10425
10426 #[test]
10427 fn csrf_middleware_header_only_mode_rejects_empty_header() {
10428 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10429 let ctx = test_context();
10430 let mut req = Request::new(crate::request::Method::Post, "/");
10431
10432 req.headers_mut().insert("x-csrf-token", b"".to_vec());
10433
10434 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10435 assert!(result.is_break());
10436 }
10437
10438 #[test]
10439 fn csrf_middleware_sets_cookie_on_get() {
10440 let csrf = CsrfMiddleware::new();
10441 let ctx = test_context();
10442 let mut req = Request::new(crate::request::Method::Get, "/");
10443
10444 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10446
10447 let response = Response::ok();
10449 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10450
10451 let cookie_value = header_value(&result, "set-cookie");
10453 assert!(cookie_value.is_some());
10454
10455 let cookie_value = cookie_value.unwrap();
10456 assert!(cookie_value.starts_with("csrf_token="));
10457 assert!(cookie_value.contains("SameSite=Strict"));
10458 assert!(cookie_value.contains("Secure")); }
10460
10461 #[test]
10462 fn csrf_middleware_no_secure_in_dev_mode() {
10463 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(false));
10464 let ctx = test_context();
10465 let mut req = Request::new(crate::request::Method::Get, "/");
10466
10467 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10468
10469 let response = Response::ok();
10470 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10471
10472 let cookie_value = header_value(&result, "set-cookie").unwrap();
10473 assert!(!cookie_value.contains("Secure")); }
10475
10476 #[test]
10477 fn csrf_middleware_does_not_set_cookie_if_already_present() {
10478 let csrf = CsrfMiddleware::new();
10479 let ctx = test_context();
10480 let mut req = Request::new(crate::request::Method::Get, "/");
10481
10482 req.headers_mut()
10484 .insert("cookie", b"csrf_token=existing-token".to_vec());
10485
10486 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10487
10488 let response = Response::ok();
10489 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10490
10491 assert!(header_value(&result, "set-cookie").is_none());
10493 }
10494
10495 #[test]
10496 fn csrf_middleware_rotates_token_when_configured() {
10497 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
10498 let ctx = test_context();
10499 let mut req = Request::new(crate::request::Method::Get, "/");
10500
10501 req.headers_mut()
10503 .insert("cookie", b"csrf_token=old-token".to_vec());
10504
10505 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10506
10507 let response = Response::ok();
10508 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10509
10510 assert!(header_value(&result, "set-cookie").is_some());
10512 }
10513
10514 #[test]
10515 fn csrf_middleware_custom_header_name() {
10516 let csrf = CsrfMiddleware::with_config(
10517 CsrfConfig::new()
10518 .header_name("X-XSRF-Token")
10519 .cookie_name("XSRF-TOKEN"),
10520 );
10521 let ctx = test_context();
10522 let mut req = Request::new(crate::request::Method::Post, "/");
10523
10524 let token = "custom-token-value";
10525 req.headers_mut()
10526 .insert("cookie", format!("XSRF-TOKEN={}", token).into_bytes());
10527 req.headers_mut()
10528 .insert("x-xsrf-token", token.as_bytes().to_vec());
10529
10530 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10531 assert!(result.is_continue());
10532 }
10533
10534 #[test]
10535 fn csrf_middleware_error_response_is_json() {
10536 let csrf = CsrfMiddleware::new();
10537 let ctx = test_context();
10538 let mut req = Request::new(crate::request::Method::Post, "/");
10539
10540 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10541
10542 if let ControlFlow::Break(response) = result {
10543 let content_type = header_value(&response, "content-type");
10544 assert_eq!(content_type, Some("application/json".to_string()));
10545
10546 if let ResponseBody::Bytes(body) = response.body_ref() {
10548 let body_str = std::str::from_utf8(body).unwrap();
10549 assert!(body_str.contains("csrf_error"));
10550 assert!(body_str.contains("x-csrf-token"));
10551 } else {
10552 panic!("Expected Bytes body");
10553 }
10554 } else {
10555 panic!("Expected Break");
10556 }
10557 }
10558
10559 #[test]
10560 fn csrf_middleware_custom_error_message() {
10561 let csrf = CsrfMiddleware::with_config(
10562 CsrfConfig::new().error_message("Access denied: invalid security token"),
10563 );
10564 let ctx = test_context();
10565 let mut req = Request::new(crate::request::Method::Post, "/");
10566
10567 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10568
10569 if let ControlFlow::Break(response) = result {
10570 if let ResponseBody::Bytes(body) = response.body_ref() {
10571 let body_str = std::str::from_utf8(body).unwrap();
10572 assert!(body_str.contains("Access denied: invalid security token"));
10573 }
10574 }
10575 }
10576
10577 #[test]
10578 fn csrf_middleware_name() {
10579 let csrf = CsrfMiddleware::new();
10580 assert_eq!(csrf.name(), "CSRF");
10581 }
10582
10583 #[test]
10584 fn csrf_middleware_parses_cookie_with_multiple_cookies() {
10585 let csrf = CsrfMiddleware::new();
10586 let ctx = test_context();
10587 let mut req = Request::new(crate::request::Method::Post, "/");
10588
10589 let token = "the-csrf-token";
10591 req.headers_mut().insert(
10592 "cookie",
10593 format!("session=abc123; csrf_token={}; user=test", token).into_bytes(),
10594 );
10595 req.headers_mut()
10596 .insert("x-csrf-token", token.as_bytes().to_vec());
10597
10598 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10599 assert!(result.is_continue());
10600 }
10601
10602 #[test]
10603 fn csrf_middleware_handles_empty_token_value() {
10604 let csrf = CsrfMiddleware::new();
10605 let ctx = test_context();
10606 let mut req = Request::new(crate::request::Method::Post, "/");
10607
10608 req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10610 req.headers_mut().insert("x-csrf-token", b"".to_vec());
10611
10612 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10613 assert!(result.is_break()); }
10615
10616 #[test]
10619 fn csrf_token_generate_many_unique() {
10620 let mut tokens = std::collections::HashSet::new();
10622 for _ in 0..100 {
10623 let token = CsrfToken::generate();
10624 assert!(
10625 tokens.insert(token.0.clone()),
10626 "Duplicate token generated: {}",
10627 token.0
10628 );
10629 }
10630 assert_eq!(tokens.len(), 100);
10631 }
10632
10633 #[test]
10634 fn csrf_token_generate_format_is_hex() {
10635 let token = CsrfToken::generate();
10636 let s = token.as_str();
10637 assert!(
10639 s.len() >= 64,
10640 "Expected at least 64 hex characters, got {} in '{s}'",
10641 s.len()
10642 );
10643 assert!(
10644 s.chars().all(|c| c.is_ascii_hexdigit()),
10645 "Non-hex character in token: {s}"
10646 );
10647 }
10648
10649 #[test]
10650 fn csrf_token_generate_minimum_length() {
10651 let token = CsrfToken::generate();
10652 assert!(
10654 token.as_str().len() >= 64,
10655 "Token too short: {} (len={})",
10656 token.as_str(),
10657 token.as_str().len()
10658 );
10659 }
10660
10661 #[test]
10662 fn csrf_token_from_str() {
10663 let token: CsrfToken = "my-token".into();
10664 assert_eq!(token.as_str(), "my-token");
10665 assert_eq!(token.0, "my-token");
10666 }
10667
10668 #[test]
10669 fn csrf_token_clone_eq() {
10670 let t1 = CsrfToken::new("abc");
10671 let t2 = t1.clone();
10672 assert_eq!(t1, t2);
10673 assert_eq!(t1.as_str(), t2.as_str());
10674 }
10675
10676 #[test]
10677 fn csrf_middleware_allows_trace_without_token() {
10678 let csrf = CsrfMiddleware::new();
10679 let ctx = test_context();
10680 let mut req = Request::new(crate::request::Method::Trace, "/");
10681
10682 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10683 assert!(result.is_continue());
10684 assert!(req.get_extension::<CsrfToken>().is_some());
10686 }
10687
10688 #[test]
10689 fn csrf_safe_method_generates_token_into_extension() {
10690 let csrf = CsrfMiddleware::new();
10691 let ctx = test_context();
10692
10693 for method in [
10694 crate::request::Method::Get,
10695 crate::request::Method::Head,
10696 crate::request::Method::Options,
10697 crate::request::Method::Trace,
10698 ] {
10699 let mut req = Request::new(method, "/test");
10700 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10701 assert!(result.is_continue());
10702 let token = req.get_extension::<CsrfToken>().expect("token missing");
10703 assert!(!token.as_str().is_empty());
10704 }
10705 }
10706
10707 #[test]
10708 fn csrf_safe_method_preserves_existing_cookie_token() {
10709 let csrf = CsrfMiddleware::new();
10710 let ctx = test_context();
10711 let mut req = Request::new(crate::request::Method::Get, "/");
10712 req.headers_mut()
10713 .insert("cookie", b"csrf_token=my-existing-token".to_vec());
10714
10715 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10716
10717 let token = req.get_extension::<CsrfToken>().unwrap();
10719 assert_eq!(token.as_str(), "my-existing-token");
10720 }
10721
10722 #[test]
10723 fn csrf_valid_post_stores_token_in_extension() {
10724 let csrf = CsrfMiddleware::new();
10725 let ctx = test_context();
10726 let mut req = Request::new(crate::request::Method::Post, "/submit");
10727
10728 let tk = "valid-token-xyz";
10729 req.headers_mut()
10730 .insert("cookie", format!("csrf_token={}", tk).into_bytes());
10731 req.headers_mut()
10732 .insert("x-csrf-token", tk.as_bytes().to_vec());
10733
10734 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10735 assert!(result.is_continue());
10736 let stored = req.get_extension::<CsrfToken>().unwrap();
10737 assert_eq!(stored.as_str(), tk);
10738 }
10739
10740 #[test]
10741 fn csrf_double_submit_both_empty_strings_rejected() {
10742 let csrf = CsrfMiddleware::new();
10743 let ctx = test_context();
10744 let mut req = Request::new(crate::request::Method::Post, "/");
10745
10746 req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10748 req.headers_mut().insert("x-csrf-token", b"".to_vec());
10749
10750 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10751 assert!(result.is_break());
10752 }
10753
10754 #[test]
10755 fn csrf_double_submit_matching_empty_rejected() {
10756 let csrf = CsrfMiddleware::new();
10758 let ctx = test_context();
10759 let mut req = Request::new(crate::request::Method::Post, "/");
10760
10761 req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10762 req.headers_mut().insert("x-csrf-token", b"".to_vec());
10763
10764 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10765 assert!(
10766 result.is_break(),
10767 "Empty matching tokens should be rejected"
10768 );
10769 }
10770
10771 #[test]
10772 fn csrf_header_only_mode_does_not_need_cookie() {
10773 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10774 let ctx = test_context();
10775 let mut req = Request::new(crate::request::Method::Post, "/");
10776
10777 req.headers_mut()
10779 .insert("x-csrf-token", b"header-only-token".to_vec());
10780
10781 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10782 assert!(result.is_continue());
10783 let token = req.get_extension::<CsrfToken>().unwrap();
10784 assert_eq!(token.as_str(), "header-only-token");
10785 }
10786
10787 #[test]
10788 fn csrf_header_only_mode_ignores_mismatched_cookie() {
10789 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10791 let ctx = test_context();
10792 let mut req = Request::new(crate::request::Method::Post, "/");
10793
10794 req.headers_mut()
10795 .insert("cookie", b"csrf_token=different-value".to_vec());
10796 req.headers_mut()
10797 .insert("x-csrf-token", b"header-value".to_vec());
10798
10799 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10800 assert!(result.is_continue(), "HeaderOnly should ignore cookie");
10801 }
10802
10803 #[test]
10804 fn csrf_header_only_mode_rejects_no_header() {
10805 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10806 let ctx = test_context();
10807 let mut req = Request::new(crate::request::Method::Post, "/");
10808 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10810 assert!(result.is_break());
10811 }
10812
10813 #[test]
10814 fn csrf_header_only_error_message_mentions_header() {
10815 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10816 let ctx = test_context();
10817 let mut req = Request::new(crate::request::Method::Post, "/");
10818
10819 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10820 if let ControlFlow::Break(response) = result {
10821 if let ResponseBody::Bytes(body) = response.body_ref() {
10822 let body_str = std::str::from_utf8(body).unwrap();
10823 assert!(
10824 body_str.contains("missing in header"),
10825 "Expected 'missing in header' in: {}",
10826 body_str
10827 );
10828 }
10829 } else {
10830 panic!("Expected Break");
10831 }
10832 }
10833
10834 #[test]
10835 fn csrf_mismatch_error_differs_from_missing_error() {
10836 let csrf = CsrfMiddleware::new();
10837 let ctx = test_context();
10838
10839 let mut req_missing = Request::new(crate::request::Method::Post, "/");
10841 let missing_result = futures_executor::block_on(csrf.before(&ctx, &mut req_missing));
10842 let missing_body = match missing_result {
10843 ControlFlow::Break(r) => match r.body_ref() {
10844 ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
10845 ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
10846 },
10847 ControlFlow::Continue => panic!("Expected Break"),
10848 };
10849
10850 let mut req_mismatch = Request::new(crate::request::Method::Post, "/");
10852 req_mismatch
10853 .headers_mut()
10854 .insert("cookie", b"csrf_token=aaa".to_vec());
10855 req_mismatch
10856 .headers_mut()
10857 .insert("x-csrf-token", b"bbb".to_vec());
10858 let mismatch_result = futures_executor::block_on(csrf.before(&ctx, &mut req_mismatch));
10859 let mismatch_body = match mismatch_result {
10860 ControlFlow::Break(r) => match r.body_ref() {
10861 ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
10862 ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
10863 },
10864 ControlFlow::Continue => panic!("Expected Break"),
10865 };
10866
10867 assert_ne!(
10869 missing_body, mismatch_body,
10870 "Missing vs mismatch should have different error messages"
10871 );
10872 assert!(missing_body.contains("missing"));
10873 assert!(mismatch_body.contains("mismatch"));
10874 }
10875
10876 #[test]
10877 fn csrf_cookie_not_httponly() {
10878 let csrf = CsrfMiddleware::new();
10880 let ctx = test_context();
10881 let mut req = Request::new(crate::request::Method::Get, "/");
10882
10883 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10884 let response = Response::ok();
10885 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10886
10887 let cookie_value = header_value(&result, "set-cookie").unwrap();
10888 assert!(
10889 !cookie_value.to_lowercase().contains("httponly"),
10890 "CSRF cookie must NOT be HttpOnly (needs JS access), got: {}",
10891 cookie_value
10892 );
10893 }
10894
10895 #[test]
10896 fn csrf_cookie_has_path_slash() {
10897 let csrf = CsrfMiddleware::new();
10898 let ctx = test_context();
10899 let mut req = Request::new(crate::request::Method::Get, "/");
10900
10901 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10902 let response = Response::ok();
10903 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10904
10905 let cookie_value = header_value(&result, "set-cookie").unwrap();
10906 assert!(
10907 cookie_value.contains("Path=/"),
10908 "Cookie should have Path=/, got: {}",
10909 cookie_value
10910 );
10911 }
10912
10913 #[test]
10914 fn csrf_cookie_has_samesite_strict() {
10915 let csrf = CsrfMiddleware::new();
10916 let ctx = test_context();
10917 let mut req = Request::new(crate::request::Method::Get, "/");
10918
10919 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10920 let response = Response::ok();
10921 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10922
10923 let cookie_value = header_value(&result, "set-cookie").unwrap();
10924 assert!(
10925 cookie_value.contains("SameSite=Strict"),
10926 "Cookie should have SameSite=Strict, got: {}",
10927 cookie_value
10928 );
10929 }
10930
10931 #[test]
10932 fn csrf_production_mode_sets_secure_flag() {
10933 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(true));
10934 let ctx = test_context();
10935 let mut req = Request::new(crate::request::Method::Get, "/");
10936
10937 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10938 let response = Response::ok();
10939 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10940
10941 let cookie_value = header_value(&result, "set-cookie").unwrap();
10942 assert!(
10943 cookie_value.contains("Secure"),
10944 "Production cookie must have Secure flag, got: {}",
10945 cookie_value
10946 );
10947 }
10948
10949 #[test]
10950 fn csrf_no_set_cookie_on_post_response() {
10951 let csrf = CsrfMiddleware::new();
10953 let ctx = test_context();
10954 let mut req = Request::new(crate::request::Method::Post, "/");
10955
10956 let token = "valid-token";
10957 req.headers_mut()
10958 .insert("cookie", format!("csrf_token={}", token).into_bytes());
10959 req.headers_mut()
10960 .insert("x-csrf-token", token.as_bytes().to_vec());
10961
10962 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10963 let response = Response::ok();
10964 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10965
10966 assert!(
10967 header_value(&result, "set-cookie").is_none(),
10968 "POST response should not set CSRF cookie"
10969 );
10970 }
10971
10972 #[test]
10973 fn csrf_head_method_sets_cookie() {
10974 let csrf = CsrfMiddleware::new();
10975 let ctx = test_context();
10976 let mut req = Request::new(crate::request::Method::Head, "/");
10977
10978 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10979 let response = Response::ok();
10980 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10981
10982 assert!(
10983 header_value(&result, "set-cookie").is_some(),
10984 "HEAD response should set CSRF cookie"
10985 );
10986 }
10987
10988 #[test]
10989 fn csrf_options_method_sets_cookie() {
10990 let csrf = CsrfMiddleware::new();
10991 let ctx = test_context();
10992 let mut req = Request::new(crate::request::Method::Options, "/");
10993
10994 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10995 let response = Response::ok();
10996 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10997
10998 assert!(
10999 header_value(&result, "set-cookie").is_some(),
11000 "OPTIONS response should set CSRF cookie"
11001 );
11002 }
11003
11004 #[test]
11005 fn csrf_rotation_produces_different_token_in_cookie() {
11006 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
11007 let ctx = test_context();
11008 let mut req = Request::new(crate::request::Method::Get, "/");
11009
11010 let old_token = "old-token-value";
11011 req.headers_mut()
11012 .insert("cookie", format!("csrf_token={}", old_token).into_bytes());
11013
11014 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
11015 let response = Response::ok();
11016 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
11017
11018 let cookie_value = header_value(&result, "set-cookie").unwrap();
11019 assert!(cookie_value.starts_with("csrf_token="));
11024 }
11025
11026 #[test]
11027 fn csrf_no_rotation_skips_set_cookie_when_present() {
11028 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(false));
11029 let ctx = test_context();
11030 let mut req = Request::new(crate::request::Method::Get, "/");
11031
11032 req.headers_mut()
11033 .insert("cookie", b"csrf_token=existing".to_vec());
11034
11035 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
11036 let response = Response::ok();
11037 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
11038
11039 assert!(
11040 header_value(&result, "set-cookie").is_none(),
11041 "Without rotation, should not re-set existing cookie"
11042 );
11043 }
11044
11045 #[test]
11046 fn csrf_custom_cookie_name_in_set_cookie_response() {
11047 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().cookie_name("XSRF-TOKEN"));
11048 let ctx = test_context();
11049 let mut req = Request::new(crate::request::Method::Get, "/");
11050
11051 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
11052 let response = Response::ok();
11053 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
11054
11055 let cookie_value = header_value(&result, "set-cookie").unwrap();
11056 assert!(
11057 cookie_value.starts_with("XSRF-TOKEN="),
11058 "Custom cookie name should appear in Set-Cookie, got: {}",
11059 cookie_value
11060 );
11061 }
11062
11063 #[test]
11064 fn csrf_custom_header_name_validated() {
11065 let csrf = CsrfMiddleware::with_config(
11066 CsrfConfig::new()
11067 .header_name("X-Custom-CSRF")
11068 .cookie_name("my_csrf"),
11069 );
11070 let ctx = test_context();
11071 let mut req = Request::new(crate::request::Method::Post, "/");
11072
11073 let token = "custom-tok";
11074 req.headers_mut()
11075 .insert("cookie", format!("my_csrf={}", token).into_bytes());
11076 req.headers_mut()
11077 .insert("x-custom-csrf", token.as_bytes().to_vec());
11078
11079 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11080 assert!(result.is_continue());
11081 }
11082
11083 #[test]
11084 fn csrf_custom_header_name_wrong_header_rejected() {
11085 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().header_name("X-Custom-CSRF"));
11086 let ctx = test_context();
11087 let mut req = Request::new(crate::request::Method::Post, "/");
11088
11089 let token = "some-token";
11090 req.headers_mut()
11091 .insert("cookie", format!("csrf_token={}", token).into_bytes());
11092 req.headers_mut()
11094 .insert("x-csrf-token", token.as_bytes().to_vec());
11095
11096 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11097 assert!(result.is_break(), "Wrong header name should be rejected");
11098 }
11099
11100 #[test]
11101 fn csrf_cookie_parsing_multiple_cookies_picks_correct() {
11102 let csrf = CsrfMiddleware::new();
11103 let ctx = test_context();
11104 let mut req = Request::new(crate::request::Method::Post, "/");
11105
11106 let token = "correct-csrf";
11107 req.headers_mut().insert(
11108 "cookie",
11109 format!("session=abc; other=xyz; csrf_token={}; tracking=123", token).into_bytes(),
11110 );
11111 req.headers_mut()
11112 .insert("x-csrf-token", token.as_bytes().to_vec());
11113
11114 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11115 assert!(result.is_continue());
11116 }
11117
11118 #[test]
11119 fn csrf_cookie_parsing_spaces_around_semicolons() {
11120 let csrf = CsrfMiddleware::new();
11121 let ctx = test_context();
11122 let mut req = Request::new(crate::request::Method::Post, "/");
11123
11124 let token = "spaced-token";
11125 req.headers_mut().insert(
11126 "cookie",
11127 format!("session=abc ; csrf_token={} ; other=xyz", token).into_bytes(),
11128 );
11129 req.headers_mut()
11130 .insert("x-csrf-token", token.as_bytes().to_vec());
11131
11132 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11133 assert!(result.is_continue());
11134 }
11135
11136 #[test]
11137 fn csrf_error_response_status_is_403() {
11138 let csrf = CsrfMiddleware::new();
11139 let ctx = test_context();
11140
11141 for method in [
11143 crate::request::Method::Post,
11144 crate::request::Method::Put,
11145 crate::request::Method::Delete,
11146 crate::request::Method::Patch,
11147 ] {
11148 let mut req = Request::new(method, "/");
11149 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11150 match result {
11151 ControlFlow::Break(response) => {
11152 assert_eq!(
11153 response.status(),
11154 StatusCode::FORBIDDEN,
11155 "Expected 403 for {:?}",
11156 method
11157 );
11158 }
11159 ControlFlow::Continue => panic!("Expected Break for {:?}", method),
11160 }
11161 }
11162 }
11163
11164 #[test]
11165 fn csrf_error_body_json_structure() {
11166 let csrf = CsrfMiddleware::new();
11167 let ctx = test_context();
11168 let mut req = Request::new(crate::request::Method::Post, "/");
11169
11170 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11171 if let ControlFlow::Break(response) = result {
11172 if let ResponseBody::Bytes(body) = response.body_ref() {
11173 let body_str = std::str::from_utf8(body).unwrap();
11174 let parsed: serde_json::Value = serde_json::from_str(body_str)
11176 .unwrap_or_else(|e| panic!("Invalid JSON: {}: {}", body_str, e));
11177 assert!(parsed["detail"].is_array());
11178 let detail = &parsed["detail"][0];
11179 assert_eq!(detail["type"], "csrf_error");
11180 assert!(detail["loc"].is_array());
11181 assert_eq!(detail["loc"][0], "header");
11182 assert_eq!(detail["loc"][1], "x-csrf-token");
11183 assert!(detail["msg"].is_string());
11184 } else {
11185 panic!("Expected Bytes body");
11186 }
11187 } else {
11188 panic!("Expected Break");
11189 }
11190 }
11191
11192 #[test]
11193 fn csrf_default_trait() {
11194 let csrf = CsrfMiddleware::default();
11195 assert_eq!(csrf.name(), "CSRF");
11196 let ctx = test_context();
11198 let mut req = Request::new(crate::request::Method::Get, "/");
11199 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11200 assert!(result.is_continue());
11201 }
11202
11203 #[test]
11204 fn csrf_mode_default_is_double_submit() {
11205 assert_eq!(CsrfMode::default(), CsrfMode::DoubleSubmit);
11206 }
11207
11208 #[test]
11209 fn csrf_double_submit_both_present_same_non_empty_passes() {
11210 let csrf = CsrfMiddleware::new();
11212 let ctx = test_context();
11213
11214 let token = "a1b2c3d4e5f6";
11215 let mut req = Request::new(crate::request::Method::Delete, "/resource/1");
11216 req.headers_mut()
11217 .insert("cookie", format!("csrf_token={}", token).into_bytes());
11218 req.headers_mut()
11219 .insert("x-csrf-token", token.as_bytes().to_vec());
11220
11221 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11222 assert!(result.is_continue());
11223 }
11224
11225 #[test]
11226 fn csrf_double_submit_case_sensitive() {
11227 let csrf = CsrfMiddleware::new();
11229 let ctx = test_context();
11230 let mut req = Request::new(crate::request::Method::Post, "/");
11231
11232 req.headers_mut()
11233 .insert("cookie", b"csrf_token=AbCdEf".to_vec());
11234 req.headers_mut().insert("x-csrf-token", b"abcdef".to_vec());
11235
11236 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11237 assert!(
11238 result.is_break(),
11239 "Token comparison should be case-sensitive"
11240 );
11241 }
11242
11243 #[test]
11244 fn csrf_token_cookie_extractor_reads_csrf_cookie() {
11245 use crate::extract::{CookieName, CsrfTokenCookie};
11247 assert_eq!(CsrfTokenCookie::NAME, "csrf_token");
11248 }
11249
11250 #[test]
11251 fn csrf_make_set_cookie_header_value_production() {
11252 let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", true);
11253 let s = std::str::from_utf8(&value).unwrap();
11254 assert!(s.contains("csrf_token=tok123"));
11255 assert!(s.contains("Path=/"));
11256 assert!(s.contains("SameSite=Strict"));
11257 assert!(s.contains("Secure"));
11258 assert!(!s.to_lowercase().contains("httponly"));
11259 }
11260
11261 #[test]
11262 fn csrf_make_set_cookie_header_value_development() {
11263 let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", false);
11264 let s = std::str::from_utf8(&value).unwrap();
11265 assert!(s.contains("csrf_token=tok123"));
11266 assert!(s.contains("Path=/"));
11267 assert!(s.contains("SameSite=Strict"));
11268 assert!(!s.contains("Secure"));
11269 }
11270
11271 #[test]
11272 fn csrf_before_after_full_cycle_get_then_post() {
11273 let csrf = CsrfMiddleware::new();
11275 let ctx = test_context();
11276
11277 let mut get_req = Request::new(crate::request::Method::Get, "/form");
11279 let _ = futures_executor::block_on(csrf.before(&ctx, &mut get_req));
11280 let get_response = Response::ok();
11281 let get_result = futures_executor::block_on(csrf.after(&ctx, &get_req, get_response));
11282
11283 let set_cookie = header_value(&get_result, "set-cookie").expect("GET should set cookie");
11284 let token_value = set_cookie
11286 .strip_prefix("csrf_token=")
11287 .unwrap()
11288 .split(';')
11289 .next()
11290 .unwrap();
11291 assert!(!token_value.is_empty());
11292
11293 let mut post_req = Request::new(crate::request::Method::Post, "/form");
11295 post_req
11296 .headers_mut()
11297 .insert("cookie", format!("csrf_token={}", token_value).into_bytes());
11298 post_req
11299 .headers_mut()
11300 .insert("x-csrf-token", token_value.as_bytes().to_vec());
11301
11302 let result = futures_executor::block_on(csrf.before(&ctx, &mut post_req));
11303 assert!(result.is_continue(), "POST with valid token should pass");
11304 }
11305
11306 #[test]
11307 fn csrf_all_state_changing_methods_require_token() {
11308 let csrf = CsrfMiddleware::new();
11309 let ctx = test_context();
11310
11311 for method in [
11312 crate::request::Method::Post,
11313 crate::request::Method::Put,
11314 crate::request::Method::Delete,
11315 crate::request::Method::Patch,
11316 ] {
11317 let mut req = Request::new(method, "/resource");
11318 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11319 assert!(
11320 result.is_break(),
11321 "{:?} without token should be rejected",
11322 method
11323 );
11324 }
11325 }
11326
11327 #[test]
11328 fn csrf_all_safe_methods_pass_without_token() {
11329 let csrf = CsrfMiddleware::new();
11330 let ctx = test_context();
11331
11332 for method in [
11333 crate::request::Method::Get,
11334 crate::request::Method::Head,
11335 crate::request::Method::Options,
11336 crate::request::Method::Trace,
11337 ] {
11338 let mut req = Request::new(method, "/resource");
11339 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11340 assert!(
11341 result.is_continue(),
11342 "{:?} should be allowed without token",
11343 method
11344 );
11345 }
11346 }
11347
11348 #[derive(Clone)]
11355 struct OrderRecordingMiddleware {
11356 id: &'static str,
11357 log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11358 }
11359
11360 impl OrderRecordingMiddleware {
11361 fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11362 Self { id, log }
11363 }
11364 }
11365
11366 impl Middleware for OrderRecordingMiddleware {
11367 fn before<'a>(
11368 &'a self,
11369 _ctx: &'a RequestContext,
11370 _req: &'a mut Request,
11371 ) -> BoxFuture<'a, ControlFlow> {
11372 let id = self.id;
11373 let log = self.log.clone();
11374 Box::pin(async move {
11375 log.lock().unwrap().push(format!("{id}:before"));
11376 ControlFlow::Continue
11377 })
11378 }
11379
11380 fn after<'a>(
11381 &'a self,
11382 _ctx: &'a RequestContext,
11383 _req: &'a Request,
11384 response: Response,
11385 ) -> BoxFuture<'a, Response> {
11386 let id = self.id;
11387 let log = self.log.clone();
11388 Box::pin(async move {
11389 log.lock().unwrap().push(format!("{id}:after"));
11390 response
11391 })
11392 }
11393
11394 fn name(&self) -> &'static str {
11395 "OrderRecording"
11396 }
11397 }
11398
11399 struct ShortCircuitMiddleware {
11401 id: &'static str,
11402 log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11403 }
11404
11405 impl ShortCircuitMiddleware {
11406 fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11407 Self { id, log }
11408 }
11409 }
11410
11411 impl Middleware for ShortCircuitMiddleware {
11412 fn before<'a>(
11413 &'a self,
11414 _ctx: &'a RequestContext,
11415 _req: &'a mut Request,
11416 ) -> BoxFuture<'a, ControlFlow> {
11417 let id = self.id;
11418 let log = self.log.clone();
11419 Box::pin(async move {
11420 log.lock().unwrap().push(format!("{id}:before:break"));
11421 ControlFlow::Break(
11422 Response::with_status(StatusCode::FORBIDDEN)
11423 .body(ResponseBody::Bytes(b"short-circuited".to_vec())),
11424 )
11425 })
11426 }
11427
11428 fn after<'a>(
11429 &'a self,
11430 _ctx: &'a RequestContext,
11431 _req: &'a Request,
11432 response: Response,
11433 ) -> BoxFuture<'a, Response> {
11434 let id = self.id;
11435 let log = self.log.clone();
11436 Box::pin(async move {
11437 log.lock().unwrap().push(format!("{id}:after"));
11438 response
11439 })
11440 }
11441
11442 fn name(&self) -> &'static str {
11443 "ShortCircuit"
11444 }
11445 }
11446
11447 struct RecordingHandler {
11449 log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11450 }
11451
11452 impl RecordingHandler {
11453 fn new(log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11454 Self { log }
11455 }
11456 }
11457
11458 impl Handler for RecordingHandler {
11459 fn call<'a>(
11460 &'a self,
11461 _ctx: &'a RequestContext,
11462 _req: &'a mut Request,
11463 ) -> BoxFuture<'a, Response> {
11464 let log = self.log.clone();
11465 Box::pin(async move {
11466 log.lock().unwrap().push("handler".to_string());
11467 Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()))
11468 })
11469 }
11470 }
11471
11472 #[test]
11473 fn middleware_stack_three_middleware_onion_order() {
11474 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11478
11479 let mut stack = MiddlewareStack::new();
11480 stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11481 stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11482 stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
11483
11484 let handler = RecordingHandler::new(log.clone());
11485 let ctx = test_context();
11486 let mut req = Request::new(crate::request::Method::Get, "/");
11487
11488 let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11489
11490 let execution_log = log.lock().unwrap().clone();
11491 assert_eq!(
11492 execution_log,
11493 vec![
11494 "mw1:before",
11495 "mw2:before",
11496 "mw3:before",
11497 "handler",
11498 "mw3:after",
11499 "mw2:after",
11500 "mw1:after",
11501 ]
11502 );
11503 }
11504
11505 #[test]
11506 fn middleware_stack_short_circuit_runs_prior_after_hooks() {
11507 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11515
11516 let mut stack = MiddlewareStack::new();
11517 stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11518 stack.push(ShortCircuitMiddleware::new("mw2", log.clone()));
11519 stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
11520
11521 let handler = RecordingHandler::new(log.clone());
11522 let ctx = test_context();
11523 let mut req = Request::new(crate::request::Method::Get, "/");
11524
11525 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11526
11527 assert_eq!(response.status().as_u16(), 403);
11529
11530 let execution_log = log.lock().unwrap().clone();
11531 assert_eq!(
11534 execution_log,
11535 vec!["mw1:before", "mw2:before:break", "mw1:after",]
11536 );
11537 }
11538
11539 #[test]
11540 fn middleware_stack_first_middleware_short_circuits() {
11541 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11545
11546 let mut stack = MiddlewareStack::new();
11547 stack.push(ShortCircuitMiddleware::new("mw1", log.clone()));
11548 stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11549
11550 let handler = RecordingHandler::new(log.clone());
11551 let ctx = test_context();
11552 let mut req = Request::new(crate::request::Method::Get, "/");
11553
11554 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11555 assert_eq!(response.status().as_u16(), 403);
11556
11557 let execution_log = log.lock().unwrap().clone();
11558 assert_eq!(execution_log, vec!["mw1:before:break",]);
11560 }
11561
11562 #[test]
11563 fn middleware_stack_empty_runs_handler_only() {
11564 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11566
11567 let stack = MiddlewareStack::new();
11568 let handler = RecordingHandler::new(log.clone());
11569 let ctx = test_context();
11570 let mut req = Request::new(crate::request::Method::Get, "/");
11571
11572 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11573 assert_eq!(response.status().as_u16(), 200);
11574
11575 let execution_log = log.lock().unwrap().clone();
11576 assert_eq!(execution_log, vec!["handler"]);
11577 }
11578
11579 #[test]
11580 fn middleware_stack_single_middleware_ordering() {
11581 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11583
11584 let mut stack = MiddlewareStack::new();
11585 stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11586
11587 let handler = RecordingHandler::new(log.clone());
11588 let ctx = test_context();
11589 let mut req = Request::new(crate::request::Method::Get, "/");
11590
11591 let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11592
11593 let execution_log = log.lock().unwrap().clone();
11594 assert_eq!(execution_log, vec!["mw1:before", "handler", "mw1:after",]);
11595 }
11596
11597 #[test]
11598 fn middleware_stack_five_middleware_onion_order() {
11599 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11601
11602 let mut stack = MiddlewareStack::new();
11603 stack.push(OrderRecordingMiddleware::new("a", log.clone()));
11604 stack.push(OrderRecordingMiddleware::new("b", log.clone()));
11605 stack.push(OrderRecordingMiddleware::new("c", log.clone()));
11606 stack.push(OrderRecordingMiddleware::new("d", log.clone()));
11607 stack.push(OrderRecordingMiddleware::new("e", log.clone()));
11608
11609 let handler = RecordingHandler::new(log.clone());
11610 let ctx = test_context();
11611 let mut req = Request::new(crate::request::Method::Get, "/");
11612
11613 let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11614
11615 let execution_log = log.lock().unwrap().clone();
11616 assert_eq!(
11617 execution_log,
11618 vec![
11619 "a:before", "b:before", "c:before", "d:before", "e:before", "handler", "e:after",
11620 "d:after", "c:after", "b:after", "a:after",
11621 ]
11622 );
11623 }
11624
11625 #[test]
11626 fn middleware_stack_short_circuit_at_end_runs_prior_afters() {
11627 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11634
11635 let mut stack = MiddlewareStack::new();
11636 stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11637 stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11638 stack.push(ShortCircuitMiddleware::new("mw3", log.clone()));
11639
11640 let handler = RecordingHandler::new(log.clone());
11641 let ctx = test_context();
11642 let mut req = Request::new(crate::request::Method::Get, "/");
11643
11644 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11645 assert_eq!(response.status().as_u16(), 403);
11646
11647 let execution_log = log.lock().unwrap().clone();
11648 assert_eq!(
11650 execution_log,
11651 vec![
11652 "mw1:before",
11653 "mw2:before",
11654 "mw3:before:break",
11655 "mw2:after",
11656 "mw1:after",
11657 ]
11658 );
11659 }
11660
11661 struct ModifyingMiddleware {
11663 id: &'static str,
11664 log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11665 }
11666
11667 impl ModifyingMiddleware {
11668 fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11669 Self { id, log }
11670 }
11671 }
11672
11673 impl Middleware for ModifyingMiddleware {
11674 fn before<'a>(
11675 &'a self,
11676 _ctx: &'a RequestContext,
11677 req: &'a mut Request,
11678 ) -> BoxFuture<'a, ControlFlow> {
11679 let id = self.id;
11680 let log = self.log.clone();
11681 Box::pin(async move {
11682 req.headers_mut()
11684 .insert(format!("x-{id}-before"), b"true".to_vec());
11685 log.lock().unwrap().push(format!("{id}:before"));
11686 ControlFlow::Continue
11687 })
11688 }
11689
11690 fn after<'a>(
11691 &'a self,
11692 _ctx: &'a RequestContext,
11693 _req: &'a Request,
11694 response: Response,
11695 ) -> BoxFuture<'a, Response> {
11696 let id = self.id;
11697 let log = self.log.clone();
11698 Box::pin(async move {
11699 log.lock().unwrap().push(format!("{id}:after"));
11700 response.header(format!("x-{id}-after"), b"true".to_vec())
11702 })
11703 }
11704
11705 fn name(&self) -> &'static str {
11706 "Modifying"
11707 }
11708 }
11709
11710 #[test]
11711 fn middleware_stack_modifications_accumulate_correctly() {
11712 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11715
11716 let mut stack = MiddlewareStack::new();
11717 stack.push(ModifyingMiddleware::new("mw1", log.clone()));
11718 stack.push(ModifyingMiddleware::new("mw2", log.clone()));
11719 stack.push(ModifyingMiddleware::new("mw3", log.clone()));
11720
11721 let handler = RecordingHandler::new(log.clone());
11722 let ctx = test_context();
11723 let mut req = Request::new(crate::request::Method::Get, "/");
11724
11725 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11726
11727 assert!(header_value(&response, "x-mw1-after").is_some());
11729 assert!(header_value(&response, "x-mw2-after").is_some());
11730 assert!(header_value(&response, "x-mw3-after").is_some());
11731
11732 assert!(req.headers().contains("x-mw1-before"));
11734 assert!(req.headers().contains("x-mw2-before"));
11735 assert!(req.headers().contains("x-mw3-before"));
11736 }
11737
11738 #[test]
11739 fn layer_wrap_maintains_middleware_order() {
11740 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11742
11743 let layer = Layer::new(OrderRecordingMiddleware::new("layer", log.clone()));
11745
11746 let handler = RecordingHandler::new(log.clone());
11748 let layered_handler = layer.wrap(handler);
11749
11750 let ctx = test_context();
11751 let mut req = Request::new(crate::request::Method::Get, "/");
11752
11753 let _response = futures_executor::block_on(layered_handler.call(&ctx, &mut req));
11755
11756 let execution_log = log.lock().unwrap().clone();
11757 assert_eq!(
11758 execution_log,
11759 vec!["layer:before", "handler", "layer:after",]
11760 );
11761 }
11762}
11763
11764#[cfg(all(test, feature = "compression"))]
11769mod compression_tests {
11770 use super::*;
11771 use crate::request::Method;
11772 use crate::response::ResponseBody;
11773
11774 fn test_context() -> RequestContext {
11775 RequestContext::new(asupersync::Cx::for_testing(), 1)
11776 }
11777
11778 #[test]
11779 fn compression_config_defaults() {
11780 let config = CompressionConfig::default();
11781 assert_eq!(config.min_size, 1024);
11782 assert_eq!(config.level, 6);
11783 assert!(!config.skip_content_types.is_empty());
11784 }
11785
11786 #[test]
11787 fn compression_config_builder() {
11788 let config = CompressionConfig::new().min_size(512).level(9);
11789 assert_eq!(config.min_size, 512);
11790 assert_eq!(config.level, 9);
11791 }
11792
11793 #[test]
11794 fn compression_level_clamped() {
11795 let config = CompressionConfig::new().level(100);
11796 assert_eq!(config.level, 9);
11797
11798 let config = CompressionConfig::new().level(0);
11799 assert_eq!(config.level, 1);
11800 }
11801
11802 #[test]
11803 fn skip_content_type_exact_match() {
11804 let config = CompressionConfig::default();
11805 assert!(config.should_skip_content_type("image/jpeg"));
11806 assert!(config.should_skip_content_type("image/jpeg; charset=utf-8"));
11807 assert!(!config.should_skip_content_type("text/html"));
11808 }
11809
11810 #[test]
11811 fn skip_content_type_prefix_match() {
11812 let config = CompressionConfig::default();
11813 assert!(config.should_skip_content_type("video/mp4"));
11815 assert!(config.should_skip_content_type("video/webm"));
11816 assert!(config.should_skip_content_type("audio/mpeg"));
11817 }
11818
11819 #[test]
11820 fn compression_skips_small_responses() {
11821 let middleware = CompressionMiddleware::new();
11822 let ctx = test_context();
11823
11824 let mut req = Request::new(Method::Get, "/");
11826 req.headers_mut()
11827 .insert("accept-encoding", b"gzip".to_vec());
11828
11829 let response = Response::ok()
11831 .header("content-type", b"text/plain".to_vec())
11832 .body(ResponseBody::Bytes(b"Hello, World!".to_vec()));
11833
11834 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11836
11837 let has_encoding = result
11839 .headers()
11840 .iter()
11841 .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11842 assert!(!has_encoding, "Small response should not be compressed");
11843 }
11844
11845 #[test]
11846 fn compression_works_for_large_responses() {
11847 let config = CompressionConfig::new().min_size(10); let middleware = CompressionMiddleware::with_config(config);
11849 let ctx = test_context();
11850
11851 let mut req = Request::new(Method::Get, "/");
11853 req.headers_mut()
11854 .insert("accept-encoding", b"gzip".to_vec());
11855
11856 let body = "Hello, World! ".repeat(100);
11858 let original_size = body.len();
11859
11860 let response = Response::ok()
11861 .header("content-type", b"text/plain".to_vec())
11862 .body(ResponseBody::Bytes(body.into_bytes()));
11863
11864 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11866
11867 let encoding = result
11869 .headers()
11870 .iter()
11871 .find(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11872 assert!(encoding.is_some(), "Large response should be compressed");
11873
11874 let (_, value) = encoding.unwrap();
11875 assert_eq!(value, b"gzip");
11876
11877 let vary = result
11879 .headers()
11880 .iter()
11881 .find(|(name, _)| name.eq_ignore_ascii_case("vary"));
11882 assert!(vary.is_some(), "Should have Vary header");
11883
11884 if let ResponseBody::Bytes(compressed) = result.body_ref() {
11886 assert!(
11887 compressed.len() < original_size,
11888 "Compressed size should be smaller"
11889 );
11890 } else {
11891 panic!("Expected Bytes body");
11892 }
11893 }
11894
11895 #[test]
11896 fn compression_skips_without_accept_encoding() {
11897 let config = CompressionConfig::new().min_size(10);
11898 let middleware = CompressionMiddleware::with_config(config);
11899 let ctx = test_context();
11900
11901 let req = Request::new(Method::Get, "/");
11903
11904 let body = "Hello, World! ".repeat(100);
11905 let response = Response::ok()
11906 .header("content-type", b"text/plain".to_vec())
11907 .body(ResponseBody::Bytes(body.into_bytes()));
11908
11909 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11910
11911 let has_encoding = result
11913 .headers()
11914 .iter()
11915 .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11916 assert!(!has_encoding, "Should not compress without Accept-Encoding");
11917 }
11918
11919 #[test]
11920 fn compression_skips_already_compressed_content() {
11921 let config = CompressionConfig::new().min_size(10);
11922 let middleware = CompressionMiddleware::with_config(config);
11923 let ctx = test_context();
11924
11925 let mut req = Request::new(Method::Get, "/");
11927 req.headers_mut()
11928 .insert("accept-encoding", b"gzip".to_vec());
11929
11930 let body = "Some image data".repeat(100);
11932 let response = Response::ok()
11933 .header("content-type", b"image/jpeg".to_vec())
11934 .body(ResponseBody::Bytes(body.into_bytes()));
11935
11936 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11937
11938 let has_encoding = result
11940 .headers()
11941 .iter()
11942 .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11943 assert!(
11944 !has_encoding,
11945 "Should not compress already-compressed content types"
11946 );
11947 }
11948
11949 #[test]
11950 fn compression_skips_if_already_has_content_encoding() {
11951 let config = CompressionConfig::new().min_size(10);
11952 let middleware = CompressionMiddleware::with_config(config);
11953 let ctx = test_context();
11954
11955 let mut req = Request::new(Method::Get, "/");
11957 req.headers_mut()
11958 .insert("accept-encoding", b"gzip".to_vec());
11959
11960 let body = "Hello, World! ".repeat(100);
11962 let response = Response::ok()
11963 .header("content-type", b"text/plain".to_vec())
11964 .header("content-encoding", b"br".to_vec())
11965 .body(ResponseBody::Bytes(body.into_bytes()));
11966
11967 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11968
11969 let encodings: Vec<_> = result
11971 .headers()
11972 .iter()
11973 .filter(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
11974 .collect();
11975
11976 assert_eq!(encodings.len(), 1);
11978 assert_eq!(encodings[0].1, b"br");
11979 }
11980
11981 #[test]
11982 fn accepts_gzip_parses_header_correctly() {
11983 let mut req = Request::new(Method::Get, "/");
11987 req.headers_mut()
11988 .insert("accept-encoding", b"gzip".to_vec());
11989 assert!(CompressionMiddleware::accepts_gzip(&req));
11990
11991 let mut req = Request::new(Method::Get, "/");
11993 req.headers_mut()
11994 .insert("accept-encoding", b"deflate, gzip, br".to_vec());
11995 assert!(CompressionMiddleware::accepts_gzip(&req));
11996
11997 let mut req = Request::new(Method::Get, "/");
11999 req.headers_mut()
12000 .insert("accept-encoding", b"gzip;q=1.0, identity;q=0.5".to_vec());
12001 assert!(CompressionMiddleware::accepts_gzip(&req));
12002
12003 let mut req = Request::new(Method::Get, "/");
12005 req.headers_mut().insert("accept-encoding", b"*".to_vec());
12006 assert!(CompressionMiddleware::accepts_gzip(&req));
12007
12008 let mut req = Request::new(Method::Get, "/");
12010 req.headers_mut()
12011 .insert("accept-encoding", b"deflate, br".to_vec());
12012 assert!(!CompressionMiddleware::accepts_gzip(&req));
12013
12014 let req_no_header = Request::new(Method::Get, "/");
12016 assert!(!CompressionMiddleware::accepts_gzip(&req_no_header));
12017 }
12018
12019 #[test]
12020 fn compression_middleware_name() {
12021 let middleware = CompressionMiddleware::new();
12022 assert_eq!(middleware.name(), "Compression");
12023 }
12024}
12025
12026#[cfg(test)]
12031mod request_inspection_tests {
12032 use super::*;
12033 use crate::request::Method;
12034 use crate::response::ResponseBody;
12035
12036 fn test_context() -> RequestContext {
12037 RequestContext::new(asupersync::Cx::for_testing(), 1)
12038 }
12039
12040 #[test]
12041 fn inspection_middleware_default_creates_normal_verbosity() {
12042 let mw = RequestInspectionMiddleware::new();
12043 assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
12044 assert_eq!(mw.slow_threshold_ms, 1000);
12045 assert_eq!(mw.max_body_preview, 2048);
12046 assert_eq!(mw.name(), "RequestInspection");
12047 }
12048
12049 #[test]
12050 fn inspection_middleware_builder_methods() {
12051 let mw = RequestInspectionMiddleware::new()
12052 .verbosity(InspectionVerbosity::Verbose)
12053 .slow_threshold_ms(500)
12054 .max_body_preview(4096)
12055 .log_config(LogConfig::development())
12056 .redact_header("x-api-key");
12057
12058 assert_eq!(mw.verbosity, InspectionVerbosity::Verbose);
12059 assert_eq!(mw.slow_threshold_ms, 500);
12060 assert_eq!(mw.max_body_preview, 4096);
12061 assert!(mw.redact_headers.contains("x-api-key"));
12062 assert!(mw.redact_headers.contains("authorization"));
12064 assert!(mw.redact_headers.contains("cookie"));
12065 }
12066
12067 #[test]
12068 fn inspection_before_continues_processing() {
12069 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
12070 let ctx = test_context();
12071 let mut req = Request::new(Method::Post, "/api/users");
12072
12073 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12074 assert!(result.is_continue());
12075 }
12076
12077 #[test]
12078 fn inspection_after_returns_response_unchanged() {
12079 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
12080 let ctx = test_context();
12081 let mut req = Request::new(Method::Get, "/health");
12082
12083 let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
12085
12086 let response = Response::ok().body(ResponseBody::Bytes(b"OK".to_vec()));
12087
12088 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
12089 assert_eq!(result.status().as_u16(), 200);
12090 assert_eq!(result.body_ref().len(), 2);
12091 }
12092
12093 #[test]
12094 fn inspection_stores_start_extension() {
12095 let mw = RequestInspectionMiddleware::new();
12096 let ctx = test_context();
12097 let mut req = Request::new(Method::Get, "/");
12098
12099 let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
12100
12101 assert!(req.get_extension::<InspectionStart>().is_some());
12103 }
12104
12105 #[test]
12106 fn inspection_all_verbosity_levels_continue() {
12107 for verbosity in [
12108 InspectionVerbosity::Minimal,
12109 InspectionVerbosity::Normal,
12110 InspectionVerbosity::Verbose,
12111 ] {
12112 let mw = RequestInspectionMiddleware::new().verbosity(verbosity);
12113 let ctx = test_context();
12114 let mut req = Request::new(Method::Get, "/test");
12115 req.headers_mut()
12116 .insert("content-type", b"text/plain".to_vec());
12117
12118 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12119 assert!(
12120 result.is_continue(),
12121 "Verbosity {verbosity:?} should continue"
12122 );
12123 }
12124 }
12125
12126 #[test]
12127 fn inspection_verbose_with_json_body() {
12128 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
12129 let ctx = test_context();
12130 let body = br#"{"name":"Alice","age":30}"#;
12131 let mut req = Request::new(Method::Post, "/api/users");
12132 req.headers_mut()
12133 .insert("content-type", b"application/json".to_vec());
12134 req.set_body(Body::Bytes(body.to_vec()));
12135
12136 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12137 assert!(result.is_continue());
12138 }
12139
12140 #[test]
12141 fn inspection_verbose_after_with_json_response() {
12142 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
12143 let ctx = test_context();
12144 let mut req = Request::new(Method::Get, "/api/users/1");
12145
12146 let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
12147
12148 let response = Response::ok()
12149 .header("content-type", b"application/json".to_vec())
12150 .body(ResponseBody::Bytes(br#"{"id":1,"name":"Alice"}"#.to_vec()));
12151
12152 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
12153 assert_eq!(result.status().as_u16(), 200);
12154 }
12155
12156 #[test]
12157 fn inspection_redacts_sensitive_headers() {
12158 let mw = RequestInspectionMiddleware::new();
12159
12160 assert!(mw.redact_headers.contains("authorization"));
12162 assert!(mw.redact_headers.contains("proxy-authorization"));
12163 assert!(mw.redact_headers.contains("cookie"));
12164 assert!(mw.redact_headers.contains("set-cookie"));
12165 }
12166
12167 #[test]
12168 fn inspection_format_headers_redacts() {
12169 let mw = RequestInspectionMiddleware::new().redact_header("x-secret");
12170
12171 let headers = vec![
12172 ("content-type", b"text/plain".as_slice()),
12173 ("x-secret", b"my-secret-value".as_slice()),
12174 ("x-normal", b"visible".as_slice()),
12175 ];
12176
12177 let output = mw.format_inspection_headers(headers.into_iter());
12178 assert!(output.contains("content-type: text/plain"));
12179 assert!(output.contains("x-secret: [REDACTED]"));
12180 assert!(output.contains("x-normal: visible"));
12181 assert!(!output.contains("my-secret-value"));
12182 }
12183
12184 #[test]
12185 fn inspection_format_body_preview_truncates() {
12186 let mw = RequestInspectionMiddleware::new().max_body_preview(10);
12187
12188 let body = b"Hello, World! This is a long body.";
12189 let result = mw.format_body_preview(body, None);
12190 assert!(result.is_some());
12191 let text = result.unwrap();
12192 assert!(text.ends_with("..."));
12193 assert!(text.len() <= 15); }
12195
12196 #[test]
12197 fn inspection_format_body_preview_empty() {
12198 let mw = RequestInspectionMiddleware::new();
12199 assert!(mw.format_body_preview(b"", None).is_none());
12200 }
12201
12202 #[test]
12203 fn inspection_format_body_preview_zero_max() {
12204 let mw = RequestInspectionMiddleware::new().max_body_preview(0);
12205 assert!(mw.format_body_preview(b"hello", None).is_none());
12206 }
12207
12208 #[test]
12209 fn inspection_format_body_preview_json_pretty() {
12210 let mw = RequestInspectionMiddleware::new();
12211 let body = br#"{"key":"value","num":42}"#;
12212 let ct = b"application/json".as_slice();
12213 let result = mw.format_body_preview(body, Some(ct));
12214 assert!(result.is_some());
12215 let text = result.unwrap();
12216 assert!(text.contains('\n'));
12218 assert!(text.contains("\"key\": \"value\""));
12219 }
12220
12221 #[test]
12222 fn inspection_format_body_preview_non_json() {
12223 let mw = RequestInspectionMiddleware::new();
12224 let body = b"Hello, World!";
12225 let ct = b"text/plain".as_slice();
12226 let result = mw.format_body_preview(body, Some(ct));
12227 assert_eq!(result.unwrap(), "Hello, World!");
12228 }
12229
12230 #[test]
12231 fn inspection_format_body_preview_binary() {
12232 let mw = RequestInspectionMiddleware::new();
12233 let body: &[u8] = &[0xFF, 0xFE, 0xFD, 0x00];
12234 let result = mw.format_body_preview(body, None);
12235 assert!(result.is_some());
12236 assert!(result.unwrap().contains("binary"));
12237 }
12238
12239 #[test]
12240 fn try_pretty_json_valid_object() {
12241 let result = try_pretty_json(r#"{"a":"b","c":1}"#);
12242 assert!(result.is_some());
12243 let pretty = result.unwrap();
12244 assert!(pretty.contains('\n'));
12245 assert!(pretty.contains(" \"a\": \"b\""));
12246 }
12247
12248 #[test]
12249 fn try_pretty_json_valid_array() {
12250 let result = try_pretty_json(r"[1,2,3]");
12251 assert!(result.is_some());
12252 let pretty = result.unwrap();
12253 assert!(pretty.contains('\n'));
12254 }
12255
12256 #[test]
12257 fn try_pretty_json_empty_object() {
12258 let result = try_pretty_json("{}");
12259 assert!(result.is_some());
12260 assert_eq!(result.unwrap(), "{}");
12261 }
12262
12263 #[test]
12264 fn try_pretty_json_empty_array() {
12265 let result = try_pretty_json("[]");
12266 assert!(result.is_some());
12267 assert_eq!(result.unwrap(), "[]");
12268 }
12269
12270 #[test]
12271 fn try_pretty_json_not_json() {
12272 assert!(try_pretty_json("hello world").is_none());
12273 assert!(try_pretty_json("12345").is_none());
12274 }
12275
12276 #[test]
12277 fn try_pretty_json_nested() {
12278 let input = r#"{"user":{"name":"Alice","roles":["admin","user"]}}"#;
12279 let result = try_pretty_json(input);
12280 assert!(result.is_some());
12281 let pretty = result.unwrap();
12282 assert!(pretty.contains("\"user\":"));
12283 assert!(pretty.contains("\"name\": \"Alice\""));
12284 assert!(pretty.contains("\"roles\":"));
12285 }
12286
12287 #[test]
12288 fn try_pretty_json_with_escapes() {
12289 let input = r#"{"msg":"hello \"world\""}"#;
12290 let result = try_pretty_json(input);
12291 assert!(result.is_some());
12292 let pretty = result.unwrap();
12293 assert!(pretty.contains(r#"\"world\""#));
12294 }
12295
12296 #[test]
12297 fn inspection_name() {
12298 let mw = RequestInspectionMiddleware::new();
12299 assert_eq!(mw.name(), "RequestInspection");
12300 }
12301
12302 #[test]
12303 fn inspection_default_via_default_trait() {
12304 let mw = RequestInspectionMiddleware::default();
12305 assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
12306 assert_eq!(mw.slow_threshold_ms, 1000);
12307 }
12308
12309 #[test]
12310 fn inspection_with_query_string() {
12311 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
12312 let ctx = test_context();
12313 let mut req = Request::new(Method::Get, "/search");
12314 req.set_query(Some("q=rust&page=1".to_string()));
12315
12316 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12317 assert!(result.is_continue());
12318 }
12319
12320 #[test]
12321 fn inspection_response_body_stream() {
12322 let mw = RequestInspectionMiddleware::new();
12323 let result = mw.format_response_preview(&ResponseBody::Empty, None);
12324 assert!(result.is_none());
12325 }
12326}
12327
12328#[cfg(test)]
12333mod rate_limit_tests {
12334 use super::*;
12335 use crate::request::Method;
12336 use crate::response::{ResponseBody, StatusCode};
12337 use std::time::Duration;
12338
12339 fn test_context() -> RequestContext {
12340 RequestContext::new(asupersync::Cx::for_testing(), 1)
12341 }
12342
12343 fn run_rate_limit_before(mw: &RateLimitMiddleware, req: &mut Request) -> ControlFlow {
12344 let ctx = test_context();
12345 let fut = mw.before(&ctx, req);
12346 futures_executor::block_on(fut)
12347 }
12348
12349 fn run_rate_limit_after(mw: &RateLimitMiddleware, req: &Request, resp: Response) -> Response {
12350 let ctx = test_context();
12351 let fut = mw.after(&ctx, req, resp);
12352 futures_executor::block_on(fut)
12353 }
12354
12355 #[test]
12356 fn rate_limit_default_allows_requests() {
12357 let mw = RateLimitMiddleware::new();
12358 let mut req = Request::new(Method::Get, "/api/test");
12359 req.headers_mut()
12360 .insert("x-forwarded-for", b"192.168.1.1".to_vec());
12361
12362 let result = run_rate_limit_before(&mw, &mut req);
12363 assert!(result.is_continue(), "first request should be allowed");
12364 }
12365
12366 #[test]
12367 fn rate_limit_fixed_window_blocks_after_limit() {
12368 let mw = RateLimitMiddleware::builder()
12369 .requests(3)
12370 .per(Duration::from_secs(60))
12371 .algorithm(RateLimitAlgorithm::FixedWindow)
12372 .key_extractor(IpKeyExtractor)
12373 .build();
12374
12375 for i in 0..3 {
12376 let mut req = Request::new(Method::Get, "/api/test");
12377 req.headers_mut()
12378 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12379 let result = run_rate_limit_before(&mw, &mut req);
12380 assert!(
12381 result.is_continue(),
12382 "request {i} should be allowed within limit"
12383 );
12384 }
12385
12386 let mut req = Request::new(Method::Get, "/api/test");
12388 req.headers_mut()
12389 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12390 let result = run_rate_limit_before(&mw, &mut req);
12391 assert!(result.is_break(), "fourth request should be blocked");
12392
12393 if let ControlFlow::Break(resp) = result {
12395 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
12396 }
12397 }
12398
12399 #[test]
12400 fn rate_limit_different_keys_independent() {
12401 let mw = RateLimitMiddleware::builder()
12402 .requests(2)
12403 .per(Duration::from_secs(60))
12404 .algorithm(RateLimitAlgorithm::FixedWindow)
12405 .key_extractor(IpKeyExtractor)
12406 .build();
12407
12408 for _ in 0..2 {
12410 let mut req = Request::new(Method::Get, "/");
12411 req.headers_mut()
12412 .insert("x-forwarded-for", b"1.1.1.1".to_vec());
12413 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12414 }
12415
12416 let mut req = Request::new(Method::Get, "/");
12418 req.headers_mut()
12419 .insert("x-forwarded-for", b"1.1.1.1".to_vec());
12420 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12421
12422 let mut req = Request::new(Method::Get, "/");
12424 req.headers_mut()
12425 .insert("x-forwarded-for", b"2.2.2.2".to_vec());
12426 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12427 }
12428
12429 #[test]
12430 fn rate_limit_token_bucket_allows_burst() {
12431 let mw = RateLimitMiddleware::builder()
12432 .requests(5)
12433 .per(Duration::from_secs(60))
12434 .algorithm(RateLimitAlgorithm::TokenBucket)
12435 .key_extractor(IpKeyExtractor)
12436 .build();
12437
12438 for i in 0..5 {
12440 let mut req = Request::new(Method::Get, "/");
12441 req.headers_mut()
12442 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12443 let result = run_rate_limit_before(&mw, &mut req);
12444 assert!(result.is_continue(), "burst request {i} should be allowed");
12445 }
12446
12447 let mut req = Request::new(Method::Get, "/");
12449 req.headers_mut()
12450 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12451 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12452 }
12453
12454 #[test]
12455 fn rate_limit_sliding_window_basic() {
12456 let mw = RateLimitMiddleware::builder()
12457 .requests(3)
12458 .per(Duration::from_secs(60))
12459 .algorithm(RateLimitAlgorithm::SlidingWindow)
12460 .key_extractor(IpKeyExtractor)
12461 .build();
12462
12463 for i in 0..3 {
12464 let mut req = Request::new(Method::Get, "/");
12465 req.headers_mut()
12466 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12467 assert!(
12468 run_rate_limit_before(&mw, &mut req).is_continue(),
12469 "sliding window request {i} should be allowed"
12470 );
12471 }
12472
12473 let mut req = Request::new(Method::Get, "/");
12475 req.headers_mut()
12476 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12477 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12478 }
12479
12480 #[test]
12481 fn rate_limit_header_key_extractor() {
12482 let mw = RateLimitMiddleware::builder()
12483 .requests(2)
12484 .per(Duration::from_secs(60))
12485 .algorithm(RateLimitAlgorithm::FixedWindow)
12486 .key_extractor(HeaderKeyExtractor::new("x-api-key"))
12487 .build();
12488
12489 for _ in 0..2 {
12491 let mut req = Request::new(Method::Get, "/");
12492 req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
12493 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12494 }
12495
12496 let mut req = Request::new(Method::Get, "/");
12498 req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
12499 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12500
12501 let mut req = Request::new(Method::Get, "/");
12503 req.headers_mut().insert("x-api-key", b"key-xyz".to_vec());
12504 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12505 }
12506
12507 #[test]
12508 fn rate_limit_path_key_extractor() {
12509 let mw = RateLimitMiddleware::builder()
12510 .requests(1)
12511 .per(Duration::from_secs(60))
12512 .algorithm(RateLimitAlgorithm::FixedWindow)
12513 .key_extractor(PathKeyExtractor)
12514 .build();
12515
12516 let mut req = Request::new(Method::Get, "/api/a");
12517 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12518
12519 let mut req = Request::new(Method::Get, "/api/a");
12521 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12522
12523 let mut req = Request::new(Method::Get, "/api/b");
12525 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12526 }
12527
12528 #[test]
12529 fn rate_limit_no_key_skips_limiting() {
12530 let mw = RateLimitMiddleware::builder()
12531 .requests(1)
12532 .per(Duration::from_secs(60))
12533 .algorithm(RateLimitAlgorithm::FixedWindow)
12534 .key_extractor(HeaderKeyExtractor::new("x-api-key"))
12535 .build();
12536
12537 let mut req = Request::new(Method::Get, "/");
12539 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12540
12541 for _ in 0..10 {
12543 let mut req = Request::new(Method::Get, "/");
12544 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12545 }
12546 }
12547
12548 #[test]
12549 fn rate_limit_response_headers_on_success() {
12550 let mw = RateLimitMiddleware::builder()
12551 .requests(10)
12552 .per(Duration::from_secs(60))
12553 .algorithm(RateLimitAlgorithm::FixedWindow)
12554 .key_extractor(IpKeyExtractor)
12555 .build();
12556
12557 let mut req = Request::new(Method::Get, "/");
12558 req.headers_mut()
12559 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12560 let cf = run_rate_limit_before(&mw, &mut req);
12561 assert!(cf.is_continue());
12562
12563 let resp = Response::with_status(StatusCode::OK);
12564 let resp = run_rate_limit_after(&mw, &req, resp);
12565
12566 let headers = resp.headers();
12568 let has_limit = headers
12569 .iter()
12570 .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
12571 let has_remaining = headers
12572 .iter()
12573 .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-remaining"));
12574 let has_reset = headers
12575 .iter()
12576 .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-reset"));
12577
12578 assert!(has_limit, "should have X-RateLimit-Limit header");
12579 assert!(has_remaining, "should have X-RateLimit-Remaining header");
12580 assert!(has_reset, "should have X-RateLimit-Reset header");
12581
12582 let limit_val = headers
12584 .iter()
12585 .find(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"))
12586 .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12587 .unwrap();
12588 assert_eq!(limit_val, "10");
12589 }
12590
12591 #[test]
12592 fn rate_limit_429_response_has_retry_after() {
12593 let mw = RateLimitMiddleware::builder()
12594 .requests(1)
12595 .per(Duration::from_secs(60))
12596 .algorithm(RateLimitAlgorithm::FixedWindow)
12597 .key_extractor(IpKeyExtractor)
12598 .build();
12599
12600 let mut req = Request::new(Method::Get, "/");
12602 req.headers_mut()
12603 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12604 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12605
12606 let mut req = Request::new(Method::Get, "/");
12608 req.headers_mut()
12609 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12610 let result = run_rate_limit_before(&mw, &mut req);
12611
12612 if let ControlFlow::Break(resp) = result {
12613 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
12614
12615 let has_retry = resp
12617 .headers()
12618 .iter()
12619 .any(|(n, _)| n.eq_ignore_ascii_case("retry-after"));
12620 assert!(has_retry, "429 response should have Retry-After header");
12621
12622 let has_ct = resp
12624 .headers()
12625 .iter()
12626 .any(|(n, v)| n.eq_ignore_ascii_case("content-type") && v == b"application/json");
12627 assert!(has_ct, "429 response should have JSON content type");
12628 } else {
12629 panic!("expected Break(429)");
12630 }
12631 }
12632
12633 #[test]
12634 fn rate_limit_no_headers_when_disabled() {
12635 let mw = RateLimitMiddleware::builder()
12636 .requests(10)
12637 .per(Duration::from_secs(60))
12638 .algorithm(RateLimitAlgorithm::FixedWindow)
12639 .key_extractor(IpKeyExtractor)
12640 .include_headers(false)
12641 .build();
12642
12643 let mut req = Request::new(Method::Get, "/");
12644 req.headers_mut()
12645 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12646 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12647
12648 let resp = Response::with_status(StatusCode::OK);
12649 let resp = run_rate_limit_after(&mw, &req, resp);
12650
12651 let has_limit = resp
12652 .headers()
12653 .iter()
12654 .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
12655 assert!(
12656 !has_limit,
12657 "should NOT have rate limit headers when disabled"
12658 );
12659 }
12660
12661 #[test]
12662 fn rate_limit_custom_retry_message() {
12663 let mw = RateLimitMiddleware::builder()
12664 .requests(1)
12665 .per(Duration::from_secs(60))
12666 .algorithm(RateLimitAlgorithm::FixedWindow)
12667 .key_extractor(IpKeyExtractor)
12668 .retry_message("Slow down, partner!")
12669 .build();
12670
12671 let mut req = Request::new(Method::Get, "/");
12673 req.headers_mut()
12674 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12675 run_rate_limit_before(&mw, &mut req);
12676
12677 let mut req = Request::new(Method::Get, "/");
12679 req.headers_mut()
12680 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12681 if let ControlFlow::Break(resp) = run_rate_limit_before(&mw, &mut req) {
12682 if let ResponseBody::Bytes(body) = resp.body_ref() {
12683 let body_str = std::str::from_utf8(body).unwrap();
12684 assert!(
12685 body_str.contains("Slow down, partner!"),
12686 "expected custom message in body, got: {body_str}"
12687 );
12688 } else {
12689 panic!("expected Bytes body");
12690 }
12691 } else {
12692 panic!("expected Break(429)");
12693 }
12694 }
12695
12696 #[test]
12697 fn rate_limit_ip_extractor_x_forwarded_for() {
12698 let extractor = IpKeyExtractor;
12699 let mut req = Request::new(Method::Get, "/");
12700 req.headers_mut()
12701 .insert("x-forwarded-for", b"1.2.3.4, 5.6.7.8".to_vec());
12702 assert_eq!(extractor.extract_key(&req), Some("1.2.3.4".to_string()));
12703 }
12704
12705 #[test]
12706 fn rate_limit_ip_extractor_x_real_ip() {
12707 let extractor = IpKeyExtractor;
12708 let mut req = Request::new(Method::Get, "/");
12709 req.headers_mut().insert("x-real-ip", b"9.8.7.6".to_vec());
12710 assert_eq!(extractor.extract_key(&req), Some("9.8.7.6".to_string()));
12711 }
12712
12713 #[test]
12714 fn rate_limit_ip_extractor_fallback() {
12715 let extractor = IpKeyExtractor;
12716 let req = Request::new(Method::Get, "/");
12717 assert_eq!(extractor.extract_key(&req), Some("unknown".to_string()));
12718 }
12719
12720 #[test]
12722 fn connected_ip_extractor_with_remote_addr() {
12723 use std::net::{IpAddr, Ipv4Addr};
12724
12725 let extractor = ConnectedIpKeyExtractor;
12726 let mut req = Request::new(Method::Get, "/");
12727 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100))));
12728
12729 assert_eq!(
12730 extractor.extract_key(&req),
12731 Some("192.168.1.100".to_string())
12732 );
12733 }
12734
12735 #[test]
12736 fn connected_ip_extractor_without_remote_addr() {
12737 let extractor = ConnectedIpKeyExtractor;
12738 let req = Request::new(Method::Get, "/");
12739
12740 assert_eq!(extractor.extract_key(&req), None);
12742 }
12743
12744 #[test]
12745 fn connected_ip_extractor_ignores_headers() {
12746 use std::net::{IpAddr, Ipv4Addr};
12747
12748 let extractor = ConnectedIpKeyExtractor;
12749 let mut req = Request::new(Method::Get, "/");
12750 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
12751 req.headers_mut()
12753 .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12754
12755 assert_eq!(extractor.extract_key(&req), Some("10.0.0.1".to_string()));
12757 }
12758
12759 #[test]
12761 fn trusted_proxy_extractor_from_trusted_proxy() {
12762 use std::net::{IpAddr, Ipv4Addr};
12763
12764 let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
12765
12766 let mut req = Request::new(Method::Get, "/");
12767 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
12769 req.headers_mut()
12771 .insert("x-forwarded-for", b"203.0.113.50".to_vec());
12772
12773 assert_eq!(
12775 extractor.extract_key(&req),
12776 Some("203.0.113.50".to_string())
12777 );
12778 }
12779
12780 #[test]
12781 fn trusted_proxy_extractor_from_untrusted_direct() {
12782 use std::net::{IpAddr, Ipv4Addr};
12783
12784 let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
12785
12786 let mut req = Request::new(Method::Get, "/");
12787 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50))));
12789 req.headers_mut()
12791 .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12792
12793 assert_eq!(
12795 extractor.extract_key(&req),
12796 Some("203.0.113.50".to_string())
12797 );
12798 }
12799
12800 #[test]
12801 fn trusted_proxy_extractor_no_remote_addr() {
12802 let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12803
12804 let mut req = Request::new(Method::Get, "/");
12805 req.headers_mut()
12807 .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12808
12809 assert_eq!(extractor.extract_key(&req), None);
12810 }
12811
12812 #[test]
12813 fn trusted_proxy_extractor_loopback_ipv4() {
12814 use std::net::{IpAddr, Ipv4Addr};
12815
12816 let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12817
12818 let mut req = Request::new(Method::Get, "/");
12819 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::LOCALHOST)));
12820 req.headers_mut()
12821 .insert("x-forwarded-for", b"8.8.8.8".to_vec());
12822
12823 assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
12824 }
12825
12826 #[test]
12827 fn trusted_proxy_extractor_loopback_ipv6() {
12828 use std::net::{IpAddr, Ipv6Addr};
12829
12830 let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12831
12832 let mut req = Request::new(Method::Get, "/");
12833 req.insert_extension(RemoteAddr(IpAddr::V6(Ipv6Addr::LOCALHOST)));
12834 req.headers_mut()
12835 .insert("x-forwarded-for", b"8.8.8.8".to_vec());
12836
12837 assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
12838 }
12839
12840 #[test]
12841 fn cidr_parsing() {
12842 assert!(parse_cidr("10.0.0.0/8").is_some());
12844 assert!(parse_cidr("192.168.1.0/24").is_some());
12845 assert!(parse_cidr("0.0.0.0/0").is_some());
12846 assert!(parse_cidr("::1/128").is_some());
12847 assert!(parse_cidr("::/0").is_some());
12848
12849 assert!(parse_cidr("10.0.0.0/33").is_none()); assert!(parse_cidr("invalid").is_none());
12852 assert!(parse_cidr("10.0.0.0").is_none()); }
12854
12855 #[test]
12856 fn ip_in_cidr_matching() {
12857 use std::net::{IpAddr, Ipv4Addr};
12858
12859 let cidr_10 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0));
12860
12861 assert!(ip_in_cidr(
12863 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
12864 cidr_10,
12865 8
12866 ));
12867 assert!(ip_in_cidr(
12868 IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255)),
12869 cidr_10,
12870 8
12871 ));
12872
12873 assert!(!ip_in_cidr(
12875 IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1)),
12876 cidr_10,
12877 8
12878 ));
12879 assert!(!ip_in_cidr(
12880 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
12881 cidr_10,
12882 8
12883 ));
12884 }
12885
12886 #[test]
12887 fn rate_limit_composite_key_extractor() {
12888 let extractor =
12889 CompositeKeyExtractor::new(vec![Box::new(IpKeyExtractor), Box::new(PathKeyExtractor)]);
12890
12891 let mut req = Request::new(Method::Get, "/api/users");
12892 req.headers_mut()
12893 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12894
12895 let key = extractor.extract_key(&req);
12896 assert_eq!(key, Some("10.0.0.1:/api/users".to_string()));
12897 }
12898
12899 #[test]
12900 fn rate_limit_builder_defaults() {
12901 let mw = RateLimitMiddleware::builder().build();
12902 assert_eq!(mw.config.max_requests, 100);
12903 assert_eq!(mw.config.window, Duration::from_secs(60));
12904 assert_eq!(mw.config.algorithm, RateLimitAlgorithm::TokenBucket);
12905 assert!(mw.config.include_headers);
12906 }
12907
12908 #[test]
12909 fn rate_limit_builder_per_minute() {
12910 let mw = RateLimitMiddleware::builder()
12911 .requests(50)
12912 .per_minute(2)
12913 .algorithm(RateLimitAlgorithm::SlidingWindow)
12914 .build();
12915 assert_eq!(mw.config.max_requests, 50);
12916 assert_eq!(mw.config.window, Duration::from_secs(120));
12917 assert_eq!(mw.config.algorithm, RateLimitAlgorithm::SlidingWindow);
12918 }
12919
12920 #[test]
12921 fn rate_limit_builder_per_hour() {
12922 let mw = RateLimitMiddleware::builder()
12923 .requests(1000)
12924 .per_hour(1)
12925 .build();
12926 assert_eq!(mw.config.window, Duration::from_secs(3600));
12927 }
12928
12929 #[test]
12930 fn rate_limit_middleware_name() {
12931 let mw = RateLimitMiddleware::new();
12932 assert_eq!(mw.name(), "RateLimit");
12933 }
12934
12935 #[test]
12936 fn rate_limit_default_via_default_trait() {
12937 let mw = RateLimitMiddleware::default();
12938 assert_eq!(mw.config.max_requests, 100);
12939 }
12940
12941 #[test]
12946 fn etag_middleware_generates_etag_for_get() {
12947 let mw = ETagMiddleware::new();
12948 let ctx = test_context();
12949 let req = Request::new(crate::request::Method::Get, "/resource");
12950
12951 let response = Response::ok()
12953 .header("content-type", b"application/json".to_vec())
12954 .body(ResponseBody::Bytes(br#"{"status":"ok"}"#.to_vec()));
12955
12956 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12957
12958 let etag = response
12960 .headers()
12961 .iter()
12962 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12963 assert!(etag.is_some(), "Response should have ETag header");
12964
12965 let etag_value = std::str::from_utf8(&etag.unwrap().1).unwrap();
12967 assert!(etag_value.starts_with('"'), "ETag should start with quote");
12968 assert!(etag_value.ends_with('"'), "ETag should end with quote");
12969 }
12970
12971 #[test]
12972 fn etag_middleware_returns_304_on_match() {
12973 let mw = ETagMiddleware::new();
12974 let ctx = test_context();
12975
12976 let req1 = Request::new(crate::request::Method::Get, "/resource");
12978 let body = br#"{"status":"ok"}"#.to_vec();
12979 let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
12980 let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
12981
12982 let etag = response1
12983 .headers()
12984 .iter()
12985 .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
12986 .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12987 .unwrap();
12988
12989 let mut req2 = Request::new(crate::request::Method::Get, "/resource");
12991 req2.headers_mut()
12992 .insert("if-none-match", etag.as_bytes().to_vec());
12993
12994 let response2 = Response::ok().body(ResponseBody::Bytes(body));
12995 let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
12996
12997 assert_eq!(response2.status().as_u16(), 304);
12999 assert!(response2.body_ref().is_empty());
13000 }
13001
13002 #[test]
13003 fn etag_middleware_returns_full_response_on_mismatch() {
13004 let mw = ETagMiddleware::new();
13005 let ctx = test_context();
13006
13007 let mut req = Request::new(crate::request::Method::Get, "/resource");
13008 req.headers_mut()
13009 .insert("if-none-match", b"\"old-etag\"".to_vec());
13010
13011 let body = br#"{"status":"updated"}"#.to_vec();
13012 let response = Response::ok().body(ResponseBody::Bytes(body.clone()));
13013 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13014
13015 assert_eq!(response.status().as_u16(), 200);
13017 assert!(!response.body_ref().is_empty());
13018 }
13019
13020 #[test]
13021 fn etag_middleware_weak_etag_generation() {
13022 let config = ETagConfig::new().weak(true);
13023 let mw = ETagMiddleware::with_config(config);
13024 let ctx = test_context();
13025 let req = Request::new(crate::request::Method::Get, "/resource");
13026
13027 let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13028 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13029
13030 let etag = response
13031 .headers()
13032 .iter()
13033 .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
13034 .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
13035 .unwrap();
13036
13037 assert!(etag.starts_with("W/"), "Weak ETag should start with W/");
13038 }
13039
13040 #[test]
13041 fn etag_middleware_skips_post_requests() {
13042 let mw = ETagMiddleware::new();
13043 let ctx = test_context();
13044 let req = Request::new(crate::request::Method::Post, "/resource");
13045
13046 let response = Response::ok().body(ResponseBody::Bytes(b"created".to_vec()));
13047 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13048
13049 let etag = response
13051 .headers()
13052 .iter()
13053 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
13054 assert!(etag.is_none(), "POST should not have ETag");
13055 }
13056
13057 #[test]
13058 fn etag_middleware_handles_head_requests() {
13059 let mw = ETagMiddleware::new();
13060 let ctx = test_context();
13061 let req = Request::new(crate::request::Method::Head, "/resource");
13062
13063 let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13064 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13065
13066 let etag = response
13068 .headers()
13069 .iter()
13070 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
13071 assert!(etag.is_some(), "HEAD should have ETag");
13072 }
13073
13074 #[test]
13075 fn etag_middleware_disabled_mode() {
13076 let config = ETagConfig::new().mode(ETagMode::Disabled);
13077 let mw = ETagMiddleware::with_config(config);
13078 let ctx = test_context();
13079 let req = Request::new(crate::request::Method::Get, "/resource");
13080
13081 let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13082 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13083
13084 let etag = response
13086 .headers()
13087 .iter()
13088 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
13089 assert!(etag.is_none(), "Disabled mode should not add ETag");
13090 }
13091
13092 #[test]
13093 fn etag_middleware_min_size_filter() {
13094 let config = ETagConfig::new().min_size(1000);
13095 let mw = ETagMiddleware::with_config(config);
13096 let ctx = test_context();
13097 let req = Request::new(crate::request::Method::Get, "/resource");
13098
13099 let response = Response::ok().body(ResponseBody::Bytes(b"small".to_vec()));
13101 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13102
13103 let etag = response
13105 .headers()
13106 .iter()
13107 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
13108 assert!(etag.is_none(), "Small body should not get ETag");
13109 }
13110
13111 #[test]
13112 fn etag_middleware_preserves_existing_etag() {
13113 let config = ETagConfig::new().mode(ETagMode::Manual);
13114 let mw = ETagMiddleware::with_config(config);
13115 let ctx = test_context();
13116
13117 let mut req = Request::new(crate::request::Method::Get, "/resource");
13119 req.headers_mut()
13120 .insert("if-none-match", b"\"custom-etag\"".to_vec());
13121
13122 let response = Response::ok()
13124 .header("etag", b"\"custom-etag\"".to_vec())
13125 .body(ResponseBody::Bytes(b"data".to_vec()));
13126 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13127
13128 assert_eq!(response.status().as_u16(), 304);
13130 }
13131
13132 #[test]
13133 fn etag_middleware_wildcard_if_none_match() {
13134 let mw = ETagMiddleware::new();
13135 let ctx = test_context();
13136 let mut req = Request::new(crate::request::Method::Get, "/resource");
13137 req.headers_mut().insert("if-none-match", b"*".to_vec());
13138
13139 let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13140 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13141
13142 assert_eq!(response.status().as_u16(), 304);
13144 }
13145
13146 #[test]
13147 fn etag_middleware_weak_comparison_matches() {
13148 let mw = ETagMiddleware::new();
13149 let ctx = test_context();
13150
13151 let req1 = Request::new(crate::request::Method::Get, "/resource");
13153 let body = b"test data".to_vec();
13154 let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
13155 let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
13156
13157 let etag = response1
13158 .headers()
13159 .iter()
13160 .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
13161 .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
13162 .unwrap();
13163
13164 let mut req2 = Request::new(crate::request::Method::Get, "/resource");
13166 let weak_etag = format!("W/{}", etag);
13167 req2.headers_mut()
13168 .insert("if-none-match", weak_etag.as_bytes().to_vec());
13169
13170 let response2 = Response::ok().body(ResponseBody::Bytes(body));
13171 let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
13172
13173 assert_eq!(response2.status().as_u16(), 304);
13175 }
13176
13177 #[test]
13178 fn etag_middleware_name() {
13179 let mw = ETagMiddleware::new();
13180 assert_eq!(mw.name(), "ETagMiddleware");
13181 }
13182
13183 #[test]
13184 fn etag_config_builder() {
13185 let config = ETagConfig::new()
13186 .mode(ETagMode::Auto)
13187 .weak(true)
13188 .min_size(512);
13189
13190 assert_eq!(config.mode, ETagMode::Auto);
13191 assert!(config.weak);
13192 assert_eq!(config.min_size, 512);
13193 }
13194
13195 #[test]
13196 fn etag_generates_consistent_hash() {
13197 let etag1 = ETagMiddleware::generate_etag(b"hello world", false);
13199 let etag2 = ETagMiddleware::generate_etag(b"hello world", false);
13200 assert_eq!(etag1, etag2);
13201
13202 let etag3 = ETagMiddleware::generate_etag(b"hello world!", false);
13204 assert_ne!(etag1, etag3);
13205 }
13206}