1use std::collections::HashSet;
54use std::future::Future;
55use std::ops::ControlFlow as StdControlFlow;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::Instant;
59
60use crate::context::RequestContext;
61use crate::dependency::DependencyOverrides;
62use crate::logging::{LogConfig, RequestLogger};
63use crate::request::{Body, Request};
64use crate::response::Response;
65
66pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
68
69#[derive(Debug)]
74pub enum ControlFlow {
75 Continue,
77 Break(Response),
82}
83
84impl ControlFlow {
85 #[must_use]
87 pub fn is_continue(&self) -> bool {
88 matches!(self, Self::Continue)
89 }
90
91 #[must_use]
93 pub fn is_break(&self) -> bool {
94 matches!(self, Self::Break(_))
95 }
96}
97
98impl From<ControlFlow> for StdControlFlow<Response, ()> {
99 fn from(cf: ControlFlow) -> Self {
100 match cf {
101 ControlFlow::Continue => StdControlFlow::Continue(()),
102 ControlFlow::Break(r) => StdControlFlow::Break(r),
103 }
104 }
105}
106
107pub trait Middleware: Send + Sync {
148 fn before<'a>(
164 &'a self,
165 _ctx: &'a RequestContext,
166 _req: &'a mut Request,
167 ) -> BoxFuture<'a, ControlFlow> {
168 Box::pin(async { ControlFlow::Continue })
169 }
170
171 fn after<'a>(
187 &'a self,
188 _ctx: &'a RequestContext,
189 _req: &'a Request,
190 response: Response,
191 ) -> BoxFuture<'a, Response> {
192 Box::pin(async move { response })
193 }
194
195 fn name(&self) -> &'static str {
199 std::any::type_name::<Self>()
200 }
201}
202
203pub trait Handler: Send + Sync {
208 fn call<'a>(&'a self, ctx: &'a RequestContext, req: &'a mut Request)
210 -> BoxFuture<'a, Response>;
211
212 fn dependency_overrides(&self) -> Option<Arc<DependencyOverrides>> {
216 None
217 }
218}
219
220impl<F, Fut> Handler for F
225where
226 F: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync,
227 Fut: Future<Output = Response> + Send + 'static,
228{
229 fn call<'a>(
230 &'a self,
231 ctx: &'a RequestContext,
232 req: &'a mut Request,
233 ) -> BoxFuture<'a, Response> {
234 let fut = self(ctx, req);
235 Box::pin(fut)
236 }
237}
238
239#[derive(Default)]
257pub struct MiddlewareStack {
258 middleware: Vec<Arc<dyn Middleware>>,
259}
260
261impl MiddlewareStack {
262 #[must_use]
264 pub fn new() -> Self {
265 Self {
266 middleware: Vec::new(),
267 }
268 }
269
270 #[must_use]
272 pub fn with_capacity(capacity: usize) -> Self {
273 Self {
274 middleware: Vec::with_capacity(capacity),
275 }
276 }
277
278 pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
282 self.middleware.push(Arc::new(middleware));
283 }
284
285 pub fn push_arc(&mut self, middleware: Arc<dyn Middleware>) {
289 self.middleware.push(middleware);
290 }
291
292 #[must_use]
294 pub fn len(&self) -> usize {
295 self.middleware.len()
296 }
297
298 #[must_use]
300 pub fn is_empty(&self) -> bool {
301 self.middleware.is_empty()
302 }
303
304 pub async fn execute<H: Handler>(
322 &self,
323 handler: &H,
324 ctx: &RequestContext,
325 req: &mut Request,
326 ) -> Response {
327 let mut ran_before_count = 0;
329
330 for mw in &self.middleware {
332 let _ = ctx.checkpoint();
333 match mw.before(ctx, req).await {
334 ControlFlow::Continue => {
335 ran_before_count += 1;
336 }
337 ControlFlow::Break(response) => {
338 return self
340 .run_after_hooks(ctx, req, response, ran_before_count)
341 .await;
342 }
343 }
344 }
345
346 let _ = ctx.checkpoint();
348 let response = handler.call(ctx, req).await;
349
350 self.run_after_hooks(ctx, req, response, ran_before_count)
352 .await
353 }
354
355 async fn run_after_hooks(
357 &self,
358 ctx: &RequestContext,
359 req: &Request,
360 mut response: Response,
361 count: usize,
362 ) -> Response {
363 for mw in self.middleware[..count].iter().rev() {
365 let _ = ctx.checkpoint();
366 response = mw.after(ctx, req, response).await;
367 }
368 response
369 }
370}
371
372pub struct Layer<M> {
383 middleware: M,
384}
385
386impl<M: Middleware + Clone> Layer<M> {
387 pub fn new(middleware: M) -> Self {
389 Self { middleware }
390 }
391
392 pub fn wrap<H: Handler>(&self, handler: H) -> Layered<M, H> {
394 Layered {
395 middleware: self.middleware.clone(),
396 inner: handler,
397 }
398 }
399}
400
401pub struct Layered<M, H> {
403 middleware: M,
404 inner: H,
405}
406
407impl<M: Middleware, H: Handler> Handler for Layered<M, H> {
408 fn call<'a>(
409 &'a self,
410 ctx: &'a RequestContext,
411 req: &'a mut Request,
412 ) -> BoxFuture<'a, Response> {
413 Box::pin(async move {
414 let _ = ctx.checkpoint();
416 match self.middleware.before(ctx, req).await {
417 ControlFlow::Continue => {
418 let _ = ctx.checkpoint();
420 let response = self.inner.call(ctx, req).await;
421 let _ = ctx.checkpoint();
423 self.middleware.after(ctx, req, response).await
424 }
425 ControlFlow::Break(response) => {
426 let _ = ctx.checkpoint();
428 self.middleware.after(ctx, req, response).await
429 }
430 }
431 })
432 }
433}
434
435#[derive(Debug, Clone, Copy, Default)]
443pub struct NoopMiddleware;
444
445impl Middleware for NoopMiddleware {
446 fn name(&self) -> &'static str {
447 "Noop"
448 }
449}
450
451#[derive(Debug, Clone)]
461pub struct AddResponseHeader {
462 name: String,
463 value: Vec<u8>,
464}
465
466impl AddResponseHeader {
467 pub fn new(name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
469 Self {
470 name: name.into(),
471 value: value.into(),
472 }
473 }
474}
475
476impl Middleware for AddResponseHeader {
477 fn after<'a>(
478 &'a self,
479 _ctx: &'a RequestContext,
480 _req: &'a Request,
481 response: Response,
482 ) -> BoxFuture<'a, Response> {
483 let name = self.name.clone();
484 let value = self.value.clone();
485 Box::pin(async move { response.header(name, value) })
486 }
487
488 fn name(&self) -> &'static str {
489 "AddResponseHeader"
490 }
491}
492
493#[derive(Debug, Clone)]
505pub struct RequireHeader {
506 name: String,
507}
508
509impl RequireHeader {
510 pub fn new(name: impl Into<String>) -> Self {
512 Self { name: name.into() }
513 }
514}
515
516impl Middleware for RequireHeader {
517 fn before<'a>(
518 &'a self,
519 _ctx: &'a RequestContext,
520 req: &'a mut Request,
521 ) -> BoxFuture<'a, ControlFlow> {
522 let has_header = req.headers().get(&self.name).is_some();
523 let name = self.name.clone();
524 Box::pin(async move {
525 if has_header {
526 ControlFlow::Continue
527 } else {
528 let body = format!("Missing required header: {name}");
529 ControlFlow::Break(
530 Response::with_status(crate::response::StatusCode::BAD_REQUEST)
531 .header("content-type", b"text/plain".to_vec())
532 .body(crate::response::ResponseBody::Bytes(body.into_bytes())),
533 )
534 }
535 })
536 }
537
538 fn name(&self) -> &'static str {
539 "RequireHeader"
540 }
541}
542
543#[derive(Debug, Clone)]
556pub struct PathPrefixFilter {
557 prefix: String,
558}
559
560impl PathPrefixFilter {
561 pub fn new(prefix: impl Into<String>) -> Self {
563 Self {
564 prefix: prefix.into(),
565 }
566 }
567}
568
569impl Middleware for PathPrefixFilter {
570 fn before<'a>(
571 &'a self,
572 _ctx: &'a RequestContext,
573 req: &'a mut Request,
574 ) -> BoxFuture<'a, ControlFlow> {
575 let path_matches = req.path().starts_with(&self.prefix);
576 Box::pin(async move {
577 if path_matches {
578 ControlFlow::Continue
579 } else {
580 ControlFlow::Break(Response::with_status(
581 crate::response::StatusCode::NOT_FOUND,
582 ))
583 }
584 })
585 }
586
587 fn name(&self) -> &'static str {
588 "PathPrefixFilter"
589 }
590}
591
592#[derive(Debug, Clone)]
596pub struct ConditionalStatus<F>
597where
598 F: Fn(&Request) -> bool + Send + Sync,
599{
600 condition: F,
601 status_if_true: crate::response::StatusCode,
602 status_if_false: crate::response::StatusCode,
603}
604
605impl<F> ConditionalStatus<F>
606where
607 F: Fn(&Request) -> bool + Send + Sync,
608{
609 pub fn new(
614 condition: F,
615 status_if_true: crate::response::StatusCode,
616 status_if_false: crate::response::StatusCode,
617 ) -> Self {
618 Self {
619 condition,
620 status_if_true,
621 status_if_false,
622 }
623 }
624}
625
626impl<F> Middleware for ConditionalStatus<F>
627where
628 F: Fn(&Request) -> bool + Send + Sync,
629{
630 fn after<'a>(
631 &'a self,
632 _ctx: &'a RequestContext,
633 req: &'a Request,
634 response: Response,
635 ) -> BoxFuture<'a, Response> {
636 let matches = (self.condition)(req);
637 let status = if matches {
638 self.status_if_true
639 } else {
640 self.status_if_false
641 };
642 Box::pin(async move { Response::with_status(status).body(response.body_ref().into()) })
643 }
644
645 fn name(&self) -> &'static str {
646 "ConditionalStatus"
647 }
648}
649
650#[derive(Debug, Clone)]
656pub enum OriginPattern {
657 Any,
659 Exact(String),
661 Wildcard(String),
663 Regex(String),
665}
666
667impl OriginPattern {
668 fn matches(&self, origin: &str) -> bool {
669 match self {
670 Self::Any => true,
671 Self::Exact(value) => value == origin,
672 Self::Wildcard(pattern) => wildcard_match(pattern, origin),
673 Self::Regex(pattern) => regex_match(pattern, origin),
674 }
675 }
676}
677
678#[derive(Debug, Clone)]
727pub struct CorsConfig {
728 allow_any_origin: bool,
729 allow_credentials: bool,
730 allowed_methods: Vec<crate::request::Method>,
731 allowed_headers: Vec<String>,
732 expose_headers: Vec<String>,
733 max_age: Option<u32>,
734 origins: Vec<OriginPattern>,
735}
736
737impl Default for CorsConfig {
738 fn default() -> Self {
739 Self {
740 allow_any_origin: false,
741 allow_credentials: false,
742 allowed_methods: vec![
743 crate::request::Method::Get,
744 crate::request::Method::Post,
745 crate::request::Method::Put,
746 crate::request::Method::Patch,
747 crate::request::Method::Delete,
748 crate::request::Method::Options,
749 crate::request::Method::Head,
750 ],
751 allowed_headers: Vec::new(),
752 expose_headers: Vec::new(),
753 max_age: None,
754 origins: Vec::new(),
755 }
756 }
757}
758
759#[derive(Debug, Clone)]
761pub struct Cors {
762 config: CorsConfig,
763}
764
765impl Cors {
766 #[must_use]
768 pub fn new() -> Self {
769 Self {
770 config: CorsConfig::default(),
771 }
772 }
773
774 #[must_use]
776 pub fn config(mut self, config: CorsConfig) -> Self {
777 self.config = config;
778 self
779 }
780
781 #[must_use]
783 pub fn allow_any_origin(mut self) -> Self {
784 self.config.allow_any_origin = true;
785 self
786 }
787
788 #[must_use]
790 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
791 self.config
792 .origins
793 .push(OriginPattern::Exact(origin.into()));
794 self
795 }
796
797 #[must_use]
799 pub fn allow_origin_wildcard(mut self, pattern: impl Into<String>) -> Self {
800 self.config
801 .origins
802 .push(OriginPattern::Wildcard(pattern.into()));
803 self
804 }
805
806 #[must_use]
808 pub fn allow_origin_regex(mut self, pattern: impl Into<String>) -> Self {
809 self.config
810 .origins
811 .push(OriginPattern::Regex(pattern.into()));
812 self
813 }
814
815 #[must_use]
817 pub fn allow_credentials(mut self, allow: bool) -> Self {
818 self.config.allow_credentials = allow;
819 self
820 }
821
822 #[must_use]
824 pub fn allow_methods<I>(mut self, methods: I) -> Self
825 where
826 I: IntoIterator<Item = crate::request::Method>,
827 {
828 self.config.allowed_methods = methods.into_iter().collect();
829 self
830 }
831
832 #[must_use]
834 pub fn allow_headers<I, S>(mut self, headers: I) -> Self
835 where
836 I: IntoIterator<Item = S>,
837 S: Into<String>,
838 {
839 self.config.allowed_headers = headers.into_iter().map(Into::into).collect();
840 self
841 }
842
843 #[must_use]
845 pub fn expose_headers<I, S>(mut self, headers: I) -> Self
846 where
847 I: IntoIterator<Item = S>,
848 S: Into<String>,
849 {
850 self.config.expose_headers = headers.into_iter().map(Into::into).collect();
851 self
852 }
853
854 #[must_use]
856 pub fn max_age(mut self, seconds: u32) -> Self {
857 self.config.max_age = Some(seconds);
858 self
859 }
860
861 fn is_origin_allowed(&self, origin: &str) -> bool {
862 if self.config.allow_any_origin {
863 return true;
864 }
865 self.config
866 .origins
867 .iter()
868 .any(|pattern| pattern.matches(origin))
869 }
870
871 fn allow_origin_value(&self, origin: &str) -> Option<String> {
872 if !self.is_origin_allowed(origin) {
873 return None;
874 }
875 if self.config.allow_any_origin && !self.config.allow_credentials {
876 Some("*".to_string())
877 } else {
878 Some(origin.to_string())
879 }
880 }
881
882 fn allow_methods_value(&self) -> String {
883 self.config
884 .allowed_methods
885 .iter()
886 .map(|method| method.as_str())
887 .collect::<Vec<_>>()
888 .join(", ")
889 }
890
891 fn allow_headers_value(&self, request: &Request) -> Option<String> {
892 if !self.config.allowed_headers.is_empty() {
893 return Some(self.config.allowed_headers.join(", "));
894 }
895
896 request
897 .headers()
898 .get("access-control-request-headers")
899 .and_then(|value| std::str::from_utf8(value).ok())
900 .map(ToString::to_string)
901 }
902
903 fn apply_common_headers(&self, mut response: Response, origin: &str) -> Response {
904 if let Some(allow_origin) = self.allow_origin_value(origin) {
905 let is_wildcard = allow_origin == "*";
906 response = response.header("access-control-allow-origin", allow_origin.into_bytes());
907 if !is_wildcard {
908 response = response.header("vary", b"Origin".to_vec());
909 }
910 if self.config.allow_credentials {
911 response = response.header("access-control-allow-credentials", b"true".to_vec());
912 }
913 if !self.config.expose_headers.is_empty() {
914 response = response.header(
915 "access-control-expose-headers",
916 self.config.expose_headers.join(", ").into_bytes(),
917 );
918 }
919 }
920 response
921 }
922}
923
924impl Default for Cors {
925 fn default() -> Self {
926 Self::new()
927 }
928}
929
930#[derive(Debug, Clone)]
931struct CorsOrigin(String);
932
933impl Middleware for Cors {
934 fn before<'a>(
935 &'a self,
936 _ctx: &'a RequestContext,
937 req: &'a mut Request,
938 ) -> BoxFuture<'a, ControlFlow> {
939 let origin = req
940 .headers()
941 .get("origin")
942 .and_then(|value| std::str::from_utf8(value).ok())
943 .map(ToString::to_string);
944
945 let Some(origin) = origin else {
946 return Box::pin(async { ControlFlow::Continue });
947 };
948
949 if !self.is_origin_allowed(&origin) {
950 let is_preflight = req.method() == crate::request::Method::Options
951 && req.headers().get("access-control-request-method").is_some();
952 if is_preflight {
953 return Box::pin(async {
954 ControlFlow::Break(Response::with_status(
955 crate::response::StatusCode::FORBIDDEN,
956 ))
957 });
958 }
959 return Box::pin(async { ControlFlow::Continue });
960 }
961
962 let is_preflight = req.method() == crate::request::Method::Options
963 && req.headers().get("access-control-request-method").is_some();
964
965 if is_preflight {
966 let mut response = Response::no_content();
967 response = self.apply_common_headers(response, &origin);
968 response = response.header(
969 "access-control-allow-methods",
970 self.allow_methods_value().into_bytes(),
971 );
972
973 if let Some(value) = self.allow_headers_value(req) {
974 response = response.header("access-control-allow-headers", value.into_bytes());
975 }
976
977 if let Some(max_age) = self.config.max_age {
978 response =
979 response.header("access-control-max-age", max_age.to_string().into_bytes());
980 }
981
982 return Box::pin(async move { ControlFlow::Break(response) });
983 }
984
985 req.insert_extension(CorsOrigin(origin));
986 Box::pin(async { ControlFlow::Continue })
987 }
988
989 fn after<'a>(
990 &'a self,
991 _ctx: &'a RequestContext,
992 req: &'a Request,
993 response: Response,
994 ) -> BoxFuture<'a, Response> {
995 let origin = req.get_extension::<CorsOrigin>().map(|v| v.0.clone());
996 Box::pin(async move {
997 if let Some(origin) = origin {
998 return self.apply_common_headers(response, &origin);
999 }
1000 response
1001 })
1002 }
1003
1004 fn name(&self) -> &'static str {
1005 "Cors"
1006 }
1007}
1008
1009fn wildcard_match(pattern: &str, value: &str) -> bool {
1010 let mut pat_chars = pattern.chars().peekable();
1012 let mut val_chars = value.chars().peekable();
1013 let mut star = None;
1014 let mut match_after_star = None;
1015
1016 while let Some(p) = pat_chars.next() {
1017 match p {
1018 '*' => {
1019 star = Some(pat_chars.clone());
1020 match_after_star = Some(val_chars.clone());
1021 }
1022 _ => {
1023 if let Some(v) = val_chars.next() {
1024 if p != v {
1025 if let (Some(pat_backup), Some(val_backup)) =
1026 (star.clone(), match_after_star.clone())
1027 {
1028 pat_chars = pat_backup;
1029 val_chars = val_backup;
1030 val_chars.next();
1031 match_after_star = Some(val_chars.clone());
1032 continue;
1033 }
1034 return false;
1035 }
1036 } else {
1037 return false;
1038 }
1039 }
1040 }
1041 }
1042
1043 if pat_chars.peek().is_none() && val_chars.peek().is_none() {
1045 return true;
1046 }
1047
1048 if let Some(pat_backup) = star {
1049 if val_chars.peek().is_none() {
1050 let trailing = pat_backup;
1051 for ch in trailing {
1052 if ch != '*' {
1053 return false;
1054 }
1055 }
1056 return true;
1057 }
1058 }
1059
1060 val_chars.peek().is_none()
1061}
1062
1063fn regex_match(pattern: &str, value: &str) -> bool {
1064 let pat = pattern.as_bytes();
1066 let text = value.as_bytes();
1067
1068 if pat.first() == Some(&b'^') {
1069 return regex_match_here(&pat[1..], text);
1070 }
1071
1072 let mut i = 0;
1073 loop {
1074 if regex_match_here(pat, &text[i..]) {
1075 return true;
1076 }
1077 if i == text.len() {
1078 break;
1079 }
1080 i += 1;
1081 }
1082 false
1083}
1084
1085fn regex_match_here(pattern: &[u8], text: &[u8]) -> bool {
1086 if pattern.is_empty() {
1087 return true;
1088 }
1089 if pattern == b"$" {
1090 return text.is_empty();
1091 }
1092 if pattern.len() >= 2 && pattern[1] == b'*' {
1093 return regex_match_star(pattern[0], &pattern[2..], text);
1094 }
1095 if !text.is_empty() && (pattern[0] == b'.' || pattern[0] == text[0]) {
1096 return regex_match_here(&pattern[1..], &text[1..]);
1097 }
1098 false
1099}
1100
1101fn regex_match_star(ch: u8, pattern: &[u8], text: &[u8]) -> bool {
1102 let mut i = 0;
1103 loop {
1104 if regex_match_here(pattern, &text[i..]) {
1105 return true;
1106 }
1107 if i == text.len() {
1108 return false;
1109 }
1110 if ch != b'.' && text[i] != ch {
1111 return false;
1112 }
1113 i += 1;
1114 }
1115}
1116
1117#[derive(Debug, Clone)]
1123pub struct RequestResponseLogger {
1124 log_config: LogConfig,
1125 redact_headers: HashSet<String>,
1126 log_request_headers: bool,
1127 log_response_headers: bool,
1128 log_body: bool,
1129 max_body_bytes: usize,
1130}
1131
1132impl Default for RequestResponseLogger {
1133 fn default() -> Self {
1134 Self {
1135 log_config: LogConfig::production(),
1136 redact_headers: default_redacted_headers(),
1137 log_request_headers: true,
1138 log_response_headers: true,
1139 log_body: false,
1140 max_body_bytes: 1024,
1141 }
1142 }
1143}
1144
1145impl RequestResponseLogger {
1146 #[must_use]
1148 pub fn new() -> Self {
1149 Self::default()
1150 }
1151
1152 #[must_use]
1154 pub fn log_config(mut self, config: LogConfig) -> Self {
1155 self.log_config = config;
1156 self
1157 }
1158
1159 #[must_use]
1161 pub fn log_request_headers(mut self, enabled: bool) -> Self {
1162 self.log_request_headers = enabled;
1163 self
1164 }
1165
1166 #[must_use]
1168 pub fn log_response_headers(mut self, enabled: bool) -> Self {
1169 self.log_response_headers = enabled;
1170 self
1171 }
1172
1173 #[must_use]
1175 pub fn log_body(mut self, enabled: bool) -> Self {
1176 self.log_body = enabled;
1177 self
1178 }
1179
1180 #[must_use]
1182 pub fn max_body_bytes(mut self, max: usize) -> Self {
1183 self.max_body_bytes = max;
1184 self
1185 }
1186
1187 #[must_use]
1189 pub fn redact_header(mut self, name: impl Into<String>) -> Self {
1190 self.redact_headers.insert(name.into().to_ascii_lowercase());
1191 self
1192 }
1193}
1194
1195#[derive(Debug, Clone)]
1196struct RequestStart(Instant);
1197
1198impl Middleware for RequestResponseLogger {
1199 fn before<'a>(
1200 &'a self,
1201 ctx: &'a RequestContext,
1202 req: &'a mut Request,
1203 ) -> BoxFuture<'a, ControlFlow> {
1204 let logger = RequestLogger::new(ctx, self.log_config.clone());
1205 req.insert_extension(RequestStart(Instant::now()));
1206
1207 let method = req.method();
1208 let path = req.path();
1209 let query = req.query();
1210 let body_bytes = body_len(req.body());
1211
1212 logger.info_with_fields("request", |entry| {
1213 let mut entry = entry
1214 .field("method", method)
1215 .field("path", path)
1216 .field("body_bytes", body_bytes);
1217
1218 if let Some(q) = query {
1219 entry = entry.field("query", q);
1220 }
1221
1222 if self.log_request_headers {
1223 let headers = format_headers(req.headers().iter(), &self.redact_headers);
1224 entry = entry.field("headers", headers);
1225 }
1226
1227 if self.log_body {
1228 if let Some(body) = preview_body(req.body(), self.max_body_bytes) {
1229 entry = entry.field("body", body);
1230 }
1231 }
1232
1233 entry
1234 });
1235
1236 Box::pin(async { ControlFlow::Continue })
1237 }
1238
1239 fn after<'a>(
1240 &'a self,
1241 ctx: &'a RequestContext,
1242 req: &'a Request,
1243 response: Response,
1244 ) -> BoxFuture<'a, Response> {
1245 let logger = RequestLogger::new(ctx, self.log_config.clone());
1246 let duration = req
1247 .get_extension::<RequestStart>()
1248 .map(|start| start.0.elapsed())
1249 .unwrap_or_default();
1250
1251 let status = response.status();
1252 let body_bytes = response.body_ref().len();
1253
1254 logger.info_with_fields("response", |entry| {
1255 let mut entry = entry
1256 .field("status", status.as_u16())
1257 .field("duration_us", duration.as_micros())
1258 .field("body_bytes", body_bytes);
1259
1260 if self.log_response_headers {
1261 let headers = format_response_headers(response.headers(), &self.redact_headers);
1262 entry = entry.field("headers", headers);
1263 }
1264
1265 if self.log_body {
1266 if let Some(body) = preview_response_body(response.body_ref(), self.max_body_bytes)
1267 {
1268 entry = entry.field("body", body);
1269 }
1270 }
1271
1272 entry
1273 });
1274
1275 Box::pin(async move { response })
1276 }
1277
1278 fn name(&self) -> &'static str {
1279 "RequestResponseLogger"
1280 }
1281}
1282
1283fn default_redacted_headers() -> HashSet<String> {
1284 [
1285 "authorization",
1286 "proxy-authorization",
1287 "cookie",
1288 "set-cookie",
1289 ]
1290 .iter()
1291 .map(ToString::to_string)
1292 .collect()
1293}
1294
1295fn body_len(body: &Body) -> usize {
1296 match body {
1297 Body::Empty => 0,
1298 Body::Bytes(bytes) => bytes.len(),
1299 Body::Stream(_) => 0, }
1301}
1302
1303fn preview_body(body: &Body, max_bytes: usize) -> Option<String> {
1304 if max_bytes == 0 {
1305 return None;
1306 }
1307 match body {
1308 Body::Empty => None,
1309 Body::Bytes(bytes) => {
1310 if bytes.is_empty() {
1311 None
1312 } else {
1313 Some(format_bytes(bytes, max_bytes))
1314 }
1315 }
1316 Body::Stream(_) => None, }
1318}
1319
1320fn preview_response_body(body: &crate::response::ResponseBody, max_bytes: usize) -> Option<String> {
1321 if max_bytes == 0 {
1322 return None;
1323 }
1324 match body {
1325 crate::response::ResponseBody::Empty => None,
1326 crate::response::ResponseBody::Bytes(bytes) => {
1327 if bytes.is_empty() {
1328 None
1329 } else {
1330 Some(format_bytes(bytes, max_bytes))
1331 }
1332 }
1333 crate::response::ResponseBody::Stream(_) => None,
1334 }
1335}
1336
1337fn format_headers<'a>(
1338 headers: impl Iterator<Item = (&'a str, &'a [u8])>,
1339 redacted: &HashSet<String>,
1340) -> String {
1341 let mut out = String::new();
1342 for (idx, (name, value)) in headers.enumerate() {
1343 if idx > 0 {
1344 out.push_str(", ");
1345 }
1346 out.push_str(name);
1347 out.push('=');
1348
1349 let lowered = name.to_ascii_lowercase();
1350 if redacted.contains(&lowered) {
1351 out.push_str("<redacted>");
1352 continue;
1353 }
1354
1355 match std::str::from_utf8(value) {
1356 Ok(text) => out.push_str(text),
1357 Err(_) => out.push_str("<binary>"),
1358 }
1359 }
1360 out
1361}
1362
1363fn format_response_headers(headers: &[(String, Vec<u8>)], redacted: &HashSet<String>) -> String {
1364 format_headers(
1365 headers
1366 .iter()
1367 .map(|(name, value)| (name.as_str(), value.as_slice())),
1368 redacted,
1369 )
1370}
1371
1372fn format_bytes(bytes: &[u8], max_bytes: usize) -> String {
1373 let limit = max_bytes.min(bytes.len());
1374 match std::str::from_utf8(&bytes[..limit]) {
1375 Ok(text) => {
1376 let mut output = text.to_string();
1377 if bytes.len() > max_bytes {
1378 output.push_str("...");
1379 }
1380 output
1381 }
1382 Err(_) => format!("<{} bytes binary>", bytes.len()),
1383 }
1384}
1385
1386impl From<&crate::response::ResponseBody> for crate::response::ResponseBody {
1388 fn from(body: &crate::response::ResponseBody) -> Self {
1389 match body {
1390 crate::response::ResponseBody::Empty => crate::response::ResponseBody::Empty,
1391 crate::response::ResponseBody::Bytes(b) => {
1392 crate::response::ResponseBody::Bytes(b.clone())
1393 }
1394 crate::response::ResponseBody::Stream(_) => crate::response::ResponseBody::Empty,
1395 }
1396 }
1397}
1398
1399#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1408pub struct RequestId(pub String);
1409
1410impl RequestId {
1411 #[must_use]
1413 pub fn new(id: impl Into<String>) -> Self {
1414 Self(id.into())
1415 }
1416
1417 #[must_use]
1419 pub fn as_str(&self) -> &str {
1420 &self.0
1421 }
1422
1423 #[must_use]
1428 pub fn generate() -> Self {
1429 use std::sync::atomic::{AtomicU64, Ordering};
1430 use std::time::{SystemTime, UNIX_EPOCH};
1431
1432 static COUNTER: AtomicU64 = AtomicU64::new(0);
1433
1434 let timestamp = SystemTime::now()
1435 .duration_since(UNIX_EPOCH)
1436 .map(|d| d.as_micros() as u64)
1437 .unwrap_or(0);
1438 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
1439
1440 Self(format!("{:x}-{:04x}", timestamp, counter & 0xFFFF))
1442 }
1443}
1444
1445impl std::fmt::Display for RequestId {
1446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1447 write!(f, "{}", self.0)
1448 }
1449}
1450
1451impl From<String> for RequestId {
1452 fn from(s: String) -> Self {
1453 Self(s)
1454 }
1455}
1456
1457impl From<&str> for RequestId {
1458 fn from(s: &str) -> Self {
1459 Self(s.to_string())
1460 }
1461}
1462
1463#[derive(Debug, Clone)]
1465pub struct RequestIdConfig {
1466 pub header_name: String,
1468 pub accept_from_client: bool,
1470 pub add_to_response: bool,
1472 pub max_client_id_length: usize,
1474}
1475
1476impl Default for RequestIdConfig {
1477 fn default() -> Self {
1478 Self {
1479 header_name: "x-request-id".to_string(),
1480 accept_from_client: true,
1481 add_to_response: true,
1482 max_client_id_length: 128,
1483 }
1484 }
1485}
1486
1487impl RequestIdConfig {
1488 #[must_use]
1490 pub fn new() -> Self {
1491 Self::default()
1492 }
1493
1494 #[must_use]
1496 pub fn header_name(mut self, name: impl Into<String>) -> Self {
1497 self.header_name = name.into();
1498 self
1499 }
1500
1501 #[must_use]
1503 pub fn accept_from_client(mut self, accept: bool) -> Self {
1504 self.accept_from_client = accept;
1505 self
1506 }
1507
1508 #[must_use]
1510 pub fn add_to_response(mut self, add: bool) -> Self {
1511 self.add_to_response = add;
1512 self
1513 }
1514
1515 #[must_use]
1517 pub fn max_client_id_length(mut self, max: usize) -> Self {
1518 self.max_client_id_length = max;
1519 self
1520 }
1521}
1522
1523#[derive(Debug, Clone)]
1548pub struct RequestIdMiddleware {
1549 config: RequestIdConfig,
1550}
1551
1552impl Default for RequestIdMiddleware {
1553 fn default() -> Self {
1554 Self::new()
1555 }
1556}
1557
1558impl RequestIdMiddleware {
1559 #[must_use]
1561 pub fn new() -> Self {
1562 Self {
1563 config: RequestIdConfig::default(),
1564 }
1565 }
1566
1567 #[must_use]
1569 pub fn with_config(config: RequestIdConfig) -> Self {
1570 Self { config }
1571 }
1572
1573 fn get_or_generate_id(&self, req: &Request) -> RequestId {
1575 if self.config.accept_from_client {
1576 if let Some(header_value) = req.headers().get(&self.config.header_name) {
1577 if let Ok(client_id) = std::str::from_utf8(header_value) {
1578 if !client_id.is_empty()
1580 && client_id.len() <= self.config.max_client_id_length
1581 && is_valid_request_id(client_id)
1582 {
1583 return RequestId::new(client_id);
1584 }
1585 }
1586 }
1587 }
1588 RequestId::generate()
1589 }
1590}
1591
1592fn is_valid_request_id(id: &str) -> bool {
1594 !id.is_empty()
1595 && id
1596 .chars()
1597 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
1598}
1599
1600impl Middleware for RequestIdMiddleware {
1601 fn before<'a>(
1602 &'a self,
1603 _ctx: &'a RequestContext,
1604 req: &'a mut Request,
1605 ) -> BoxFuture<'a, ControlFlow> {
1606 let request_id = self.get_or_generate_id(req);
1607 req.insert_extension(request_id);
1608 Box::pin(async { ControlFlow::Continue })
1609 }
1610
1611 fn after<'a>(
1612 &'a self,
1613 _ctx: &'a RequestContext,
1614 req: &'a Request,
1615 response: Response,
1616 ) -> BoxFuture<'a, Response> {
1617 if !self.config.add_to_response {
1618 return Box::pin(async move { response });
1619 }
1620
1621 let request_id = req.get_extension::<RequestId>().cloned();
1622 let header_name = self.config.header_name.clone();
1623
1624 Box::pin(async move {
1625 if let Some(id) = request_id {
1626 response.header(header_name, id.0.into_bytes())
1627 } else {
1628 response
1629 }
1630 })
1631 }
1632
1633 fn name(&self) -> &'static str {
1634 "RequestId"
1635 }
1636}
1637
1638#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1646pub enum XFrameOptions {
1647 Deny,
1649 SameOrigin,
1651}
1652
1653impl XFrameOptions {
1654 fn as_bytes(self) -> &'static [u8] {
1655 match self {
1656 Self::Deny => b"DENY",
1657 Self::SameOrigin => b"SAMEORIGIN",
1658 }
1659 }
1660}
1661
1662#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1666pub enum ReferrerPolicy {
1667 NoReferrer,
1669 NoReferrerWhenDowngrade,
1671 Origin,
1673 OriginWhenCrossOrigin,
1675 SameOrigin,
1677 StrictOrigin,
1679 StrictOriginWhenCrossOrigin,
1681 UnsafeUrl,
1683}
1684
1685impl ReferrerPolicy {
1686 fn as_bytes(self) -> &'static [u8] {
1687 match self {
1688 Self::NoReferrer => b"no-referrer",
1689 Self::NoReferrerWhenDowngrade => b"no-referrer-when-downgrade",
1690 Self::Origin => b"origin",
1691 Self::OriginWhenCrossOrigin => b"origin-when-cross-origin",
1692 Self::SameOrigin => b"same-origin",
1693 Self::StrictOrigin => b"strict-origin",
1694 Self::StrictOriginWhenCrossOrigin => b"strict-origin-when-cross-origin",
1695 Self::UnsafeUrl => b"unsafe-url",
1696 }
1697 }
1698}
1699
1700#[derive(Debug, Clone)]
1724pub struct SecurityHeadersConfig {
1725 pub x_content_type_options: Option<&'static str>,
1728 pub x_frame_options: Option<XFrameOptions>,
1731 pub x_xss_protection: Option<&'static str>,
1737 pub content_security_policy: Option<String>,
1740 pub hsts: Option<(u64, bool, bool)>,
1744 pub referrer_policy: Option<ReferrerPolicy>,
1747 pub permissions_policy: Option<String>,
1750}
1751
1752impl Default for SecurityHeadersConfig {
1753 fn default() -> Self {
1754 Self {
1755 x_content_type_options: Some("nosniff"),
1756 x_frame_options: Some(XFrameOptions::Deny),
1757 x_xss_protection: Some("0"),
1758 content_security_policy: None,
1759 hsts: None,
1760 referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
1761 permissions_policy: None,
1762 }
1763 }
1764}
1765
1766impl SecurityHeadersConfig {
1767 #[must_use]
1769 pub fn new() -> Self {
1770 Self::default()
1771 }
1772
1773 #[must_use]
1775 pub fn none() -> Self {
1776 Self {
1777 x_content_type_options: None,
1778 x_frame_options: None,
1779 x_xss_protection: None,
1780 content_security_policy: None,
1781 hsts: None,
1782 referrer_policy: None,
1783 permissions_policy: None,
1784 }
1785 }
1786
1787 #[must_use]
1794 pub fn strict() -> Self {
1795 Self {
1796 x_content_type_options: Some("nosniff"),
1797 x_frame_options: Some(XFrameOptions::Deny),
1798 x_xss_protection: Some("0"),
1799 content_security_policy: Some("default-src 'self'".to_string()),
1800 hsts: Some((31536000, true, false)), referrer_policy: Some(ReferrerPolicy::NoReferrer),
1802 permissions_policy: Some("geolocation=(), camera=(), microphone=()".to_string()),
1803 }
1804 }
1805
1806 #[must_use]
1808 pub fn x_content_type_options(mut self, value: Option<&'static str>) -> Self {
1809 self.x_content_type_options = value;
1810 self
1811 }
1812
1813 #[must_use]
1815 pub fn x_frame_options(mut self, value: Option<XFrameOptions>) -> Self {
1816 self.x_frame_options = value;
1817 self
1818 }
1819
1820 #[must_use]
1822 pub fn x_xss_protection(mut self, value: Option<&'static str>) -> Self {
1823 self.x_xss_protection = value;
1824 self
1825 }
1826
1827 #[must_use]
1829 pub fn content_security_policy(mut self, value: impl Into<String>) -> Self {
1830 self.content_security_policy = Some(value.into());
1831 self
1832 }
1833
1834 #[must_use]
1836 pub fn no_content_security_policy(mut self) -> Self {
1837 self.content_security_policy = None;
1838 self
1839 }
1840
1841 #[must_use]
1854 pub fn hsts(mut self, max_age: u64, include_sub_domains: bool, preload: bool) -> Self {
1855 self.hsts = Some((max_age, include_sub_domains, preload));
1856 self
1857 }
1858
1859 #[must_use]
1861 pub fn no_hsts(mut self) -> Self {
1862 self.hsts = None;
1863 self
1864 }
1865
1866 #[must_use]
1868 pub fn referrer_policy(mut self, value: Option<ReferrerPolicy>) -> Self {
1869 self.referrer_policy = value;
1870 self
1871 }
1872
1873 #[must_use]
1875 pub fn permissions_policy(mut self, value: impl Into<String>) -> Self {
1876 self.permissions_policy = Some(value.into());
1877 self
1878 }
1879
1880 #[must_use]
1882 pub fn no_permissions_policy(mut self) -> Self {
1883 self.permissions_policy = None;
1884 self
1885 }
1886
1887 fn build_hsts_value(&self) -> Option<String> {
1889 self.hsts.map(|(max_age, include_sub, preload)| {
1890 let mut value = format!("max-age={}", max_age);
1891 if include_sub {
1892 value.push_str("; includeSubDomains");
1893 }
1894 if preload {
1895 value.push_str("; preload");
1896 }
1897 value
1898 })
1899 }
1900}
1901
1902#[derive(Debug, Clone)]
1933pub struct SecurityHeaders {
1934 config: SecurityHeadersConfig,
1935}
1936
1937impl Default for SecurityHeaders {
1938 fn default() -> Self {
1939 Self::new()
1940 }
1941}
1942
1943impl SecurityHeaders {
1944 #[must_use]
1946 pub fn new() -> Self {
1947 Self {
1948 config: SecurityHeadersConfig::default(),
1949 }
1950 }
1951
1952 #[must_use]
1954 pub fn with_config(config: SecurityHeadersConfig) -> Self {
1955 Self { config }
1956 }
1957
1958 #[must_use]
1960 pub fn strict() -> Self {
1961 Self {
1962 config: SecurityHeadersConfig::strict(),
1963 }
1964 }
1965}
1966
1967impl Middleware for SecurityHeaders {
1968 fn after<'a>(
1969 &'a self,
1970 _ctx: &'a RequestContext,
1971 _req: &'a Request,
1972 response: Response,
1973 ) -> BoxFuture<'a, Response> {
1974 let config = self.config.clone();
1975 Box::pin(async move {
1976 let mut resp = response;
1977
1978 if let Some(value) = config.x_content_type_options {
1980 resp = resp.header("X-Content-Type-Options", value.as_bytes().to_vec());
1981 }
1982
1983 if let Some(value) = config.x_frame_options {
1985 resp = resp.header("X-Frame-Options", value.as_bytes().to_vec());
1986 }
1987
1988 if let Some(value) = config.x_xss_protection {
1990 resp = resp.header("X-XSS-Protection", value.as_bytes().to_vec());
1991 }
1992
1993 if let Some(ref value) = config.content_security_policy {
1995 resp = resp.header("Content-Security-Policy", value.as_bytes().to_vec());
1996 }
1997
1998 if let Some(ref hsts_value) = config.build_hsts_value() {
2000 resp = resp.header("Strict-Transport-Security", hsts_value.as_bytes().to_vec());
2001 }
2002
2003 if let Some(value) = config.referrer_policy {
2005 resp = resp.header("Referrer-Policy", value.as_bytes().to_vec());
2006 }
2007
2008 if let Some(ref value) = config.permissions_policy {
2010 resp = resp.header("Permissions-Policy", value.as_bytes().to_vec());
2011 }
2012
2013 resp
2014 })
2015 }
2016
2017 fn name(&self) -> &'static str {
2018 "SecurityHeaders"
2019 }
2020}
2021
2022#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2031pub struct CsrfToken(pub String);
2032
2033impl CsrfToken {
2034 #[must_use]
2036 pub fn new(token: impl Into<String>) -> Self {
2037 Self(token.into())
2038 }
2039
2040 #[must_use]
2042 pub fn as_str(&self) -> &str {
2043 &self.0
2044 }
2045
2046 #[must_use]
2055 pub fn generate() -> Self {
2056 let bytes = Self::read_urandom(32).unwrap_or_else(|_| {
2058 panic!(
2059 "FATAL: Cryptographically secure random source (/dev/urandom) is unavailable. \
2060 CSRF token generation requires a CSPRNG. Cannot safely generate CSRF tokens \
2061 without cryptographic entropy."
2062 );
2063 });
2064 Self(Self::bytes_to_hex(&bytes))
2065 }
2066
2067 fn read_urandom(len: usize) -> std::io::Result<Vec<u8>> {
2068 use std::io::Read;
2069 let mut f = std::fs::File::open("/dev/urandom")?;
2070 let mut buf = vec![0u8; len];
2071 f.read_exact(&mut buf)?;
2072 Ok(buf)
2073 }
2074
2075 fn bytes_to_hex(bytes: &[u8]) -> String {
2076 use std::fmt::Write;
2077 let mut s = String::with_capacity(bytes.len() * 2);
2078 for b in bytes {
2079 let _ = write!(s, "{b:02x}");
2080 }
2081 s
2082 }
2083}
2084
2085impl std::fmt::Display for CsrfToken {
2086 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2087 f.write_str(&self.0)
2088 }
2089}
2090
2091impl From<&str> for CsrfToken {
2092 fn from(s: &str) -> Self {
2093 Self(s.to_string())
2094 }
2095}
2096
2097#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
2099pub enum CsrfMode {
2100 #[default]
2103 DoubleSubmit,
2104 HeaderOnly,
2106}
2107
2108#[derive(Debug, Clone)]
2110pub struct CsrfConfig {
2111 pub cookie_name: String,
2113 pub header_name: String,
2115 pub mode: CsrfMode,
2117 pub rotate_token: bool,
2119 pub production: bool,
2121 pub error_message: Option<String>,
2123}
2124
2125impl Default for CsrfConfig {
2126 fn default() -> Self {
2127 Self {
2128 cookie_name: "csrf_token".to_string(),
2129 header_name: "x-csrf-token".to_string(),
2130 mode: CsrfMode::DoubleSubmit,
2131 rotate_token: false,
2132 production: true,
2133 error_message: None,
2134 }
2135 }
2136}
2137
2138impl CsrfConfig {
2139 #[must_use]
2141 pub fn new() -> Self {
2142 Self::default()
2143 }
2144
2145 #[must_use]
2147 pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
2148 self.cookie_name = name.into();
2149 self
2150 }
2151
2152 #[must_use]
2154 pub fn header_name(mut self, name: impl Into<String>) -> Self {
2155 self.header_name = name.into();
2156 self
2157 }
2158
2159 #[must_use]
2161 pub fn mode(mut self, mode: CsrfMode) -> Self {
2162 self.mode = mode;
2163 self
2164 }
2165
2166 #[must_use]
2168 pub fn rotate_token(mut self, rotate: bool) -> Self {
2169 self.rotate_token = rotate;
2170 self
2171 }
2172
2173 #[must_use]
2175 pub fn production(mut self, production: bool) -> Self {
2176 self.production = production;
2177 self
2178 }
2179
2180 #[must_use]
2182 pub fn error_message(mut self, message: impl Into<String>) -> Self {
2183 self.error_message = Some(message.into());
2184 self
2185 }
2186}
2187
2188#[derive(Debug, Clone)]
2218pub struct CsrfMiddleware {
2219 config: CsrfConfig,
2220}
2221
2222impl Default for CsrfMiddleware {
2223 fn default() -> Self {
2224 Self::new()
2225 }
2226}
2227
2228impl CsrfMiddleware {
2229 #[must_use]
2231 pub fn new() -> Self {
2232 Self {
2233 config: CsrfConfig::default(),
2234 }
2235 }
2236
2237 #[must_use]
2239 pub fn with_config(config: CsrfConfig) -> Self {
2240 Self { config }
2241 }
2242
2243 fn is_safe_method(method: crate::request::Method) -> bool {
2245 matches!(
2246 method,
2247 crate::request::Method::Get
2248 | crate::request::Method::Head
2249 | crate::request::Method::Options
2250 | crate::request::Method::Trace
2251 )
2252 }
2253
2254 fn get_cookie_token(&self, req: &Request) -> Option<String> {
2256 let cookie_header = req.headers().get("cookie")?;
2257 let cookie_str = std::str::from_utf8(cookie_header).ok()?;
2258
2259 for part in cookie_str.split(';') {
2261 let part = part.trim();
2262 if let Some((name, value)) = part.split_once('=') {
2263 if name.trim() == self.config.cookie_name {
2264 return Some(value.trim().to_string());
2265 }
2266 }
2267 }
2268 None
2269 }
2270
2271 fn get_header_token(&self, req: &Request) -> Option<String> {
2273 let header_value = req.headers().get(&self.config.header_name)?;
2274 std::str::from_utf8(header_value)
2275 .ok()
2276 .map(|s| s.trim().to_string())
2277 }
2278
2279 fn validate_token(&self, req: &Request) -> Result<Option<CsrfToken>, Response> {
2281 let header_token = self.get_header_token(req);
2282
2283 match self.config.mode {
2284 CsrfMode::DoubleSubmit => {
2285 let cookie_token = self.get_cookie_token(req);
2286
2287 match (header_token, cookie_token) {
2288 (Some(header), Some(cookie))
2289 if !header.is_empty()
2290 && crate::extract::constant_time_eq(
2291 header.as_bytes(),
2292 cookie.as_bytes(),
2293 ) =>
2294 {
2295 Ok(Some(CsrfToken::new(header)))
2296 }
2297 (None, _) | (_, None) => Err(self.csrf_error_response("CSRF token missing")),
2298 _ => Err(self.csrf_error_response("CSRF token mismatch")),
2299 }
2300 }
2301 CsrfMode::HeaderOnly => match header_token {
2302 Some(token) if !token.is_empty() => Ok(Some(CsrfToken::new(token))),
2303 _ => Err(self.csrf_error_response("CSRF token missing in header")),
2304 },
2305 }
2306 }
2307
2308 fn csrf_error_response(&self, default_message: &str) -> Response {
2310 let message = self
2311 .config
2312 .error_message
2313 .as_deref()
2314 .unwrap_or(default_message);
2315
2316 let body = format!(
2318 r#"{{"detail":[{{"type":"csrf_error","loc":["header","{}"],"msg":"{}"}}]}}"#,
2319 self.config.header_name, message
2320 );
2321
2322 Response::with_status(crate::response::StatusCode::FORBIDDEN)
2323 .header("content-type", b"application/json".to_vec())
2324 .body(crate::response::ResponseBody::Bytes(body.into_bytes()))
2325 }
2326
2327 fn make_set_cookie_header_value(cookie_name: &str, token: &str, production: bool) -> Vec<u8> {
2329 let mut cookie = format!("{}={}; Path=/; SameSite=Strict", cookie_name, token);
2330
2331 if production {
2332 cookie.push_str("; Secure");
2333 }
2334
2335 cookie.into_bytes()
2338 }
2339}
2340
2341impl Middleware for CsrfMiddleware {
2342 fn before<'a>(
2343 &'a self,
2344 _ctx: &'a RequestContext,
2345 req: &'a mut Request,
2346 ) -> BoxFuture<'a, ControlFlow> {
2347 Box::pin(async move {
2348 if Self::is_safe_method(req.method()) {
2349 let existing_token = self.get_cookie_token(req);
2351 let token = existing_token
2352 .map(CsrfToken::new)
2353 .unwrap_or_else(CsrfToken::generate);
2354 req.insert_extension(token);
2355 ControlFlow::Continue
2356 } else {
2357 match self.validate_token(req) {
2359 Ok(Some(token)) => {
2360 req.insert_extension(token);
2361 ControlFlow::Continue
2362 }
2363 Ok(None) => ControlFlow::Continue,
2364 Err(response) => ControlFlow::Break(response),
2365 }
2366 }
2367 })
2368 }
2369
2370 fn after<'a>(
2371 &'a self,
2372 _ctx: &'a RequestContext,
2373 req: &'a Request,
2374 response: Response,
2375 ) -> BoxFuture<'a, Response> {
2376 let config = self.config.clone();
2377 let is_safe = Self::is_safe_method(req.method());
2378 let existing_cookie_token = self.get_cookie_token(req);
2379 let token = req.get_extension::<CsrfToken>().cloned();
2380
2381 Box::pin(async move {
2382 if is_safe {
2386 let should_set_cookie = existing_cookie_token.is_none() || config.rotate_token;
2387
2388 if should_set_cookie {
2389 if let Some(token) = token {
2390 let cookie_value = Self::make_set_cookie_header_value(
2391 &config.cookie_name,
2392 token.as_str(),
2393 config.production,
2394 );
2395 return response.header("set-cookie", cookie_value);
2396 }
2397 }
2398 }
2399 response
2400 })
2401 }
2402
2403 fn name(&self) -> &'static str {
2404 "CSRF"
2405 }
2406}
2407
2408#[cfg(feature = "compression")]
2431#[derive(Debug, Clone)]
2432pub struct CompressionConfig {
2433 pub min_size: usize,
2437 pub level: u32,
2440 pub skip_content_types: Vec<&'static str>,
2443}
2444
2445#[cfg(feature = "compression")]
2446impl Default for CompressionConfig {
2447 fn default() -> Self {
2448 Self {
2449 min_size: 1024,
2450 level: 6,
2451 skip_content_types: vec![
2452 "image/jpeg",
2454 "image/png",
2455 "image/gif",
2456 "image/webp",
2457 "image/avif",
2458 "video/",
2460 "audio/",
2461 "application/zip",
2463 "application/gzip",
2464 "application/x-gzip",
2465 "application/x-bzip2",
2466 "application/x-xz",
2467 "application/x-7z-compressed",
2468 "application/x-rar-compressed",
2469 "application/pdf",
2471 "application/woff",
2472 "application/woff2",
2473 "font/woff",
2474 "font/woff2",
2475 ],
2476 }
2477 }
2478}
2479
2480#[cfg(feature = "compression")]
2481impl CompressionConfig {
2482 #[must_use]
2484 pub fn new() -> Self {
2485 Self::default()
2486 }
2487
2488 #[must_use]
2493 pub fn min_size(mut self, size: usize) -> Self {
2494 self.min_size = size;
2495 self
2496 }
2497
2498 #[must_use]
2506 pub fn level(mut self, level: u32) -> Self {
2507 self.level = level.clamp(1, 9);
2508 self
2509 }
2510
2511 #[must_use]
2515 pub fn skip_content_type(mut self, content_type: &'static str) -> Self {
2516 self.skip_content_types.push(content_type);
2517 self
2518 }
2519
2520 fn should_skip_content_type(&self, content_type: &str) -> bool {
2522 let ct_lower = content_type.to_ascii_lowercase();
2523 for skip in &self.skip_content_types {
2524 if skip.ends_with('/') {
2525 if ct_lower.starts_with(*skip) {
2527 return true;
2528 }
2529 } else {
2530 if ct_lower == *skip || ct_lower.starts_with(&format!("{skip};")) {
2532 return true;
2533 }
2534 }
2535 }
2536 false
2537 }
2538}
2539
2540#[cfg(feature = "compression")]
2573#[derive(Debug, Clone)]
2574pub struct CompressionMiddleware {
2575 config: CompressionConfig,
2576}
2577
2578#[cfg(feature = "compression")]
2579impl Default for CompressionMiddleware {
2580 fn default() -> Self {
2581 Self::new()
2582 }
2583}
2584
2585#[cfg(feature = "compression")]
2586impl CompressionMiddleware {
2587 #[must_use]
2589 pub fn new() -> Self {
2590 Self {
2591 config: CompressionConfig::default(),
2592 }
2593 }
2594
2595 #[must_use]
2597 pub fn with_config(config: CompressionConfig) -> Self {
2598 Self { config }
2599 }
2600
2601 fn accepts_gzip(req: &Request) -> bool {
2603 if let Some(accept_encoding) = req.headers().get("accept-encoding") {
2604 if let Ok(value) = std::str::from_utf8(accept_encoding) {
2605 for part in value.split(',') {
2608 let encoding = part.trim().split(';').next().unwrap_or("").trim();
2609 if encoding.eq_ignore_ascii_case("gzip") {
2610 return true;
2611 }
2612 if encoding == "*" {
2614 return true;
2615 }
2616 }
2617 }
2618 }
2619 false
2620 }
2621
2622 fn get_content_type(headers: &[(String, Vec<u8>)]) -> Option<String> {
2624 for (name, value) in headers {
2625 if name.eq_ignore_ascii_case("content-type") {
2626 return std::str::from_utf8(value).ok().map(String::from);
2627 }
2628 }
2629 None
2630 }
2631
2632 fn has_content_encoding(headers: &[(String, Vec<u8>)]) -> bool {
2634 headers
2635 .iter()
2636 .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
2637 }
2638
2639 fn compress_gzip(data: &[u8], level: u32) -> Result<Vec<u8>, std::io::Error> {
2641 use flate2::Compression;
2642 use flate2::write::GzEncoder;
2643 use std::io::Write;
2644
2645 let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
2646 encoder.write_all(data)?;
2647 encoder.finish()
2648 }
2649}
2650
2651#[cfg(feature = "compression")]
2652impl Middleware for CompressionMiddleware {
2653 fn after<'a>(
2654 &'a self,
2655 _ctx: &'a RequestContext,
2656 req: &'a Request,
2657 response: Response,
2658 ) -> BoxFuture<'a, Response> {
2659 let config = self.config.clone();
2660
2661 Box::pin(async move {
2662 if !Self::accepts_gzip(req) {
2664 return response;
2665 }
2666
2667 let (status, headers, body) = response.into_parts();
2669
2670 if Self::has_content_encoding(&headers) {
2672 return Response::with_status(status)
2673 .body(body)
2674 .rebuild_with_headers(headers);
2675 }
2676
2677 let body_bytes = match body {
2679 crate::response::ResponseBody::Bytes(bytes) => bytes,
2680 other => {
2681 return Response::with_status(status)
2683 .body(other)
2684 .rebuild_with_headers(headers);
2685 }
2686 };
2687
2688 if body_bytes.len() < config.min_size {
2690 return Response::with_status(status)
2691 .body(crate::response::ResponseBody::Bytes(body_bytes))
2692 .rebuild_with_headers(headers);
2693 }
2694
2695 if let Some(content_type) = Self::get_content_type(&headers) {
2697 if config.should_skip_content_type(&content_type) {
2698 return Response::with_status(status)
2699 .body(crate::response::ResponseBody::Bytes(body_bytes))
2700 .rebuild_with_headers(headers);
2701 }
2702 }
2703
2704 match Self::compress_gzip(&body_bytes, config.level) {
2706 Ok(compressed) => {
2707 if compressed.len() >= body_bytes.len() {
2709 return Response::with_status(status)
2710 .body(crate::response::ResponseBody::Bytes(body_bytes))
2711 .rebuild_with_headers(headers);
2712 }
2713
2714 let mut resp = Response::with_status(status)
2716 .body(crate::response::ResponseBody::Bytes(compressed));
2717
2718 for (name, value) in headers {
2720 if !name.eq_ignore_ascii_case("content-length") {
2721 resp = resp.header(name, value);
2722 }
2723 }
2724
2725 resp = resp.header("Content-Encoding", b"gzip".to_vec());
2727 resp = resp.header("Vary", b"Accept-Encoding".to_vec());
2728
2729 resp
2730 }
2731 Err(_) => {
2732 Response::with_status(status)
2734 .body(crate::response::ResponseBody::Bytes(body_bytes))
2735 .rebuild_with_headers(headers)
2736 }
2737 }
2738 })
2739 }
2740
2741 fn name(&self) -> &'static str {
2742 "Compression"
2743 }
2744}
2745
2746use parking_lot::Mutex;
2751use std::collections::HashMap as StdHashMap;
2752use std::time::Duration;
2753
2754#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2756pub enum RateLimitAlgorithm {
2757 TokenBucket,
2759 FixedWindow,
2761 SlidingWindow,
2763}
2764
2765#[derive(Debug, Clone)]
2767pub struct RateLimitResult {
2768 pub allowed: bool,
2770 pub limit: u64,
2772 pub remaining: u64,
2774 pub reset_after_secs: u64,
2776}
2777
2778pub trait KeyExtractor: Send + Sync {
2783 fn extract_key(&self, req: &Request) -> Option<String>;
2787}
2788
2789#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
2807pub struct RemoteAddr(pub std::net::IpAddr);
2808
2809impl std::fmt::Display for RemoteAddr {
2810 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2811 write!(f, "{}", self.0)
2812 }
2813}
2814
2815#[derive(Debug, Clone)]
2836pub struct ConnectedIpKeyExtractor;
2837
2838impl KeyExtractor for ConnectedIpKeyExtractor {
2839 fn extract_key(&self, req: &Request) -> Option<String> {
2840 req.get_extension::<RemoteAddr>().map(ToString::to_string)
2841 }
2842}
2843
2844#[derive(Debug, Clone)]
2873pub struct IpKeyExtractor;
2874
2875impl KeyExtractor for IpKeyExtractor {
2876 fn extract_key(&self, req: &Request) -> Option<String> {
2877 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
2879 if let Ok(s) = std::str::from_utf8(forwarded) {
2880 if let Some(ip) = s.split(',').next() {
2882 return Some(ip.trim().to_string());
2883 }
2884 }
2885 }
2886 if let Some(real_ip) = req.headers().get("x-real-ip") {
2887 if let Ok(s) = std::str::from_utf8(real_ip) {
2888 return Some(s.trim().to_string());
2889 }
2890 }
2891 Some("unknown".to_string())
2892 }
2893}
2894
2895#[derive(Debug, Clone)]
2926pub struct TrustedProxyIpKeyExtractor {
2927 trusted_cidrs: Vec<(std::net::IpAddr, u8)>,
2929}
2930
2931impl TrustedProxyIpKeyExtractor {
2932 #[must_use]
2934 pub fn new() -> Self {
2935 Self {
2936 trusted_cidrs: Vec::new(),
2937 }
2938 }
2939
2940 #[must_use]
2946 pub fn trust_cidr(mut self, cidr: &str) -> Self {
2947 let (ip, prefix) = parse_cidr(cidr).expect("invalid CIDR notation");
2948 self.trusted_cidrs.push((ip, prefix));
2949 self
2950 }
2951
2952 #[must_use]
2954 pub fn trust_loopback(mut self) -> Self {
2955 self.trusted_cidrs.push((
2956 std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 0)),
2957 8,
2958 ));
2959 self.trusted_cidrs
2960 .push((std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST), 128));
2961 self
2962 }
2963
2964 fn is_trusted(&self, ip: std::net::IpAddr) -> bool {
2966 self.trusted_cidrs
2967 .iter()
2968 .any(|(cidr_ip, prefix)| ip_in_cidr(ip, *cidr_ip, *prefix))
2969 }
2970
2971 fn extract_from_header(&self, req: &Request) -> Option<String> {
2973 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
2974 if let Ok(s) = std::str::from_utf8(forwarded) {
2975 if let Some(ip) = s.split(',').next() {
2976 return Some(ip.trim().to_string());
2977 }
2978 }
2979 }
2980 if let Some(real_ip) = req.headers().get("x-real-ip") {
2981 if let Ok(s) = std::str::from_utf8(real_ip) {
2982 return Some(s.trim().to_string());
2983 }
2984 }
2985 None
2986 }
2987}
2988
2989impl Default for TrustedProxyIpKeyExtractor {
2990 fn default() -> Self {
2991 Self::new()
2992 }
2993}
2994
2995impl KeyExtractor for TrustedProxyIpKeyExtractor {
2996 fn extract_key(&self, req: &Request) -> Option<String> {
2997 let remote = req.get_extension::<RemoteAddr>()?;
2998
2999 if self.is_trusted(remote.0) {
3000 self.extract_from_header(req)
3002 .or_else(|| Some(remote.to_string()))
3003 } else {
3004 Some(remote.to_string())
3006 }
3007 }
3008}
3009
3010fn parse_cidr(cidr: &str) -> Option<(std::net::IpAddr, u8)> {
3012 let (ip_str, prefix_str) = cidr.split_once('/')?;
3013 let ip: std::net::IpAddr = ip_str.parse().ok()?;
3014 let prefix: u8 = prefix_str.parse().ok()?;
3015
3016 let max_prefix = match ip {
3018 std::net::IpAddr::V4(_) => 32,
3019 std::net::IpAddr::V6(_) => 128,
3020 };
3021 if prefix > max_prefix {
3022 return None;
3023 }
3024
3025 Some((ip, prefix))
3026}
3027
3028fn ip_in_cidr(ip: std::net::IpAddr, cidr_ip: std::net::IpAddr, prefix: u8) -> bool {
3030 match (ip, cidr_ip) {
3031 (std::net::IpAddr::V4(ip), std::net::IpAddr::V4(cidr)) => {
3032 if prefix == 0 {
3033 return true;
3034 }
3035 let ip_bits = u32::from(ip);
3036 let cidr_bits = u32::from(cidr);
3037 let mask = !0u32 << (32 - prefix);
3038 (ip_bits & mask) == (cidr_bits & mask)
3039 }
3040 (std::net::IpAddr::V6(ip), std::net::IpAddr::V6(cidr)) => {
3041 if prefix == 0 {
3042 return true;
3043 }
3044 let ip_bits = u128::from(ip);
3045 let cidr_bits = u128::from(cidr);
3046 let mask = !0u128 << (128 - prefix);
3047 (ip_bits & mask) == (cidr_bits & mask)
3048 }
3049 _ => false, }
3051}
3052
3053#[derive(Debug, Clone)]
3055pub struct HeaderKeyExtractor {
3056 header_name: String,
3057}
3058
3059impl HeaderKeyExtractor {
3060 #[must_use]
3062 pub fn new(header_name: impl Into<String>) -> Self {
3063 Self {
3064 header_name: header_name.into(),
3065 }
3066 }
3067}
3068
3069impl KeyExtractor for HeaderKeyExtractor {
3070 fn extract_key(&self, req: &Request) -> Option<String> {
3071 req.headers()
3072 .get(&self.header_name)
3073 .and_then(|v| std::str::from_utf8(v).ok())
3074 .map(str::to_string)
3075 }
3076}
3077
3078#[derive(Debug, Clone)]
3080pub struct PathKeyExtractor;
3081
3082impl KeyExtractor for PathKeyExtractor {
3083 fn extract_key(&self, req: &Request) -> Option<String> {
3084 Some(req.path().to_string())
3085 }
3086}
3087
3088pub struct CompositeKeyExtractor {
3093 extractors: Vec<Box<dyn KeyExtractor>>,
3094}
3095
3096impl CompositeKeyExtractor {
3097 #[must_use]
3099 pub fn new(extractors: Vec<Box<dyn KeyExtractor>>) -> Self {
3100 Self { extractors }
3101 }
3102}
3103
3104impl KeyExtractor for CompositeKeyExtractor {
3105 fn extract_key(&self, req: &Request) -> Option<String> {
3106 let parts: Vec<String> = self
3107 .extractors
3108 .iter()
3109 .filter_map(|e| e.extract_key(req))
3110 .collect();
3111 if parts.is_empty() {
3112 None
3113 } else {
3114 Some(parts.join(":"))
3115 }
3116 }
3117}
3118
3119#[derive(Debug, Clone)]
3121struct TokenBucketState {
3122 tokens: f64,
3123 last_refill: Instant,
3124}
3125
3126#[derive(Debug, Clone)]
3128struct FixedWindowState {
3129 count: u64,
3130 window_start: Instant,
3131}
3132
3133#[derive(Debug, Clone)]
3135struct SlidingWindowState {
3136 current_count: u64,
3137 previous_count: u64,
3138 current_window_start: Instant,
3139}
3140
3141pub struct InMemoryRateLimitStore {
3147 token_buckets: Mutex<StdHashMap<String, TokenBucketState>>,
3148 fixed_windows: Mutex<StdHashMap<String, FixedWindowState>>,
3149 sliding_windows: Mutex<StdHashMap<String, SlidingWindowState>>,
3150}
3151
3152impl InMemoryRateLimitStore {
3153 #[must_use]
3155 pub fn new() -> Self {
3156 Self {
3157 token_buckets: Mutex::new(StdHashMap::new()),
3158 fixed_windows: Mutex::new(StdHashMap::new()),
3159 sliding_windows: Mutex::new(StdHashMap::new()),
3160 }
3161 }
3162
3163 #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
3164 fn check_token_bucket(
3165 &self,
3166 key: &str,
3167 max_tokens: u64,
3168 refill_rate: f64,
3169 window: Duration,
3170 ) -> RateLimitResult {
3171 let mut buckets = self.token_buckets.lock();
3172 let now = Instant::now();
3173
3174 let state = buckets
3175 .entry(key.to_string())
3176 .or_insert_with(|| TokenBucketState {
3177 tokens: max_tokens as f64,
3178 last_refill: now,
3179 });
3180
3181 let elapsed = now.duration_since(state.last_refill);
3183 let refill = elapsed.as_secs_f64() * refill_rate;
3184 state.tokens = (state.tokens + refill).min(max_tokens as f64);
3185 state.last_refill = now;
3186
3187 if state.tokens >= 1.0 {
3188 state.tokens -= 1.0;
3189 RateLimitResult {
3190 allowed: true,
3191 limit: max_tokens,
3192 remaining: state.tokens as u64,
3193 reset_after_secs: if state.tokens < max_tokens as f64 {
3194 ((max_tokens as f64 - state.tokens) / refill_rate).ceil() as u64
3195 } else {
3196 window.as_secs()
3197 },
3198 }
3199 } else {
3200 let wait_secs = ((1.0 - state.tokens) / refill_rate).ceil() as u64;
3201 RateLimitResult {
3202 allowed: false,
3203 limit: max_tokens,
3204 remaining: 0,
3205 reset_after_secs: wait_secs,
3206 }
3207 }
3208 }
3209
3210 fn check_fixed_window(
3211 &self,
3212 key: &str,
3213 max_requests: u64,
3214 window: Duration,
3215 ) -> RateLimitResult {
3216 let mut windows = self.fixed_windows.lock();
3217 let now = Instant::now();
3218
3219 let state = windows
3220 .entry(key.to_string())
3221 .or_insert_with(|| FixedWindowState {
3222 count: 0,
3223 window_start: now,
3224 });
3225
3226 let elapsed = now.duration_since(state.window_start);
3228 if elapsed >= window {
3229 state.count = 0;
3230 state.window_start = now;
3231 }
3232
3233 let remaining_time = window
3234 .checked_sub(now.duration_since(state.window_start))
3235 .unwrap_or(Duration::ZERO);
3236
3237 if state.count < max_requests {
3238 state.count += 1;
3239 RateLimitResult {
3240 allowed: true,
3241 limit: max_requests,
3242 remaining: max_requests - state.count,
3243 reset_after_secs: remaining_time.as_secs(),
3244 }
3245 } else {
3246 RateLimitResult {
3247 allowed: false,
3248 limit: max_requests,
3249 remaining: 0,
3250 reset_after_secs: remaining_time.as_secs(),
3251 }
3252 }
3253 }
3254
3255 #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
3256 fn check_sliding_window(
3257 &self,
3258 key: &str,
3259 max_requests: u64,
3260 window: Duration,
3261 ) -> RateLimitResult {
3262 let mut windows = self.sliding_windows.lock();
3263 let now = Instant::now();
3264
3265 let state = windows
3266 .entry(key.to_string())
3267 .or_insert_with(|| SlidingWindowState {
3268 current_count: 0,
3269 previous_count: 0,
3270 current_window_start: now,
3271 });
3272
3273 let elapsed = now.duration_since(state.current_window_start);
3275 if elapsed >= window {
3276 state.previous_count = state.current_count;
3278 state.current_count = 0;
3279 state.current_window_start = now;
3280 }
3281
3282 let window_elapsed = now.duration_since(state.current_window_start);
3285 let window_fraction = window_elapsed.as_secs_f64() / window.as_secs_f64();
3286 let previous_weight = 1.0 - window_fraction;
3287 let weighted_count =
3288 (state.previous_count as f64 * previous_weight) + state.current_count as f64;
3289
3290 let remaining_time = window.checked_sub(window_elapsed).unwrap_or(Duration::ZERO);
3291
3292 if weighted_count < max_requests as f64 {
3293 state.current_count += 1;
3294 let new_weighted =
3295 (state.previous_count as f64 * previous_weight) + state.current_count as f64;
3296 let remaining = (max_requests as f64 - new_weighted).max(0.0) as u64;
3297 RateLimitResult {
3298 allowed: true,
3299 limit: max_requests,
3300 remaining,
3301 reset_after_secs: remaining_time.as_secs(),
3302 }
3303 } else {
3304 RateLimitResult {
3305 allowed: false,
3306 limit: max_requests,
3307 remaining: 0,
3308 reset_after_secs: remaining_time.as_secs(),
3309 }
3310 }
3311 }
3312
3313 #[allow(clippy::cast_precision_loss)]
3315 pub fn check(
3316 &self,
3317 key: &str,
3318 algorithm: RateLimitAlgorithm,
3319 max_requests: u64,
3320 window: Duration,
3321 ) -> RateLimitResult {
3322 match algorithm {
3323 RateLimitAlgorithm::TokenBucket => {
3324 let refill_rate = max_requests as f64 / window.as_secs_f64();
3325 self.check_token_bucket(key, max_requests, refill_rate, window)
3326 }
3327 RateLimitAlgorithm::FixedWindow => self.check_fixed_window(key, max_requests, window),
3328 RateLimitAlgorithm::SlidingWindow => {
3329 self.check_sliding_window(key, max_requests, window)
3330 }
3331 }
3332 }
3333}
3334
3335impl Default for InMemoryRateLimitStore {
3336 fn default() -> Self {
3337 Self::new()
3338 }
3339}
3340
3341#[derive(Clone)]
3375pub struct RateLimitConfig {
3376 pub max_requests: u64,
3378 pub window: Duration,
3380 pub algorithm: RateLimitAlgorithm,
3382 pub include_headers: bool,
3384 pub retry_message: String,
3386}
3387
3388impl Default for RateLimitConfig {
3389 fn default() -> Self {
3390 Self {
3391 max_requests: 100,
3392 window: Duration::from_secs(60),
3393 algorithm: RateLimitAlgorithm::TokenBucket,
3394 include_headers: true,
3395 retry_message: "Rate limit exceeded. Please retry later.".to_string(),
3396 }
3397 }
3398}
3399
3400pub struct RateLimitBuilder {
3402 config: RateLimitConfig,
3403 key_extractor: Option<Box<dyn KeyExtractor>>,
3404}
3405
3406impl RateLimitBuilder {
3407 #[must_use]
3409 pub fn new() -> Self {
3410 Self {
3411 config: RateLimitConfig::default(),
3412 key_extractor: None,
3413 }
3414 }
3415
3416 #[must_use]
3418 pub fn requests(mut self, max: u64) -> Self {
3419 self.config.max_requests = max;
3420 self
3421 }
3422
3423 #[must_use]
3425 pub fn per(mut self, window: Duration) -> Self {
3426 self.config.window = window;
3427 self
3428 }
3429
3430 #[must_use]
3432 pub fn per_second(self, secs: u64) -> Self {
3433 self.per(Duration::from_secs(secs))
3434 }
3435
3436 #[must_use]
3438 pub fn per_minute(self, minutes: u64) -> Self {
3439 self.per(Duration::from_secs(minutes * 60))
3440 }
3441
3442 #[must_use]
3444 pub fn per_hour(self, hours: u64) -> Self {
3445 self.per(Duration::from_secs(hours * 3600))
3446 }
3447
3448 #[must_use]
3450 pub fn algorithm(mut self, algo: RateLimitAlgorithm) -> Self {
3451 self.config.algorithm = algo;
3452 self
3453 }
3454
3455 #[must_use]
3457 pub fn key_extractor(mut self, extractor: impl KeyExtractor + 'static) -> Self {
3458 self.key_extractor = Some(Box::new(extractor));
3459 self
3460 }
3461
3462 #[must_use]
3464 pub fn include_headers(mut self, include: bool) -> Self {
3465 self.config.include_headers = include;
3466 self
3467 }
3468
3469 #[must_use]
3471 pub fn retry_message(mut self, msg: impl Into<String>) -> Self {
3472 self.config.retry_message = msg.into();
3473 self
3474 }
3475
3476 #[must_use]
3478 pub fn build(self) -> RateLimitMiddleware {
3479 let key_extractor = self
3480 .key_extractor
3481 .unwrap_or_else(|| Box::new(IpKeyExtractor));
3482 RateLimitMiddleware {
3483 config: self.config,
3484 store: Arc::new(InMemoryRateLimitStore::new()),
3485 key_extractor: Arc::from(key_extractor),
3486 }
3487 }
3488}
3489
3490impl Default for RateLimitBuilder {
3491 fn default() -> Self {
3492 Self::new()
3493 }
3494}
3495
3496#[derive(Debug, Clone)]
3498struct RateLimitInfo {
3499 result: RateLimitResult,
3500}
3501
3502pub struct RateLimitMiddleware {
3525 config: RateLimitConfig,
3526 store: Arc<InMemoryRateLimitStore>,
3527 key_extractor: Arc<dyn KeyExtractor>,
3528}
3529
3530impl RateLimitMiddleware {
3531 #[must_use]
3533 pub fn new() -> Self {
3534 Self::builder().build()
3535 }
3536
3537 #[must_use]
3539 pub fn builder() -> RateLimitBuilder {
3540 RateLimitBuilder::new()
3541 }
3542
3543 fn too_many_requests_body(&self, result: &RateLimitResult) -> Vec<u8> {
3545 format!(
3546 r#"{{"detail":"{}","retry_after_secs":{}}}"#,
3547 self.config.retry_message, result.reset_after_secs
3548 )
3549 .into_bytes()
3550 }
3551
3552 fn add_headers(&self, response: Response, result: &RateLimitResult) -> Response {
3554 response
3555 .header("X-RateLimit-Limit", result.limit.to_string().into_bytes())
3556 .header(
3557 "X-RateLimit-Remaining",
3558 result.remaining.to_string().into_bytes(),
3559 )
3560 .header(
3561 "X-RateLimit-Reset",
3562 result.reset_after_secs.to_string().into_bytes(),
3563 )
3564 }
3565}
3566
3567impl Default for RateLimitMiddleware {
3568 fn default() -> Self {
3569 Self::new()
3570 }
3571}
3572
3573impl Middleware for RateLimitMiddleware {
3574 fn before<'a>(
3575 &'a self,
3576 _ctx: &'a RequestContext,
3577 req: &'a mut Request,
3578 ) -> BoxFuture<'a, ControlFlow> {
3579 Box::pin(async move {
3580 let Some(key) = self.key_extractor.extract_key(req) else {
3582 return ControlFlow::Continue;
3584 };
3585
3586 let result = self.store.check(
3588 &key,
3589 self.config.algorithm,
3590 self.config.max_requests,
3591 self.config.window,
3592 );
3593
3594 if result.allowed {
3595 req.insert_extension(RateLimitInfo { result });
3597 ControlFlow::Continue
3598 } else {
3599 let body = self.too_many_requests_body(&result);
3601 let mut response =
3602 Response::with_status(crate::response::StatusCode::TOO_MANY_REQUESTS)
3603 .header("Content-Type", b"application/json".to_vec())
3604 .header(
3605 "Retry-After",
3606 result.reset_after_secs.to_string().into_bytes(),
3607 )
3608 .body(crate::response::ResponseBody::Bytes(body));
3609
3610 if self.config.include_headers {
3611 response = self.add_headers(response, &result);
3612 }
3613
3614 ControlFlow::Break(response)
3615 }
3616 })
3617 }
3618
3619 fn after<'a>(
3620 &'a self,
3621 _ctx: &'a RequestContext,
3622 req: &'a Request,
3623 response: Response,
3624 ) -> BoxFuture<'a, Response> {
3625 Box::pin(async move {
3626 if !self.config.include_headers {
3627 return response;
3628 }
3629
3630 if let Some(info) = req.get_extension::<RateLimitInfo>() {
3632 self.add_headers(response, &info.result)
3633 } else {
3634 response
3635 }
3636 })
3637 }
3638
3639 fn name(&self) -> &'static str {
3640 "RateLimit"
3641 }
3642}
3643
3644#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3656pub enum InspectionVerbosity {
3657 Minimal,
3661
3662 Normal,
3666
3667 Verbose,
3672}
3673
3674pub struct RequestInspectionMiddleware {
3713 log_config: LogConfig,
3714 verbosity: InspectionVerbosity,
3715 redact_headers: HashSet<String>,
3716 slow_threshold_ms: u64,
3717 max_body_preview: usize,
3718}
3719
3720impl Default for RequestInspectionMiddleware {
3721 fn default() -> Self {
3722 Self {
3723 log_config: LogConfig::development(),
3724 verbosity: InspectionVerbosity::Normal,
3725 redact_headers: default_redacted_headers(),
3726 slow_threshold_ms: 1000,
3727 max_body_preview: 2048,
3728 }
3729 }
3730}
3731
3732impl RequestInspectionMiddleware {
3733 #[must_use]
3735 pub fn new() -> Self {
3736 Self::default()
3737 }
3738
3739 #[must_use]
3741 pub fn log_config(mut self, config: LogConfig) -> Self {
3742 self.log_config = config;
3743 self
3744 }
3745
3746 #[must_use]
3748 pub fn verbosity(mut self, level: InspectionVerbosity) -> Self {
3749 self.verbosity = level;
3750 self
3751 }
3752
3753 #[must_use]
3755 pub fn slow_threshold_ms(mut self, ms: u64) -> Self {
3756 self.slow_threshold_ms = ms;
3757 self
3758 }
3759
3760 #[must_use]
3762 pub fn max_body_preview(mut self, max: usize) -> Self {
3763 self.max_body_preview = max;
3764 self
3765 }
3766
3767 #[must_use]
3769 pub fn redact_header(mut self, name: impl Into<String>) -> Self {
3770 self.redact_headers.insert(name.into().to_ascii_lowercase());
3771 self
3772 }
3773
3774 fn format_body_preview(&self, bytes: &[u8], content_type: Option<&[u8]>) -> Option<String> {
3776 if bytes.is_empty() || self.max_body_preview == 0 {
3777 return None;
3778 }
3779
3780 let is_json = content_type
3781 .and_then(|ct| std::str::from_utf8(ct).ok())
3782 .is_some_and(|ct| ct.contains("application/json"));
3783
3784 let limit = self.max_body_preview.min(bytes.len());
3785 let truncated = bytes.len() > self.max_body_preview;
3786
3787 match std::str::from_utf8(&bytes[..limit]) {
3788 Ok(text) => {
3789 if is_json {
3790 if let Some(pretty) = try_pretty_json(text) {
3792 let mut output = pretty;
3793 if truncated {
3794 output.push_str("\n ... (truncated)");
3795 }
3796 return Some(output);
3797 }
3798 }
3799 let mut output = text.to_string();
3800 if truncated {
3801 output.push_str("...");
3802 }
3803 Some(output)
3804 }
3805 Err(_) => Some(format!("<{} bytes binary>", bytes.len())),
3806 }
3807 }
3808
3809 fn format_response_preview(
3811 &self,
3812 body: &crate::response::ResponseBody,
3813 content_type: Option<&[u8]>,
3814 ) -> Option<String> {
3815 match body {
3816 crate::response::ResponseBody::Empty => None,
3817 crate::response::ResponseBody::Bytes(bytes) => {
3818 self.format_body_preview(bytes, content_type)
3819 }
3820 crate::response::ResponseBody::Stream(_) => Some("<streaming body>".to_string()),
3821 }
3822 }
3823
3824 fn format_inspection_headers<'a>(
3826 &self,
3827 headers: impl Iterator<Item = (&'a str, &'a [u8])>,
3828 ) -> String {
3829 let mut out = String::new();
3830 for (name, value) in headers {
3831 out.push_str("\n ");
3832 out.push_str(name);
3833 out.push_str(": ");
3834
3835 let lowered = name.to_ascii_lowercase();
3836 if self.redact_headers.contains(&lowered) {
3837 out.push_str("[REDACTED]");
3838 } else {
3839 match std::str::from_utf8(value) {
3840 Ok(text) => out.push_str(text),
3841 Err(_) => out.push_str("<binary>"),
3842 }
3843 }
3844 }
3845 out
3846 }
3847
3848 fn format_response_inspection_headers(&self, headers: &[(String, Vec<u8>)]) -> String {
3850 self.format_inspection_headers(
3851 headers
3852 .iter()
3853 .map(|(name, value)| (name.as_str(), value.as_slice())),
3854 )
3855 }
3856}
3857
3858#[derive(Debug, Clone)]
3860struct InspectionStart(Instant);
3861
3862impl Middleware for RequestInspectionMiddleware {
3863 fn before<'a>(
3864 &'a self,
3865 ctx: &'a RequestContext,
3866 req: &'a mut Request,
3867 ) -> BoxFuture<'a, ControlFlow> {
3868 let logger = RequestLogger::new(ctx, self.log_config.clone());
3869 req.insert_extension(InspectionStart(Instant::now()));
3870
3871 let method = req.method();
3872 let path = req.path();
3873 let query = req.query();
3874
3875 let mut request_line = format!("--> {method} {path}");
3877 if let Some(q) = query {
3878 request_line.push('?');
3879 request_line.push_str(q);
3880 }
3881
3882 let body_size = body_len(req.body());
3883 if body_size > 0 {
3884 request_line.push_str(&format!(" ({body_size} bytes)"));
3885 }
3886
3887 match self.verbosity {
3888 InspectionVerbosity::Minimal => {
3889 logger.info(request_line);
3890 }
3891 InspectionVerbosity::Normal => {
3892 let headers = self.format_inspection_headers(req.headers().iter());
3893 logger.info(format!("{request_line}{headers}"));
3894 }
3895 InspectionVerbosity::Verbose => {
3896 let headers = self.format_inspection_headers(req.headers().iter());
3897 let content_type = req.headers().get("content-type");
3898 let body_preview = match req.body() {
3899 Body::Empty => None,
3900 Body::Bytes(bytes) => self.format_body_preview(bytes, content_type),
3901 Body::Stream(_) => Some("<streaming body>".to_string()),
3902 };
3903
3904 let mut output = format!("{request_line}{headers}");
3905 if let Some(body) = body_preview {
3906 output.push_str("\n ");
3907 output.push_str(&body.replace('\n', "\n "));
3909 }
3910 logger.info(output);
3911 }
3912 }
3913
3914 Box::pin(async { ControlFlow::Continue })
3915 }
3916
3917 fn after<'a>(
3918 &'a self,
3919 ctx: &'a RequestContext,
3920 req: &'a Request,
3921 response: Response,
3922 ) -> BoxFuture<'a, Response> {
3923 let logger = RequestLogger::new(ctx, self.log_config.clone());
3924 let duration = req
3925 .get_extension::<InspectionStart>()
3926 .map(|start| start.0.elapsed())
3927 .unwrap_or_default();
3928
3929 let status = response.status();
3930 let duration_ms = duration.as_millis();
3931
3932 let mut response_line = format!(
3934 "<-- {} {} ({duration_ms}ms)",
3935 status.as_u16(),
3936 status.canonical_reason(),
3937 );
3938
3939 if duration_ms >= u128::from(self.slow_threshold_ms) {
3941 response_line.push_str(" [SLOW]");
3942 }
3943
3944 match self.verbosity {
3945 InspectionVerbosity::Minimal => {
3946 if duration_ms >= u128::from(self.slow_threshold_ms) {
3947 logger.warn(response_line);
3948 } else {
3949 logger.info(response_line);
3950 }
3951 }
3952 InspectionVerbosity::Normal => {
3953 let headers = self.format_response_inspection_headers(response.headers());
3954 let output = format!("{response_line}{headers}");
3955 if duration_ms >= u128::from(self.slow_threshold_ms) {
3956 logger.warn(output);
3957 } else {
3958 logger.info(output);
3959 }
3960 }
3961 InspectionVerbosity::Verbose => {
3962 let headers = self.format_response_inspection_headers(response.headers());
3963
3964 let resp_content_type: Option<&[u8]> = response
3966 .headers()
3967 .iter()
3968 .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
3969 .map(|(_, value)| value.as_slice());
3970
3971 let body_preview =
3972 self.format_response_preview(response.body_ref(), resp_content_type);
3973
3974 let mut output = format!("{response_line}{headers}");
3975 if let Some(body) = body_preview {
3976 output.push_str("\n ");
3977 output.push_str(&body.replace('\n', "\n "));
3978 }
3979
3980 if duration_ms >= u128::from(self.slow_threshold_ms) {
3981 logger.warn(output);
3982 } else {
3983 logger.info(output);
3984 }
3985 }
3986 }
3987
3988 Box::pin(async move { response })
3989 }
3990
3991 fn name(&self) -> &'static str {
3992 "RequestInspection"
3993 }
3994}
3995
3996fn try_pretty_json(input: &str) -> Option<String> {
4001 let trimmed = input.trim();
4002 if !trimmed.starts_with('{') && !trimmed.starts_with('[') {
4003 return None;
4004 }
4005
4006 let mut output = String::with_capacity(trimmed.len() * 2);
4008 if json_pretty_format(trimmed, &mut output).is_ok() {
4009 Some(output)
4010 } else {
4011 None
4012 }
4013}
4014
4015fn json_pretty_format(input: &str, output: &mut String) -> Result<(), ()> {
4020 let bytes = input.as_bytes();
4021 let mut pos = 0;
4022 let mut indent: usize = 0;
4023 let mut in_string = false;
4024 let mut escape_next = false;
4025
4026 while pos < bytes.len() {
4027 let ch = bytes[pos] as char;
4028
4029 if escape_next {
4030 output.push(ch);
4031 escape_next = false;
4032 pos += 1;
4033 continue;
4034 }
4035
4036 if in_string {
4037 output.push(ch);
4038 if ch == '\\' {
4039 escape_next = true;
4040 } else if ch == '"' {
4041 in_string = false;
4042 }
4043 pos += 1;
4044 continue;
4045 }
4046
4047 match ch {
4048 '"' => {
4049 in_string = true;
4050 output.push('"');
4051 }
4052 '{' | '[' => {
4053 output.push(ch);
4054 let peek = skip_whitespace(bytes, pos + 1);
4056 let closing = if ch == '{' { '}' } else { ']' };
4057 if peek < bytes.len() && bytes[peek] as char == closing {
4058 output.push(closing);
4059 pos = peek + 1;
4060 continue;
4061 }
4062 indent += 1;
4063 output.push('\n');
4064 push_indent(output, indent);
4065 }
4066 '}' | ']' => {
4067 indent = indent.saturating_sub(1);
4068 output.push('\n');
4069 push_indent(output, indent);
4070 output.push(ch);
4071 }
4072 ':' => {
4073 output.push_str(": ");
4074 }
4075 ',' => {
4076 output.push(',');
4077 output.push('\n');
4078 push_indent(output, indent);
4079 }
4080 c if c.is_ascii_whitespace() => {
4081 }
4083 _ => {
4084 output.push(ch);
4085 }
4086 }
4087
4088 pos += 1;
4089 }
4090
4091 if in_string || indent != 0 {
4092 return Err(());
4093 }
4094
4095 Ok(())
4096}
4097
4098fn skip_whitespace(bytes: &[u8], start: usize) -> usize {
4099 let mut i = start;
4100 while i < bytes.len() && (bytes[i] as char).is_ascii_whitespace() {
4101 i += 1;
4102 }
4103 i
4104}
4105
4106fn push_indent(output: &mut String, level: usize) {
4107 for _ in 0..level {
4108 output.push_str(" ");
4109 }
4110}
4111
4112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4122pub enum ETagMode {
4123 Auto,
4126 Manual,
4129 Disabled,
4131}
4132
4133impl Default for ETagMode {
4134 fn default() -> Self {
4135 Self::Auto
4136 }
4137}
4138
4139#[derive(Debug, Clone)]
4141pub struct ETagConfig {
4142 pub mode: ETagMode,
4144 pub weak: bool,
4147 pub min_size: usize,
4150}
4151
4152impl Default for ETagConfig {
4153 fn default() -> Self {
4154 Self {
4155 mode: ETagMode::Auto,
4156 weak: false,
4157 min_size: 0,
4158 }
4159 }
4160}
4161
4162impl ETagConfig {
4163 #[must_use]
4165 pub fn new() -> Self {
4166 Self::default()
4167 }
4168
4169 #[must_use]
4171 pub fn mode(mut self, mode: ETagMode) -> Self {
4172 self.mode = mode;
4173 self
4174 }
4175
4176 #[must_use]
4178 pub fn weak(mut self, weak: bool) -> Self {
4179 self.weak = weak;
4180 self
4181 }
4182
4183 #[must_use]
4185 pub fn min_size(mut self, size: usize) -> Self {
4186 self.min_size = size;
4187 self
4188 }
4189}
4190
4191pub struct ETagMiddleware {
4232 config: ETagConfig,
4233}
4234
4235impl Default for ETagMiddleware {
4236 fn default() -> Self {
4237 Self::new()
4238 }
4239}
4240
4241impl ETagMiddleware {
4242 #[must_use]
4244 pub fn new() -> Self {
4245 Self {
4246 config: ETagConfig::default(),
4247 }
4248 }
4249
4250 #[must_use]
4252 pub fn with_config(config: ETagConfig) -> Self {
4253 Self { config }
4254 }
4255
4256 fn generate_etag(data: &[u8], weak: bool) -> String {
4263 const FNV_OFFSET_BASIS: u64 = 0xcbf29ce484222325;
4265 const FNV_PRIME: u64 = 0x100000001b3;
4266
4267 let mut hash = FNV_OFFSET_BASIS;
4268 for &byte in data {
4269 hash ^= u64::from(byte);
4270 hash = hash.wrapping_mul(FNV_PRIME);
4271 }
4272
4273 if weak {
4275 format!("W/\"{:016x}\"", hash)
4276 } else {
4277 format!("\"{:016x}\"", hash)
4278 }
4279 }
4280
4281 fn parse_if_none_match(value: &str) -> Vec<String> {
4289 let trimmed = value.trim();
4290
4291 if trimmed == "*" {
4293 return vec!["*".to_string()];
4294 }
4295
4296 let mut etags = Vec::new();
4297 let mut current = String::new();
4298 let mut in_quote = false;
4299 let mut prev_char = '\0';
4300
4301 for ch in trimmed.chars() {
4302 match ch {
4303 '"' if prev_char != '\\' => {
4304 current.push(ch);
4305 if in_quote {
4306 let etag = current.trim().to_string();
4308 if !etag.is_empty() {
4309 etags.push(etag);
4310 }
4311 current.clear();
4312 }
4313 in_quote = !in_quote;
4314 }
4315 ',' if !in_quote => {
4316 current.clear();
4318 }
4319 _ => {
4320 current.push(ch);
4321 }
4322 }
4323 prev_char = ch;
4324 }
4325
4326 etags
4327 }
4328
4329 fn etags_match_weak(etag1: &str, etag2: &str) -> bool {
4337 let e1 = Self::strip_weak_prefix(etag1);
4339 let e2 = Self::strip_weak_prefix(etag2);
4340 e1 == e2
4341 }
4342
4343 fn strip_weak_prefix(s: &str) -> &str {
4345 if s.starts_with("W/") || s.starts_with("w/") {
4346 &s[2..]
4347 } else {
4348 s
4349 }
4350 }
4351
4352 fn is_cacheable_method(method: crate::request::Method) -> bool {
4354 matches!(
4355 method,
4356 crate::request::Method::Get | crate::request::Method::Head
4357 )
4358 }
4359
4360 fn get_existing_etag(headers: &[(String, Vec<u8>)]) -> Option<String> {
4362 for (name, value) in headers {
4363 if name.eq_ignore_ascii_case("etag") {
4364 return std::str::from_utf8(value).ok().map(String::from);
4365 }
4366 }
4367 None
4368 }
4369}
4370
4371impl Middleware for ETagMiddleware {
4372 fn after<'a>(
4373 &'a self,
4374 _ctx: &'a RequestContext,
4375 req: &'a Request,
4376 response: Response,
4377 ) -> BoxFuture<'a, Response> {
4378 let config = self.config.clone();
4379
4380 Box::pin(async move {
4381 if config.mode == ETagMode::Disabled {
4383 return response;
4384 }
4385
4386 if !Self::is_cacheable_method(req.method()) {
4388 return response;
4389 }
4390
4391 let (status, headers, body) = response.into_parts();
4393
4394 let existing_etag = Self::get_existing_etag(&headers);
4396
4397 let body_bytes = match &body {
4399 crate::response::ResponseBody::Bytes(bytes) => Some(bytes.clone()),
4400 crate::response::ResponseBody::Empty => Some(Vec::new()),
4401 crate::response::ResponseBody::Stream(_) => None,
4402 };
4403
4404 let etag = if let Some(existing) = existing_etag {
4406 Some(existing)
4407 } else if config.mode == ETagMode::Auto {
4408 if let Some(ref bytes) = body_bytes {
4409 if bytes.len() >= config.min_size {
4410 Some(Self::generate_etag(bytes, config.weak))
4411 } else {
4412 None
4413 }
4414 } else {
4415 None
4416 }
4417 } else {
4418 None
4419 };
4420
4421 if let Some(ref etag_value) = etag {
4423 if let Some(if_none_match) = req.headers().get("if-none-match") {
4424 if let Ok(value) = std::str::from_utf8(if_none_match) {
4425 let client_etags = Self::parse_if_none_match(value);
4426
4427 let matches = client_etags.iter().any(|client_etag| {
4429 client_etag == "*" || Self::etags_match_weak(client_etag, etag_value)
4430 });
4431
4432 if matches {
4433 return Response::with_status(
4435 crate::response::StatusCode::NOT_MODIFIED,
4436 )
4437 .header("etag", etag_value.as_bytes().to_vec());
4438 }
4439 }
4440 }
4441 }
4442
4443 let mut new_response = Response::with_status(status)
4445 .body(body)
4446 .rebuild_with_headers(headers);
4447
4448 if let Some(etag_value) = etag {
4449 new_response = new_response.header("etag", etag_value.into_bytes());
4450 }
4451
4452 new_response
4453 })
4454 }
4455
4456 fn name(&self) -> &'static str {
4457 "ETagMiddleware"
4458 }
4459}
4460
4461#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4470pub enum CacheDirective {
4471 Public,
4473 Private,
4475 NoStore,
4477 NoCache,
4479 NoTransform,
4481 MustRevalidate,
4483 ProxyRevalidate,
4485 StaleIfError,
4487 StaleWhileRevalidate,
4489 SMaxAge,
4491 OnlyIfCached,
4493 Immutable,
4495}
4496
4497impl CacheDirective {
4498 fn as_str(self) -> &'static str {
4500 match self {
4501 Self::Public => "public",
4502 Self::Private => "private",
4503 Self::NoStore => "no-store",
4504 Self::NoCache => "no-cache",
4505 Self::NoTransform => "no-transform",
4506 Self::MustRevalidate => "must-revalidate",
4507 Self::ProxyRevalidate => "proxy-revalidate",
4508 Self::StaleIfError => "stale-if-error",
4509 Self::StaleWhileRevalidate => "stale-while-revalidate",
4510 Self::SMaxAge => "s-maxage",
4511 Self::OnlyIfCached => "only-if-cached",
4512 Self::Immutable => "immutable",
4513 }
4514 }
4515}
4516
4517#[derive(Debug, Clone, Default)]
4547pub struct CacheControlBuilder {
4548 directives: Vec<CacheDirective>,
4549 max_age: Option<u32>,
4550 s_maxage: Option<u32>,
4551 stale_while_revalidate: Option<u32>,
4552 stale_if_error: Option<u32>,
4553}
4554
4555impl CacheControlBuilder {
4556 #[must_use]
4558 pub fn new() -> Self {
4559 Self::default()
4560 }
4561
4562 #[must_use]
4564 pub fn public(mut self) -> Self {
4565 self.directives.push(CacheDirective::Public);
4566 self
4567 }
4568
4569 #[must_use]
4571 pub fn private(mut self) -> Self {
4572 self.directives.push(CacheDirective::Private);
4573 self
4574 }
4575
4576 #[must_use]
4578 pub fn no_store(mut self) -> Self {
4579 self.directives.push(CacheDirective::NoStore);
4580 self
4581 }
4582
4583 #[must_use]
4585 pub fn no_cache(mut self) -> Self {
4586 self.directives.push(CacheDirective::NoCache);
4587 self
4588 }
4589
4590 #[must_use]
4592 pub fn no_transform(mut self) -> Self {
4593 self.directives.push(CacheDirective::NoTransform);
4594 self
4595 }
4596
4597 #[must_use]
4599 pub fn must_revalidate(mut self) -> Self {
4600 self.directives.push(CacheDirective::MustRevalidate);
4601 self
4602 }
4603
4604 #[must_use]
4606 pub fn proxy_revalidate(mut self) -> Self {
4607 self.directives.push(CacheDirective::ProxyRevalidate);
4608 self
4609 }
4610
4611 #[must_use]
4613 pub fn immutable(mut self) -> Self {
4614 self.directives.push(CacheDirective::Immutable);
4615 self
4616 }
4617
4618 #[must_use]
4620 pub fn max_age_secs(mut self, seconds: u32) -> Self {
4621 self.max_age = Some(seconds);
4622 self
4623 }
4624
4625 #[must_use]
4627 pub fn max_age(self, duration: std::time::Duration) -> Self {
4628 self.max_age_secs(duration.as_secs() as u32)
4629 }
4630
4631 #[must_use]
4633 pub fn s_maxage_secs(mut self, seconds: u32) -> Self {
4634 self.s_maxage = Some(seconds);
4635 self
4636 }
4637
4638 #[must_use]
4640 pub fn s_maxage(self, duration: std::time::Duration) -> Self {
4641 self.s_maxage_secs(duration.as_secs() as u32)
4642 }
4643
4644 #[must_use]
4646 pub fn stale_while_revalidate_secs(mut self, seconds: u32) -> Self {
4647 self.stale_while_revalidate = Some(seconds);
4648 self
4649 }
4650
4651 #[must_use]
4653 pub fn stale_if_error_secs(mut self, seconds: u32) -> Self {
4654 self.stale_if_error = Some(seconds);
4655 self
4656 }
4657
4658 #[must_use]
4660 pub fn build(&self) -> String {
4661 let mut parts = Vec::new();
4662
4663 for directive in &self.directives {
4665 parts.push(directive.as_str().to_string());
4666 }
4667
4668 if let Some(age) = self.max_age {
4670 parts.push(format!("max-age={age}"));
4671 }
4672
4673 if let Some(age) = self.s_maxage {
4675 parts.push(format!("s-maxage={age}"));
4676 }
4677
4678 if let Some(seconds) = self.stale_while_revalidate {
4680 parts.push(format!("stale-while-revalidate={seconds}"));
4681 }
4682
4683 if let Some(seconds) = self.stale_if_error {
4685 parts.push(format!("stale-if-error={seconds}"));
4686 }
4687
4688 parts.join(", ")
4689 }
4690
4691 #[must_use]
4693 pub fn is_no_cache(&self) -> bool {
4694 self.directives.contains(&CacheDirective::NoStore)
4695 || self.directives.contains(&CacheDirective::NoCache)
4696 }
4697}
4698
4699#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4701pub enum CachePreset {
4702 NoCache,
4704 PrivateNoCache,
4706 PublicOneHour,
4708 Immutable,
4710 CdnFriendly,
4712 StaticAssets,
4714}
4715
4716impl CachePreset {
4717 #[must_use]
4719 pub fn to_header_value(&self) -> String {
4720 match self {
4721 Self::NoCache => "no-store, no-cache, must-revalidate".to_string(),
4722 Self::PrivateNoCache => "private, max-age=0, must-revalidate".to_string(),
4723 Self::PublicOneHour => "public, max-age=3600".to_string(),
4724 Self::Immutable => "public, max-age=31536000, immutable".to_string(),
4725 Self::CdnFriendly => "public, max-age=60, s-maxage=3600".to_string(),
4726 Self::StaticAssets => "public, max-age=86400".to_string(),
4727 }
4728 }
4729
4730 #[must_use]
4732 pub fn to_builder(&self) -> CacheControlBuilder {
4733 match self {
4734 Self::NoCache => CacheControlBuilder::new()
4735 .no_store()
4736 .no_cache()
4737 .must_revalidate(),
4738 Self::PrivateNoCache => CacheControlBuilder::new()
4739 .private()
4740 .max_age_secs(0)
4741 .must_revalidate(),
4742 Self::PublicOneHour => CacheControlBuilder::new().public().max_age_secs(3600),
4743 Self::Immutable => CacheControlBuilder::new()
4744 .public()
4745 .max_age_secs(31536000)
4746 .immutable(),
4747 Self::CdnFriendly => CacheControlBuilder::new()
4748 .public()
4749 .max_age_secs(60)
4750 .s_maxage_secs(3600),
4751 Self::StaticAssets => CacheControlBuilder::new().public().max_age_secs(86400),
4752 }
4753 }
4754}
4755
4756#[derive(Debug, Clone)]
4758pub struct CacheControlConfig {
4759 pub cache_control: String,
4761 pub vary: Vec<String>,
4763 pub set_expires: bool,
4765 pub preserve_existing: bool,
4767 pub methods: Vec<crate::request::Method>,
4769 pub path_patterns: Vec<String>,
4771 pub cacheable_statuses: Vec<u16>,
4773}
4774
4775impl Default for CacheControlConfig {
4776 fn default() -> Self {
4777 Self {
4778 cache_control: CachePreset::NoCache.to_header_value(),
4779 vary: Vec::new(),
4780 set_expires: false,
4781 preserve_existing: true,
4782 methods: vec![crate::request::Method::Get, crate::request::Method::Head],
4783 path_patterns: Vec::new(),
4784 cacheable_statuses: (200..300).collect(),
4785 }
4786 }
4787}
4788
4789impl CacheControlConfig {
4790 #[must_use]
4792 pub fn new() -> Self {
4793 Self::default()
4794 }
4795
4796 #[must_use]
4798 pub fn from_preset(preset: CachePreset) -> Self {
4799 Self {
4800 cache_control: preset.to_header_value(),
4801 ..Self::default()
4802 }
4803 }
4804
4805 #[must_use]
4807 pub fn from_builder(builder: CacheControlBuilder) -> Self {
4808 Self {
4809 cache_control: builder.build(),
4810 ..Self::default()
4811 }
4812 }
4813
4814 #[must_use]
4816 pub fn cache_control(mut self, value: impl Into<String>) -> Self {
4817 self.cache_control = value.into();
4818 self
4819 }
4820
4821 #[must_use]
4823 pub fn vary(mut self, header: impl Into<String>) -> Self {
4824 self.vary.push(header.into());
4825 self
4826 }
4827
4828 #[must_use]
4830 pub fn vary_headers(mut self, headers: Vec<String>) -> Self {
4831 self.vary.extend(headers);
4832 self
4833 }
4834
4835 #[must_use]
4837 pub fn with_expires(mut self, enable: bool) -> Self {
4838 self.set_expires = enable;
4839 self
4840 }
4841
4842 #[must_use]
4844 pub fn preserve_existing(mut self, preserve: bool) -> Self {
4845 self.preserve_existing = preserve;
4846 self
4847 }
4848
4849 #[must_use]
4851 pub fn methods(mut self, methods: Vec<crate::request::Method>) -> Self {
4852 self.methods = methods;
4853 self
4854 }
4855
4856 #[must_use]
4858 pub fn path_patterns(mut self, patterns: Vec<String>) -> Self {
4859 self.path_patterns = patterns;
4860 self
4861 }
4862
4863 #[must_use]
4865 pub fn cacheable_statuses(mut self, statuses: Vec<u16>) -> Self {
4866 self.cacheable_statuses = statuses;
4867 self
4868 }
4869}
4870
4871pub struct CacheControlMiddleware {
4918 config: CacheControlConfig,
4919}
4920
4921impl Default for CacheControlMiddleware {
4922 fn default() -> Self {
4923 Self::new()
4924 }
4925}
4926
4927impl CacheControlMiddleware {
4928 #[must_use]
4932 pub fn new() -> Self {
4933 Self {
4934 config: CacheControlConfig::default(),
4935 }
4936 }
4937
4938 #[must_use]
4940 pub fn with_preset(preset: CachePreset) -> Self {
4941 Self {
4942 config: CacheControlConfig::from_preset(preset),
4943 }
4944 }
4945
4946 #[must_use]
4948 pub fn with_config(config: CacheControlConfig) -> Self {
4949 Self { config }
4950 }
4951
4952 fn is_cacheable_method(&self, method: crate::request::Method) -> bool {
4954 self.config.methods.contains(&method)
4955 }
4956
4957 fn is_cacheable_status(&self, status: u16) -> bool {
4959 self.config.cacheable_statuses.contains(&status)
4960 }
4961
4962 fn matches_path(&self, path: &str) -> bool {
4964 if self.config.path_patterns.is_empty() {
4965 return true; }
4967
4968 for pattern in &self.config.path_patterns {
4969 if path_matches_pattern(path, pattern) {
4970 return true;
4971 }
4972 }
4973 false
4974 }
4975
4976 fn has_cache_control(headers: &[(String, Vec<u8>)]) -> bool {
4978 headers
4979 .iter()
4980 .any(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
4981 }
4982
4983 fn calculate_expires(cache_control: &str) -> Option<String> {
4985 for directive in cache_control.split(',') {
4987 let directive = directive.trim();
4988 if directive.starts_with("max-age=") {
4989 if let Ok(seconds) = directive[8..].parse::<u64>() {
4990 let now = std::time::SystemTime::now();
4992 if let Some(expires) = now.checked_add(std::time::Duration::from_secs(seconds))
4993 {
4994 return Some(format_http_date(expires));
4995 }
4996 }
4997 }
4998 }
4999 None
5000 }
5001}
5002
5003fn path_matches_pattern(path: &str, pattern: &str) -> bool {
5005 if pattern == "*" {
5006 return true;
5007 }
5008
5009 if pattern.contains('*') {
5010 let parts: Vec<&str> = pattern.split('*').collect();
5012 if parts.len() == 2 {
5013 let (prefix, suffix) = (parts[0], parts[1]);
5014 return path.starts_with(prefix) && path.ends_with(suffix);
5015 }
5016 let fixed_parts: Vec<&str> = pattern.split('*').filter(|s| !s.is_empty()).collect();
5018 let mut remaining = path;
5019 for part in fixed_parts {
5020 if let Some(pos) = remaining.find(part) {
5021 remaining = &remaining[pos + part.len()..];
5022 } else {
5023 return false;
5024 }
5025 }
5026 true
5027 } else {
5028 path == pattern
5029 }
5030}
5031
5032fn format_http_date(time: std::time::SystemTime) -> String {
5034 match time.duration_since(std::time::UNIX_EPOCH) {
5036 Ok(duration) => {
5037 let secs = duration.as_secs();
5039 let days = secs / 86400;
5041 let remaining_secs = secs % 86400;
5042 let hours = remaining_secs / 3600;
5043 let minutes = (remaining_secs % 3600) / 60;
5044 let seconds = remaining_secs % 60;
5045
5046 let day_of_week = ((days + 4) % 7) as usize;
5048 let day_names = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
5049
5050 let (year, month, day) = days_to_date(days);
5052 let month_names = [
5053 "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
5054 ];
5055
5056 format!(
5057 "{}, {:02} {} {} {:02}:{:02}:{:02} GMT",
5058 day_names[day_of_week],
5059 day,
5060 month_names[(month - 1) as usize],
5061 year,
5062 hours,
5063 minutes,
5064 seconds
5065 )
5066 }
5067 Err(_) => "Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
5068 }
5069}
5070
5071fn days_to_date(days: u64) -> (u64, u64, u64) {
5073 let mut remaining_days = days;
5075 let mut year = 1970u64;
5076
5077 loop {
5078 let days_in_year = if is_leap_year(year) { 366 } else { 365 };
5079 if remaining_days < days_in_year {
5080 break;
5081 }
5082 remaining_days -= days_in_year;
5083 year += 1;
5084 }
5085
5086 let leap = is_leap_year(year);
5087 let month_days: [u64; 12] = if leap {
5088 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
5089 } else {
5090 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
5091 };
5092
5093 let mut month = 1u64;
5094 for &days_in_month in &month_days {
5095 if remaining_days < days_in_month {
5096 break;
5097 }
5098 remaining_days -= days_in_month;
5099 month += 1;
5100 }
5101
5102 (year, month, remaining_days + 1)
5103}
5104
5105fn is_leap_year(year: u64) -> bool {
5107 (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
5108}
5109
5110impl Middleware for CacheControlMiddleware {
5111 fn after<'a>(
5112 &'a self,
5113 _ctx: &'a RequestContext,
5114 req: &'a Request,
5115 response: Response,
5116 ) -> BoxFuture<'a, Response> {
5117 let config = self.config.clone();
5118
5119 Box::pin(async move {
5120 if !self.is_cacheable_method(req.method()) {
5122 return response;
5123 }
5124
5125 if !self.is_cacheable_status(response.status().as_u16()) {
5126 return response;
5127 }
5128
5129 if !self.matches_path(req.path()) {
5130 return response;
5131 }
5132
5133 let (status, mut headers, body) = response.into_parts();
5135
5136 if config.preserve_existing && Self::has_cache_control(&headers) {
5138 let mut resp = Response::with_status(status);
5140 for (name, value) in headers {
5141 resp = resp.header(name, value);
5142 }
5143 return resp.body(body);
5144 }
5145
5146 headers.push((
5148 "Cache-Control".to_string(),
5149 config.cache_control.as_bytes().to_vec(),
5150 ));
5151
5152 if !config.vary.is_empty() {
5154 let vary_value = config.vary.join(", ");
5155 headers.push(("Vary".to_string(), vary_value.into_bytes()));
5156 }
5157
5158 if config.set_expires {
5160 if let Some(expires) = Self::calculate_expires(&config.cache_control) {
5161 headers.push(("Expires".to_string(), expires.into_bytes()));
5162 }
5163 }
5164
5165 let mut resp = Response::with_status(status);
5167 for (name, value) in headers {
5168 resp = resp.header(name, value);
5169 }
5170 resp.body(body)
5171 })
5172 }
5173
5174 fn name(&self) -> &'static str {
5175 "CacheControlMiddleware"
5176 }
5177}
5178
5179#[derive(Debug, Clone)]
5214pub struct TraceRejectionMiddleware {
5215 log_attempts: bool,
5217}
5218
5219impl Default for TraceRejectionMiddleware {
5220 fn default() -> Self {
5221 Self::new()
5222 }
5223}
5224
5225impl TraceRejectionMiddleware {
5226 #[must_use]
5230 pub fn new() -> Self {
5231 Self { log_attempts: true }
5232 }
5233
5234 #[must_use]
5239 pub fn log_attempts(mut self, log: bool) -> Self {
5240 self.log_attempts = log;
5241 self
5242 }
5243
5244 fn rejection_response(path: &str) -> Response {
5246 let body = format!(
5247 r#"{{"detail":"HTTP TRACE method is not allowed","path":"{}"}}"#,
5248 path.replace('"', "\\\"")
5249 );
5250 Response::with_status(crate::response::StatusCode::METHOD_NOT_ALLOWED)
5251 .header("Content-Type", b"application/json".to_vec())
5252 .header(
5253 "Allow",
5254 b"GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD".to_vec(),
5255 )
5256 .body(crate::response::ResponseBody::Bytes(body.into_bytes()))
5257 }
5258}
5259
5260impl Middleware for TraceRejectionMiddleware {
5261 fn before<'a>(
5262 &'a self,
5263 _ctx: &'a RequestContext,
5264 req: &'a mut Request,
5265 ) -> BoxFuture<'a, ControlFlow> {
5266 Box::pin(async move {
5267 if req.method() == crate::request::Method::Trace {
5268 if self.log_attempts {
5269 let path = req.path();
5271 let remote_ip = req
5272 .headers()
5273 .get("X-Forwarded-For")
5274 .or_else(|| req.headers().get("X-Real-IP"))
5275 .map(|v| String::from_utf8_lossy(v).to_string())
5276 .unwrap_or_else(|| "unknown".to_string());
5277
5278 eprintln!(
5279 "[SECURITY] TRACE request blocked: path={}, remote_ip={}",
5280 path, remote_ip
5281 );
5282 }
5283
5284 return ControlFlow::Break(Self::rejection_response(req.path()));
5285 }
5286
5287 ControlFlow::Continue
5288 })
5289 }
5290
5291 fn name(&self) -> &'static str {
5292 "TraceRejection"
5293 }
5294}
5295
5296#[derive(Debug, Clone)]
5306#[allow(clippy::struct_excessive_bools)]
5307pub struct HttpsRedirectConfig {
5308 pub redirect_enabled: bool,
5310 pub permanent_redirect: bool,
5312 pub hsts_max_age_secs: u64,
5314 pub hsts_include_subdomains: bool,
5316 pub hsts_preload: bool,
5318 pub exclude_paths: Vec<String>,
5320 pub https_port: u16,
5322}
5323
5324impl Default for HttpsRedirectConfig {
5325 fn default() -> Self {
5326 Self {
5327 redirect_enabled: true,
5328 permanent_redirect: true, hsts_max_age_secs: 31_536_000, hsts_include_subdomains: false,
5331 hsts_preload: false,
5332 exclude_paths: Vec::new(),
5333 https_port: 443,
5334 }
5335 }
5336}
5337
5338#[derive(Debug, Clone)]
5378pub struct HttpsRedirectMiddleware {
5379 config: HttpsRedirectConfig,
5380}
5381
5382impl Default for HttpsRedirectMiddleware {
5383 fn default() -> Self {
5384 Self::new()
5385 }
5386}
5387
5388impl HttpsRedirectMiddleware {
5389 #[must_use]
5391 pub fn new() -> Self {
5392 Self {
5393 config: HttpsRedirectConfig::default(),
5394 }
5395 }
5396
5397 #[must_use]
5399 pub fn redirect_enabled(mut self, enabled: bool) -> Self {
5400 self.config.redirect_enabled = enabled;
5401 self
5402 }
5403
5404 #[must_use]
5408 pub fn permanent_redirect(mut self, permanent: bool) -> Self {
5409 self.config.permanent_redirect = permanent;
5410 self
5411 }
5412
5413 #[must_use]
5418 pub fn hsts_max_age_secs(mut self, secs: u64) -> Self {
5419 self.config.hsts_max_age_secs = secs;
5420 self
5421 }
5422
5423 #[must_use]
5425 pub fn include_subdomains(mut self, include: bool) -> Self {
5426 self.config.hsts_include_subdomains = include;
5427 self
5428 }
5429
5430 #[must_use]
5435 pub fn preload(mut self, preload: bool) -> Self {
5436 self.config.hsts_preload = preload;
5437 self
5438 }
5439
5440 #[must_use]
5445 pub fn exclude_path(mut self, path: impl Into<String>) -> Self {
5446 self.config.exclude_paths.push(path.into());
5447 self
5448 }
5449
5450 #[must_use]
5452 pub fn exclude_paths(mut self, paths: Vec<String>) -> Self {
5453 self.config.exclude_paths = paths;
5454 self
5455 }
5456
5457 #[must_use]
5459 pub fn https_port(mut self, port: u16) -> Self {
5460 self.config.https_port = port;
5461 self
5462 }
5463
5464 fn is_secure(&self, req: &Request) -> bool {
5469 if let Some(proto) = req.headers().get("X-Forwarded-Proto") {
5471 return proto.eq_ignore_ascii_case(b"https");
5472 }
5473
5474 if let Some(ssl) = req.headers().get("X-Forwarded-Ssl") {
5476 return ssl.eq_ignore_ascii_case(b"on");
5477 }
5478
5479 if let Some(https) = req.headers().get("Front-End-Https") {
5481 return https.eq_ignore_ascii_case(b"on");
5482 }
5483
5484 false
5487 }
5488
5489 fn is_excluded(&self, path: &str) -> bool {
5491 self.config
5492 .exclude_paths
5493 .iter()
5494 .any(|p| path.starts_with(p))
5495 }
5496
5497 fn build_hsts_header(&self) -> Option<Vec<u8>> {
5499 if self.config.hsts_max_age_secs == 0 {
5500 return None;
5501 }
5502
5503 let mut value = format!("max-age={}", self.config.hsts_max_age_secs);
5504
5505 if self.config.hsts_include_subdomains {
5506 value.push_str("; includeSubDomains");
5507 }
5508
5509 if self.config.hsts_preload {
5510 value.push_str("; preload");
5511 }
5512
5513 Some(value.into_bytes())
5514 }
5515
5516 fn build_redirect_url(&self, req: &Request) -> String {
5518 let host = req
5519 .headers()
5520 .get("Host")
5521 .map(|h| String::from_utf8_lossy(h).to_string())
5522 .unwrap_or_else(|| "localhost".to_string());
5523
5524 let host_without_port = host.split(':').next().unwrap_or(&host);
5526
5527 let path = req.path();
5528 let query = req.query();
5529
5530 if self.config.https_port == 443 {
5531 match query {
5532 Some(q) => format!("https://{}{}?{}", host_without_port, path, q),
5533 None => format!("https://{}{}", host_without_port, path),
5534 }
5535 } else {
5536 match query {
5537 Some(q) => format!(
5538 "https://{}:{}{}?{}",
5539 host_without_port, self.config.https_port, path, q
5540 ),
5541 None => format!(
5542 "https://{}:{}{}",
5543 host_without_port, self.config.https_port, path
5544 ),
5545 }
5546 }
5547 }
5548}
5549
5550impl Middleware for HttpsRedirectMiddleware {
5551 fn before<'a>(
5552 &'a self,
5553 _ctx: &'a RequestContext,
5554 req: &'a mut Request,
5555 ) -> BoxFuture<'a, ControlFlow> {
5556 Box::pin(async move {
5557 if !self.config.redirect_enabled {
5559 return ControlFlow::Continue;
5560 }
5561
5562 if self.is_secure(req) {
5564 return ControlFlow::Continue;
5565 }
5566
5567 if self.is_excluded(req.path()) {
5569 return ControlFlow::Continue;
5570 }
5571
5572 let redirect_url = self.build_redirect_url(req);
5574
5575 let status = if self.config.permanent_redirect {
5577 crate::response::StatusCode::MOVED_PERMANENTLY
5578 } else {
5579 crate::response::StatusCode::TEMPORARY_REDIRECT
5580 };
5581
5582 let response = Response::with_status(status)
5584 .header("Location", redirect_url.into_bytes())
5585 .header("Content-Type", b"text/plain".to_vec())
5586 .body(crate::response::ResponseBody::Bytes(
5587 b"Redirecting to HTTPS...".to_vec(),
5588 ));
5589
5590 ControlFlow::Break(response)
5591 })
5592 }
5593
5594 fn after<'a>(
5595 &'a self,
5596 _ctx: &'a RequestContext,
5597 req: &'a Request,
5598 response: Response,
5599 ) -> BoxFuture<'a, Response> {
5600 Box::pin(async move {
5601 if !self.is_secure(req) {
5603 return response;
5604 }
5605
5606 if let Some(hsts_value) = self.build_hsts_header() {
5608 response.header("Strict-Transport-Security", hsts_value)
5609 } else {
5610 response
5611 }
5612 })
5613 }
5614
5615 fn name(&self) -> &'static str {
5616 "HttpsRedirect"
5617 }
5618}
5619
5620pub trait ResponseInterceptor: Send + Sync {
5659 fn intercept<'a>(
5670 &'a self,
5671 ctx: &'a ResponseInterceptorContext<'a>,
5672 response: Response,
5673 ) -> BoxFuture<'a, Response>;
5674
5675 fn name(&self) -> &'static str {
5677 std::any::type_name::<Self>()
5678 }
5679}
5680
5681#[derive(Debug)]
5686pub struct ResponseInterceptorContext<'a> {
5687 pub request: &'a Request,
5689 pub start_time: Instant,
5691 pub request_ctx: &'a RequestContext,
5693}
5694
5695impl<'a> ResponseInterceptorContext<'a> {
5696 pub fn new(request: &'a Request, request_ctx: &'a RequestContext, start_time: Instant) -> Self {
5698 Self {
5699 request,
5700 start_time,
5701 request_ctx,
5702 }
5703 }
5704
5705 pub fn elapsed(&self) -> std::time::Duration {
5707 self.start_time.elapsed()
5708 }
5709
5710 pub fn elapsed_ms(&self) -> u128 {
5712 self.start_time.elapsed().as_millis()
5713 }
5714}
5715
5716#[derive(Default)]
5731pub struct ResponseInterceptorStack {
5732 interceptors: Vec<Arc<dyn ResponseInterceptor>>,
5733}
5734
5735impl ResponseInterceptorStack {
5736 #[must_use]
5738 pub fn new() -> Self {
5739 Self {
5740 interceptors: Vec::new(),
5741 }
5742 }
5743
5744 #[must_use]
5746 pub fn with_capacity(capacity: usize) -> Self {
5747 Self {
5748 interceptors: Vec::with_capacity(capacity),
5749 }
5750 }
5751
5752 pub fn push<I: ResponseInterceptor + 'static>(&mut self, interceptor: I) {
5754 self.interceptors.push(Arc::new(interceptor));
5755 }
5756
5757 pub fn push_arc(&mut self, interceptor: Arc<dyn ResponseInterceptor>) {
5759 self.interceptors.push(interceptor);
5760 }
5761
5762 #[must_use]
5764 pub fn len(&self) -> usize {
5765 self.interceptors.len()
5766 }
5767
5768 #[must_use]
5770 pub fn is_empty(&self) -> bool {
5771 self.interceptors.is_empty()
5772 }
5773
5774 pub async fn process(
5776 &self,
5777 ctx: &ResponseInterceptorContext<'_>,
5778 mut response: Response,
5779 ) -> Response {
5780 for interceptor in &self.interceptors {
5781 let _ = ctx.request_ctx.checkpoint();
5782 response = interceptor.intercept(ctx, response).await;
5783 }
5784 response
5785 }
5786}
5787
5788#[derive(Debug, Clone)]
5805pub struct TimingInterceptor {
5806 header_name: String,
5808 include_server_timing: bool,
5810 server_timing_name: String,
5812}
5813
5814impl Default for TimingInterceptor {
5815 fn default() -> Self {
5816 Self::new()
5817 }
5818}
5819
5820impl TimingInterceptor {
5821 #[must_use]
5823 pub fn new() -> Self {
5824 Self {
5825 header_name: "X-Response-Time".to_string(),
5826 include_server_timing: false,
5827 server_timing_name: "total".to_string(),
5828 }
5829 }
5830
5831 #[must_use]
5833 pub fn with_server_timing(mut self, metric_name: impl Into<String>) -> Self {
5834 self.include_server_timing = true;
5835 self.server_timing_name = metric_name.into();
5836 self
5837 }
5838
5839 #[must_use]
5841 pub fn header_name(mut self, name: impl Into<String>) -> Self {
5842 self.header_name = name.into();
5843 self
5844 }
5845}
5846
5847impl ResponseInterceptor for TimingInterceptor {
5848 fn intercept<'a>(
5849 &'a self,
5850 ctx: &'a ResponseInterceptorContext<'a>,
5851 response: Response,
5852 ) -> BoxFuture<'a, Response> {
5853 Box::pin(async move {
5854 let elapsed_ms = ctx.elapsed_ms();
5855 let timing_value = format!("{}ms", elapsed_ms);
5856
5857 let response = response.header(&self.header_name, timing_value.clone().into_bytes());
5858
5859 if self.include_server_timing {
5860 let server_timing = format!("{};dur={}", self.server_timing_name, elapsed_ms);
5862 response.header("Server-Timing", server_timing.into_bytes())
5863 } else {
5864 response
5865 }
5866 })
5867 }
5868
5869 fn name(&self) -> &'static str {
5870 "TimingInterceptor"
5871 }
5872}
5873
5874#[derive(Debug, Clone)]
5898#[allow(clippy::struct_excessive_bools)]
5899pub struct DebugInfoInterceptor {
5900 include_path: bool,
5902 include_method: bool,
5904 include_request_id: bool,
5906 include_timing: bool,
5908 header_prefix: String,
5910}
5911
5912impl Default for DebugInfoInterceptor {
5913 fn default() -> Self {
5914 Self::new()
5915 }
5916}
5917
5918impl DebugInfoInterceptor {
5919 #[must_use]
5921 pub fn new() -> Self {
5922 Self {
5923 include_path: true,
5924 include_method: true,
5925 include_request_id: true,
5926 include_timing: true,
5927 header_prefix: "X-Debug-".to_string(),
5928 }
5929 }
5930
5931 #[must_use]
5933 pub fn include_path(mut self, include: bool) -> Self {
5934 self.include_path = include;
5935 self
5936 }
5937
5938 #[must_use]
5940 pub fn include_method(mut self, include: bool) -> Self {
5941 self.include_method = include;
5942 self
5943 }
5944
5945 #[must_use]
5947 pub fn include_request_id(mut self, include: bool) -> Self {
5948 self.include_request_id = include;
5949 self
5950 }
5951
5952 #[must_use]
5954 pub fn include_timing(mut self, include: bool) -> Self {
5955 self.include_timing = include;
5956 self
5957 }
5958
5959 #[must_use]
5961 pub fn header_prefix(mut self, prefix: impl Into<String>) -> Self {
5962 self.header_prefix = prefix.into();
5963 self
5964 }
5965}
5966
5967impl ResponseInterceptor for DebugInfoInterceptor {
5968 fn intercept<'a>(
5969 &'a self,
5970 ctx: &'a ResponseInterceptorContext<'a>,
5971 response: Response,
5972 ) -> BoxFuture<'a, Response> {
5973 Box::pin(async move {
5974 let mut resp = response;
5975
5976 if self.include_path {
5977 let header_name = format!("{}Path", self.header_prefix);
5978 resp = resp.header(header_name, ctx.request.path().as_bytes().to_vec());
5979 }
5980
5981 if self.include_method {
5982 let header_name = format!("{}Method", self.header_prefix);
5983 resp = resp.header(
5984 header_name,
5985 ctx.request.method().as_str().as_bytes().to_vec(),
5986 );
5987 }
5988
5989 if self.include_request_id {
5990 if let Some(request_id) = ctx.request.get_extension::<RequestId>() {
5991 let header_name = format!("{}Request-Id", self.header_prefix);
5992 resp = resp.header(header_name, request_id.0.as_bytes().to_vec());
5993 }
5994 }
5995
5996 if self.include_timing {
5997 let header_name = format!("{}Handler-Time", self.header_prefix);
5998 let timing = format!("{}ms", ctx.elapsed_ms());
5999 resp = resp.header(header_name, timing.into_bytes());
6000 }
6001
6002 resp
6003 })
6004 }
6005
6006 fn name(&self) -> &'static str {
6007 "DebugInfoInterceptor"
6008 }
6009}
6010
6011pub struct ResponseBodyTransform<F>
6032where
6033 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6034{
6035 transform_fn: F,
6036 content_type_filter: Option<String>,
6038}
6039
6040impl<F> ResponseBodyTransform<F>
6041where
6042 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6043{
6044 pub fn new(transform_fn: F) -> Self {
6046 Self {
6047 transform_fn,
6048 content_type_filter: None,
6049 }
6050 }
6051
6052 #[must_use]
6054 pub fn for_content_type(mut self, content_type: impl Into<String>) -> Self {
6055 self.content_type_filter = Some(content_type.into());
6056 self
6057 }
6058
6059 fn should_transform(&self, response: &Response) -> bool {
6060 match &self.content_type_filter {
6061 Some(filter) => response
6062 .headers()
6063 .iter()
6064 .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
6065 .and_then(|(_, ct)| std::str::from_utf8(ct).ok())
6066 .map(|ct| ct.starts_with(filter))
6067 .unwrap_or(false),
6068 None => true,
6069 }
6070 }
6071}
6072
6073impl<F> ResponseInterceptor for ResponseBodyTransform<F>
6074where
6075 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6076{
6077 fn intercept<'a>(
6078 &'a self,
6079 _ctx: &'a ResponseInterceptorContext<'a>,
6080 response: Response,
6081 ) -> BoxFuture<'a, Response> {
6082 Box::pin(async move {
6083 if !self.should_transform(&response) {
6084 return response;
6085 }
6086
6087 let body_bytes = match response.body_ref() {
6089 crate::response::ResponseBody::Empty => Vec::new(),
6090 crate::response::ResponseBody::Bytes(b) => b.clone(),
6091 crate::response::ResponseBody::Stream(_) => {
6092 return response;
6094 }
6095 };
6096
6097 let transformed = (self.transform_fn)(body_bytes);
6099
6100 response.body(crate::response::ResponseBody::Bytes(transformed))
6102 })
6103 }
6104
6105 fn name(&self) -> &'static str {
6106 "ResponseBodyTransform"
6107 }
6108}
6109
6110#[derive(Debug, Clone, Default)]
6127pub struct HeaderTransformInterceptor {
6128 add_headers: Vec<(String, Vec<u8>)>,
6130 remove_headers: Vec<String>,
6132 rename_headers: Vec<(String, String)>,
6134}
6135
6136impl HeaderTransformInterceptor {
6137 #[must_use]
6139 pub fn new() -> Self {
6140 Self::default()
6141 }
6142
6143 #[must_use]
6145 pub fn add(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
6146 self.add_headers.push((name.into(), value.into()));
6147 self
6148 }
6149
6150 #[must_use]
6152 pub fn remove(mut self, name: impl Into<String>) -> Self {
6153 self.remove_headers.push(name.into());
6154 self
6155 }
6156
6157 #[must_use]
6159 pub fn rename(mut self, old_name: impl Into<String>, new_name: impl Into<String>) -> Self {
6160 self.rename_headers.push((old_name.into(), new_name.into()));
6161 self
6162 }
6163}
6164
6165impl ResponseInterceptor for HeaderTransformInterceptor {
6166 fn intercept<'a>(
6167 &'a self,
6168 _ctx: &'a ResponseInterceptorContext<'a>,
6169 response: Response,
6170 ) -> BoxFuture<'a, Response> {
6171 let add_headers = self.add_headers.clone();
6172 let remove_headers = self.remove_headers.clone();
6173 let rename_headers = self.rename_headers.clone();
6174
6175 Box::pin(async move {
6176 let mut resp = response;
6177
6178 for (old_name, new_name) in &rename_headers {
6180 let header_value = resp
6181 .headers()
6182 .iter()
6183 .find(|(name, _)| name.eq_ignore_ascii_case(old_name))
6184 .map(|(_, v)| v.clone());
6185
6186 if let Some(value) = header_value {
6187 resp = resp.header(new_name, value);
6188 }
6191 }
6192
6193 for (name, value) in add_headers {
6195 resp = resp.header(name, value);
6196 }
6197
6198 let _ = remove_headers;
6201
6202 resp
6203 })
6204 }
6205
6206 fn name(&self) -> &'static str {
6207 "HeaderTransformInterceptor"
6208 }
6209}
6210
6211pub struct ConditionalInterceptor<I, F>
6227where
6228 I: ResponseInterceptor,
6229 F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6230{
6231 inner: I,
6232 condition: F,
6233}
6234
6235impl<I, F> ConditionalInterceptor<I, F>
6236where
6237 I: ResponseInterceptor,
6238 F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6239{
6240 pub fn new(inner: I, condition: F) -> Self {
6242 Self { inner, condition }
6243 }
6244}
6245
6246impl<I, F> ResponseInterceptor for ConditionalInterceptor<I, F>
6247where
6248 I: ResponseInterceptor,
6249 F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6250{
6251 fn intercept<'a>(
6252 &'a self,
6253 ctx: &'a ResponseInterceptorContext<'a>,
6254 response: Response,
6255 ) -> BoxFuture<'a, Response> {
6256 Box::pin(async move {
6257 if (self.condition)(ctx, &response) {
6258 self.inner.intercept(ctx, response).await
6259 } else {
6260 response
6261 }
6262 })
6263 }
6264
6265 fn name(&self) -> &'static str {
6266 "ConditionalInterceptor"
6267 }
6268}
6269
6270#[derive(Debug, Clone)]
6289pub struct ErrorResponseTransformer {
6290 status_codes: HashSet<u16>,
6292 replacement_body: Option<Vec<u8>>,
6294 add_error_id: bool,
6296}
6297
6298impl Default for ErrorResponseTransformer {
6299 fn default() -> Self {
6300 Self::new()
6301 }
6302}
6303
6304impl ErrorResponseTransformer {
6305 #[must_use]
6307 pub fn new() -> Self {
6308 Self {
6309 status_codes: HashSet::new(),
6310 replacement_body: None,
6311 add_error_id: false,
6312 }
6313 }
6314
6315 #[must_use]
6317 pub fn hide_details_for_status(mut self, status: crate::response::StatusCode) -> Self {
6318 self.status_codes.insert(status.as_u16());
6319 self
6320 }
6321
6322 #[must_use]
6324 pub fn with_replacement_body(mut self, body: impl Into<Vec<u8>>) -> Self {
6325 self.replacement_body = Some(body.into());
6326 self
6327 }
6328
6329 #[must_use]
6331 pub fn add_error_id(mut self, enable: bool) -> Self {
6332 self.add_error_id = enable;
6333 self
6334 }
6335}
6336
6337impl ResponseInterceptor for ErrorResponseTransformer {
6338 fn intercept<'a>(
6339 &'a self,
6340 ctx: &'a ResponseInterceptorContext<'a>,
6341 response: Response,
6342 ) -> BoxFuture<'a, Response> {
6343 Box::pin(async move {
6344 let status_code = response.status().as_u16();
6345
6346 if !self.status_codes.contains(&status_code) {
6347 return response;
6348 }
6349
6350 let mut resp = response;
6351
6352 if let Some(ref replacement) = self.replacement_body {
6354 resp = resp.body(crate::response::ResponseBody::Bytes(replacement.clone()));
6355 }
6356
6357 if self.add_error_id {
6359 let error_id = ctx
6361 .request
6362 .get_extension::<RequestId>()
6363 .map(|r| r.0.clone())
6364 .unwrap_or_else(|| format!("err-{}", ctx.elapsed_ms()));
6365 resp = resp.header("X-Error-Id", error_id.into_bytes());
6366 }
6367
6368 resp
6369 })
6370 }
6371
6372 fn name(&self) -> &'static str {
6373 "ErrorResponseTransformer"
6374 }
6375}
6376
6377pub struct ResponseInterceptorMiddleware<I>
6393where
6394 I: ResponseInterceptor,
6395{
6396 interceptor: I,
6397}
6398
6399impl<I> ResponseInterceptorMiddleware<I>
6400where
6401 I: ResponseInterceptor,
6402{
6403 pub fn new(interceptor: I) -> Self {
6405 Self { interceptor }
6406 }
6407}
6408
6409impl<I> Middleware for ResponseInterceptorMiddleware<I>
6410where
6411 I: ResponseInterceptor,
6412{
6413 fn before<'a>(
6414 &'a self,
6415 _ctx: &'a RequestContext,
6416 req: &'a mut Request,
6417 ) -> BoxFuture<'a, ControlFlow> {
6418 req.insert_extension(InterceptorStartTime(Instant::now()));
6420 Box::pin(async { ControlFlow::Continue })
6421 }
6422
6423 fn after<'a>(
6424 &'a self,
6425 ctx: &'a RequestContext,
6426 req: &'a Request,
6427 response: Response,
6428 ) -> BoxFuture<'a, Response> {
6429 Box::pin(async move {
6430 let start_time = req
6432 .get_extension::<InterceptorStartTime>()
6433 .map(|t| t.0)
6434 .unwrap_or_else(Instant::now);
6435
6436 let interceptor_ctx = ResponseInterceptorContext::new(req, ctx, start_time);
6437 self.interceptor.intercept(&interceptor_ctx, response).await
6438 })
6439 }
6440
6441 fn name(&self) -> &'static str {
6442 self.interceptor.name()
6443 }
6444}
6445
6446#[derive(Debug, Clone, Copy)]
6448struct InterceptorStartTime(Instant);
6449
6450#[derive(Debug, Clone)]
6482pub struct ServerTimingEntry {
6483 name: String,
6485 duration_ms: f64,
6487 description: Option<String>,
6489}
6490
6491impl ServerTimingEntry {
6492 #[must_use]
6494 pub fn new(name: impl Into<String>, duration_ms: f64) -> Self {
6495 Self {
6496 name: name.into(),
6497 duration_ms,
6498 description: None,
6499 }
6500 }
6501
6502 #[must_use]
6504 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
6505 self.description = Some(desc.into());
6506 self
6507 }
6508
6509 #[must_use]
6511 pub fn to_header_value(&self) -> String {
6512 match &self.description {
6513 Some(desc) => format!(
6514 "{};dur={:.3};desc=\"{}\"",
6515 self.name, self.duration_ms, desc
6516 ),
6517 None => format!("{};dur={:.3}", self.name, self.duration_ms),
6518 }
6519 }
6520}
6521
6522#[derive(Debug, Clone, Default)]
6538pub struct ServerTimingBuilder {
6539 entries: Vec<ServerTimingEntry>,
6540}
6541
6542impl ServerTimingBuilder {
6543 #[must_use]
6545 pub fn new() -> Self {
6546 Self::default()
6547 }
6548
6549 #[must_use]
6551 pub fn add(mut self, name: impl Into<String>, duration_ms: f64) -> Self {
6552 self.entries.push(ServerTimingEntry::new(name, duration_ms));
6553 self
6554 }
6555
6556 #[must_use]
6558 pub fn add_with_desc(
6559 mut self,
6560 name: impl Into<String>,
6561 duration_ms: f64,
6562 description: impl Into<String>,
6563 ) -> Self {
6564 self.entries
6565 .push(ServerTimingEntry::new(name, duration_ms).with_description(description));
6566 self
6567 }
6568
6569 #[must_use]
6571 pub fn add_entry(mut self, entry: ServerTimingEntry) -> Self {
6572 self.entries.push(entry);
6573 self
6574 }
6575
6576 #[must_use]
6578 pub fn build(&self) -> String {
6579 self.entries
6580 .iter()
6581 .map(ServerTimingEntry::to_header_value)
6582 .collect::<Vec<_>>()
6583 .join(", ")
6584 }
6585
6586 #[must_use]
6588 pub fn is_empty(&self) -> bool {
6589 self.entries.is_empty()
6590 }
6591
6592 #[must_use]
6594 pub fn len(&self) -> usize {
6595 self.entries.len()
6596 }
6597}
6598
6599#[derive(Debug, Clone)]
6615pub struct TimingMetrics {
6616 pub start_time: Instant,
6618 pub first_byte_time: Option<Instant>,
6620 pub custom_metrics: Vec<(String, f64, Option<String>)>,
6622}
6623
6624impl TimingMetrics {
6625 #[must_use]
6627 pub fn new() -> Self {
6628 Self {
6629 start_time: Instant::now(),
6630 first_byte_time: None,
6631 custom_metrics: Vec::new(),
6632 }
6633 }
6634
6635 #[must_use]
6637 pub fn with_start_time(start_time: Instant) -> Self {
6638 Self {
6639 start_time,
6640 first_byte_time: None,
6641 custom_metrics: Vec::new(),
6642 }
6643 }
6644
6645 pub fn mark_first_byte(&mut self) {
6647 self.first_byte_time = Some(Instant::now());
6648 }
6649
6650 pub fn add_metric(&mut self, name: impl Into<String>, duration_ms: f64) {
6652 self.custom_metrics.push((name.into(), duration_ms, None));
6653 }
6654
6655 pub fn add_metric_with_desc(
6657 &mut self,
6658 name: impl Into<String>,
6659 duration_ms: f64,
6660 desc: impl Into<String>,
6661 ) {
6662 self.custom_metrics
6663 .push((name.into(), duration_ms, Some(desc.into())));
6664 }
6665
6666 #[must_use]
6668 pub fn total_ms(&self) -> f64 {
6669 self.start_time.elapsed().as_secs_f64() * 1000.0
6670 }
6671
6672 #[must_use]
6674 pub fn ttfb_ms(&self) -> Option<f64> {
6675 self.first_byte_time
6676 .map(|t| t.duration_since(self.start_time).as_secs_f64() * 1000.0)
6677 }
6678
6679 #[must_use]
6681 pub fn to_server_timing(&self) -> ServerTimingBuilder {
6682 let mut builder = ServerTimingBuilder::new().add_with_desc(
6683 "total",
6684 self.total_ms(),
6685 "Total request time",
6686 );
6687
6688 if let Some(ttfb) = self.ttfb_ms() {
6689 builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
6690 }
6691
6692 for (name, duration, desc) in &self.custom_metrics {
6693 match desc {
6694 Some(d) => builder = builder.add_with_desc(name, *duration, d),
6695 None => builder = builder.add(name, *duration),
6696 }
6697 }
6698
6699 builder
6700 }
6701}
6702
6703impl Default for TimingMetrics {
6704 fn default() -> Self {
6705 Self::new()
6706 }
6707}
6708
6709#[derive(Debug, Clone)]
6711#[allow(clippy::struct_excessive_bools)]
6712pub struct TimingMetricsConfig {
6713 pub add_server_timing_header: bool,
6715 pub add_response_time_header: bool,
6717 pub response_time_header_name: String,
6719 pub include_custom_metrics: bool,
6721 pub include_ttfb: bool,
6723}
6724
6725impl Default for TimingMetricsConfig {
6726 fn default() -> Self {
6727 Self {
6728 add_server_timing_header: true,
6729 add_response_time_header: true,
6730 response_time_header_name: "X-Response-Time".to_string(),
6731 include_custom_metrics: true,
6732 include_ttfb: true,
6733 }
6734 }
6735}
6736
6737impl TimingMetricsConfig {
6738 #[must_use]
6740 pub fn new() -> Self {
6741 Self::default()
6742 }
6743
6744 #[must_use]
6746 pub fn server_timing(mut self, enabled: bool) -> Self {
6747 self.add_server_timing_header = enabled;
6748 self
6749 }
6750
6751 #[must_use]
6753 pub fn response_time(mut self, enabled: bool) -> Self {
6754 self.add_response_time_header = enabled;
6755 self
6756 }
6757
6758 #[must_use]
6760 pub fn response_time_header(mut self, name: impl Into<String>) -> Self {
6761 self.response_time_header_name = name.into();
6762 self
6763 }
6764
6765 #[must_use]
6767 pub fn custom_metrics(mut self, enabled: bool) -> Self {
6768 self.include_custom_metrics = enabled;
6769 self
6770 }
6771
6772 #[must_use]
6774 pub fn ttfb(mut self, enabled: bool) -> Self {
6775 self.include_ttfb = enabled;
6776 self
6777 }
6778
6779 #[must_use]
6781 pub fn production() -> Self {
6782 Self {
6783 add_server_timing_header: false,
6784 add_response_time_header: true,
6785 response_time_header_name: "X-Response-Time".to_string(),
6786 include_custom_metrics: false,
6787 include_ttfb: false,
6788 }
6789 }
6790
6791 #[must_use]
6793 pub fn development() -> Self {
6794 Self::default()
6795 }
6796}
6797
6798#[derive(Debug, Clone)]
6817pub struct TimingMetricsMiddleware {
6818 config: TimingMetricsConfig,
6819}
6820
6821impl TimingMetricsMiddleware {
6822 #[must_use]
6824 pub fn new() -> Self {
6825 Self {
6826 config: TimingMetricsConfig::default(),
6827 }
6828 }
6829
6830 #[must_use]
6832 pub fn with_config(config: TimingMetricsConfig) -> Self {
6833 Self { config }
6834 }
6835
6836 #[must_use]
6838 pub fn production() -> Self {
6839 Self {
6840 config: TimingMetricsConfig::production(),
6841 }
6842 }
6843
6844 #[must_use]
6846 pub fn development() -> Self {
6847 Self {
6848 config: TimingMetricsConfig::development(),
6849 }
6850 }
6851}
6852
6853impl Default for TimingMetricsMiddleware {
6854 fn default() -> Self {
6855 Self::new()
6856 }
6857}
6858
6859impl Middleware for TimingMetricsMiddleware {
6860 fn before<'a>(
6861 &'a self,
6862 _ctx: &'a RequestContext,
6863 req: &'a mut Request,
6864 ) -> BoxFuture<'a, ControlFlow> {
6865 req.insert_extension(TimingMetrics::new());
6867 Box::pin(async { ControlFlow::Continue })
6868 }
6869
6870 fn after<'a>(
6871 &'a self,
6872 _ctx: &'a RequestContext,
6873 req: &'a Request,
6874 response: Response,
6875 ) -> BoxFuture<'a, Response> {
6876 let config = self.config.clone();
6877
6878 Box::pin(async move {
6879 let mut resp = response;
6880
6881 let metrics = req.get_extension::<TimingMetrics>();
6883
6884 match metrics {
6885 Some(metrics) => {
6886 if config.add_response_time_header {
6888 let timing = format!("{:.3}ms", metrics.total_ms());
6889 resp = resp.header(&config.response_time_header_name, timing.into_bytes());
6890 }
6891
6892 if config.add_server_timing_header {
6894 let mut builder = ServerTimingBuilder::new().add_with_desc(
6895 "total",
6896 metrics.total_ms(),
6897 "Total request time",
6898 );
6899
6900 if config.include_ttfb {
6902 if let Some(ttfb) = metrics.ttfb_ms() {
6903 builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
6904 }
6905 }
6906
6907 if config.include_custom_metrics {
6909 for (name, duration, desc) in &metrics.custom_metrics {
6910 match desc {
6911 Some(d) => builder = builder.add_with_desc(name, *duration, d),
6912 None => builder = builder.add(name, *duration),
6913 }
6914 }
6915 }
6916
6917 let header_value = builder.build();
6918 resp = resp.header("Server-Timing", header_value.into_bytes());
6919 }
6920 }
6921 None => {
6922 if config.add_response_time_header {
6925 resp = resp.header(&config.response_time_header_name, b"0.000ms".to_vec());
6926 }
6927 }
6928 }
6929
6930 resp
6931 })
6932 }
6933
6934 fn name(&self) -> &'static str {
6935 "TimingMetrics"
6936 }
6937}
6938
6939#[derive(Debug, Clone)]
6943pub struct TimingHistogramBucket {
6944 pub le: f64,
6946 pub count: u64,
6948}
6949
6950#[derive(Debug, Clone)]
6969pub struct TimingHistogram {
6970 bucket_bounds: Vec<f64>,
6972 bucket_counts: Vec<u64>,
6974 sum: f64,
6976 count: u64,
6978}
6979
6980impl TimingHistogram {
6981 #[must_use]
6985 pub fn with_buckets(bucket_bounds: Vec<f64>) -> Self {
6986 let bucket_counts = vec![0; bucket_bounds.len()];
6987 Self {
6988 bucket_bounds,
6989 bucket_counts,
6990 sum: 0.0,
6991 count: 0,
6992 }
6993 }
6994
6995 #[must_use]
6999 pub fn http_latency() -> Self {
7000 Self::with_buckets(vec![
7001 1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0, 2500.0, 5000.0, 10000.0,
7002 ])
7003 }
7004
7005 pub fn observe(&mut self, value_ms: f64) {
7007 self.sum += value_ms;
7008 self.count += 1;
7009
7010 for (i, bound) in self.bucket_bounds.iter().enumerate() {
7012 if value_ms <= *bound {
7013 self.bucket_counts[i] += 1;
7014 }
7015 }
7016 }
7017
7018 #[must_use]
7020 pub fn count(&self) -> u64 {
7021 self.count
7022 }
7023
7024 #[must_use]
7026 pub fn sum(&self) -> f64 {
7027 self.sum
7028 }
7029
7030 #[must_use]
7032 pub fn mean(&self) -> f64 {
7033 if self.count == 0 {
7034 0.0
7035 } else {
7036 #[allow(clippy::cast_precision_loss)]
7037 {
7038 self.sum / self.count as f64
7039 }
7040 }
7041 }
7042
7043 #[must_use]
7045 pub fn buckets(&self) -> Vec<TimingHistogramBucket> {
7046 self.bucket_bounds
7047 .iter()
7048 .zip(&self.bucket_counts)
7049 .map(|(&le, &count)| TimingHistogramBucket { le, count })
7050 .collect()
7051 }
7052
7053 pub fn reset(&mut self) {
7055 self.sum = 0.0;
7056 self.count = 0;
7057 for count in &mut self.bucket_counts {
7058 *count = 0;
7059 }
7060 }
7061}
7062
7063impl Default for TimingHistogram {
7064 fn default() -> Self {
7065 Self::http_latency()
7066 }
7067}
7068
7069#[cfg(test)]
7074mod timing_metrics_tests {
7075 use super::*;
7076 use crate::request::Method;
7077 use crate::response::StatusCode;
7078
7079 fn test_context() -> RequestContext {
7080 RequestContext::new(asupersync::Cx::for_testing(), 1)
7081 }
7082
7083 fn test_request() -> Request {
7084 Request::new(Method::Get, "/test")
7085 }
7086
7087 fn run_middleware_before(mw: &impl Middleware, req: &mut Request) -> ControlFlow {
7088 let ctx = test_context();
7089 futures_executor::block_on(mw.before(&ctx, req))
7090 }
7091
7092 fn run_middleware_after(mw: &impl Middleware, req: &Request, resp: Response) -> Response {
7093 let ctx = test_context();
7094 futures_executor::block_on(mw.after(&ctx, req, resp))
7095 }
7096
7097 #[test]
7098 fn server_timing_entry_basic() {
7099 let entry = ServerTimingEntry::new("db", 42.5);
7100 assert_eq!(entry.to_header_value(), "db;dur=42.500");
7101 }
7102
7103 #[test]
7104 fn server_timing_entry_with_description() {
7105 let entry = ServerTimingEntry::new("db", 42.5).with_description("Database query");
7106 assert_eq!(
7107 entry.to_header_value(),
7108 "db;dur=42.500;desc=\"Database query\""
7109 );
7110 }
7111
7112 #[test]
7113 fn server_timing_builder_single_entry() {
7114 let timing = ServerTimingBuilder::new().add("total", 150.0).build();
7115 assert_eq!(timing, "total;dur=150.000");
7116 }
7117
7118 #[test]
7119 fn server_timing_builder_multiple_entries() {
7120 let timing = ServerTimingBuilder::new()
7121 .add("total", 150.0)
7122 .add_with_desc("db", 42.0, "Database")
7123 .add("cache", 5.0)
7124 .build();
7125
7126 assert!(timing.contains("total;dur=150.000"));
7127 assert!(timing.contains("db;dur=42.000;desc=\"Database\""));
7128 assert!(timing.contains("cache;dur=5.000"));
7129 assert!(timing.contains(", ")); }
7131
7132 #[test]
7133 fn server_timing_builder_empty() {
7134 let builder = ServerTimingBuilder::new();
7135 assert!(builder.is_empty());
7136 assert_eq!(builder.len(), 0);
7137 assert_eq!(builder.build(), "");
7138 }
7139
7140 #[test]
7141 fn timing_metrics_basic() {
7142 let metrics = TimingMetrics::new();
7143 std::thread::sleep(std::time::Duration::from_millis(5));
7144
7145 let total = metrics.total_ms();
7146 assert!(total >= 5.0, "Total should be at least 5ms");
7147 assert!(metrics.ttfb_ms().is_none(), "TTFB should not be set");
7148 }
7149
7150 #[test]
7151 fn timing_metrics_custom_metrics() {
7152 let mut metrics = TimingMetrics::new();
7153 metrics.add_metric("db", 42.5);
7154 metrics.add_metric_with_desc("cache", 5.0, "Cache lookup");
7155
7156 let timing = metrics.to_server_timing();
7157 assert_eq!(timing.len(), 3); let header = timing.build();
7160 assert!(header.contains("total"));
7161 assert!(header.contains("db;dur=42.500"));
7162 assert!(header.contains("cache;dur=5.000;desc=\"Cache lookup\""));
7163 }
7164
7165 #[test]
7166 fn timing_metrics_ttfb() {
7167 let mut metrics = TimingMetrics::new();
7168 std::thread::sleep(std::time::Duration::from_millis(5));
7169 metrics.mark_first_byte();
7170
7171 let ttfb = metrics.ttfb_ms().unwrap();
7172 assert!(ttfb >= 5.0, "TTFB should be at least 5ms");
7173 }
7174
7175 #[test]
7176 fn timing_metrics_config_default() {
7177 let config = TimingMetricsConfig::default();
7178 assert!(config.add_server_timing_header);
7179 assert!(config.add_response_time_header);
7180 assert!(config.include_custom_metrics);
7181 assert!(config.include_ttfb);
7182 }
7183
7184 #[test]
7185 fn timing_metrics_config_production() {
7186 let config = TimingMetricsConfig::production();
7187 assert!(!config.add_server_timing_header);
7188 assert!(config.add_response_time_header);
7189 assert!(!config.include_custom_metrics);
7190 }
7191
7192 #[test]
7193 fn timing_middleware_adds_metrics_to_request() {
7194 let mw = TimingMetricsMiddleware::new();
7195 let mut req = test_request();
7196
7197 let result = run_middleware_before(&mw, &mut req);
7199 assert!(result.is_continue());
7200
7201 let metrics = req.get_extension::<TimingMetrics>();
7202 assert!(metrics.is_some(), "TimingMetrics should be in extensions");
7203 }
7204
7205 #[test]
7206 fn timing_middleware_adds_response_time_header() {
7207 let mw = TimingMetricsMiddleware::new();
7208 let mut req = test_request();
7209
7210 run_middleware_before(&mw, &mut req);
7212
7213 let resp = Response::with_status(StatusCode::OK);
7214 let result = run_middleware_after(&mw, &req, resp);
7215
7216 let has_timing = result
7217 .headers()
7218 .iter()
7219 .any(|(name, _)| name == "X-Response-Time");
7220 assert!(has_timing, "Should have X-Response-Time header");
7221 }
7222
7223 #[test]
7224 fn timing_middleware_adds_server_timing_header() {
7225 let mw = TimingMetricsMiddleware::new();
7226 let mut req = test_request();
7227
7228 run_middleware_before(&mw, &mut req);
7229
7230 let resp = Response::with_status(StatusCode::OK);
7231 let result = run_middleware_after(&mw, &req, resp);
7232
7233 let server_timing = result
7234 .headers()
7235 .iter()
7236 .find(|(name, _)| name == "Server-Timing")
7237 .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7238
7239 assert!(server_timing.is_some(), "Should have Server-Timing header");
7240 let header = server_timing.unwrap();
7241 assert!(header.contains("total"), "Should have total timing");
7242 }
7243
7244 #[test]
7245 fn timing_middleware_production_mode() {
7246 let mw = TimingMetricsMiddleware::production();
7247 let mut req = test_request();
7248
7249 run_middleware_before(&mw, &mut req);
7250
7251 let resp = Response::with_status(StatusCode::OK);
7252 let result = run_middleware_after(&mw, &req, resp);
7253
7254 let has_response_time = result
7256 .headers()
7257 .iter()
7258 .any(|(name, _)| name == "X-Response-Time");
7259 assert!(has_response_time);
7260
7261 let has_server_timing = result
7263 .headers()
7264 .iter()
7265 .any(|(name, _)| name == "Server-Timing");
7266 assert!(!has_server_timing);
7267 }
7268
7269 #[test]
7270 #[allow(clippy::float_cmp)]
7271 fn timing_histogram_basic() {
7272 let mut histogram = TimingHistogram::http_latency();
7273 assert_eq!(histogram.count(), 0);
7274 assert_eq!(histogram.sum(), 0.0);
7275
7276 histogram.observe(42.0);
7277 histogram.observe(150.0);
7278 histogram.observe(5.0);
7279
7280 assert_eq!(histogram.count(), 3);
7281 assert_eq!(histogram.sum(), 197.0);
7282 assert!((histogram.mean() - 65.666).abs() < 0.01);
7283 }
7284
7285 #[test]
7286 fn timing_histogram_buckets() {
7287 let mut histogram = TimingHistogram::with_buckets(vec![10.0, 50.0, 100.0]);
7288
7289 histogram.observe(5.0); histogram.observe(25.0); histogram.observe(75.0); histogram.observe(150.0); let buckets = histogram.buckets();
7295 assert_eq!(buckets.len(), 3);
7296
7297 assert_eq!(buckets[0].count, 1); assert_eq!(buckets[1].count, 2); assert_eq!(buckets[2].count, 3); }
7302
7303 #[test]
7304 #[allow(clippy::float_cmp)]
7305 fn timing_histogram_reset() {
7306 let mut histogram = TimingHistogram::http_latency();
7307 histogram.observe(100.0);
7308 histogram.observe(200.0);
7309
7310 assert_eq!(histogram.count(), 2);
7311
7312 histogram.reset();
7313
7314 assert_eq!(histogram.count(), 0);
7315 assert_eq!(histogram.sum(), 0.0);
7316 }
7317}
7318
7319#[cfg(test)]
7320mod response_interceptor_tests {
7321 use super::*;
7322 use crate::request::Method;
7323 use crate::response::StatusCode;
7324
7325 fn test_context() -> RequestContext {
7326 RequestContext::new(asupersync::Cx::for_testing(), 1)
7327 }
7328
7329 fn test_request() -> Request {
7330 Request::new(Method::Get, "/test")
7331 }
7332
7333 fn run_interceptor<I: ResponseInterceptor>(
7334 interceptor: &I,
7335 req: &Request,
7336 resp: Response,
7337 ) -> Response {
7338 let ctx = test_context();
7339 let start_time = Instant::now();
7340 let interceptor_ctx = ResponseInterceptorContext::new(req, &ctx, start_time);
7341 futures_executor::block_on(interceptor.intercept(&interceptor_ctx, resp))
7342 }
7343
7344 #[test]
7345 fn timing_interceptor_adds_header() {
7346 let interceptor = TimingInterceptor::new();
7347 let req = test_request();
7348 let resp = Response::with_status(StatusCode::OK);
7349
7350 let result = run_interceptor(&interceptor, &req, resp);
7351
7352 let has_timing = result
7353 .headers()
7354 .iter()
7355 .any(|(name, _)| name == "X-Response-Time");
7356 assert!(has_timing, "Should have X-Response-Time header");
7357 }
7358
7359 #[test]
7360 fn timing_interceptor_with_server_timing() {
7361 let interceptor = TimingInterceptor::new().with_server_timing("app");
7362 let req = test_request();
7363 let resp = Response::with_status(StatusCode::OK);
7364
7365 let result = run_interceptor(&interceptor, &req, resp);
7366
7367 let has_server_timing = result
7368 .headers()
7369 .iter()
7370 .any(|(name, _)| name == "Server-Timing");
7371 assert!(has_server_timing, "Should have Server-Timing header");
7372 }
7373
7374 #[test]
7375 fn timing_interceptor_custom_header_name() {
7376 let interceptor = TimingInterceptor::new().header_name("X-Custom-Time");
7377 let req = test_request();
7378 let resp = Response::with_status(StatusCode::OK);
7379
7380 let result = run_interceptor(&interceptor, &req, resp);
7381
7382 let has_custom = result
7383 .headers()
7384 .iter()
7385 .any(|(name, _)| name == "X-Custom-Time");
7386 assert!(has_custom, "Should have X-Custom-Time header");
7387 }
7388
7389 #[test]
7390 fn debug_info_interceptor_adds_headers() {
7391 let interceptor = DebugInfoInterceptor::new();
7392 let req = test_request();
7393 let resp = Response::with_status(StatusCode::OK);
7394
7395 let result = run_interceptor(&interceptor, &req, resp);
7396
7397 let has_path = result
7398 .headers()
7399 .iter()
7400 .any(|(name, _)| name == "X-Debug-Path");
7401 let has_method = result
7402 .headers()
7403 .iter()
7404 .any(|(name, _)| name == "X-Debug-Method");
7405 let has_timing = result
7406 .headers()
7407 .iter()
7408 .any(|(name, _)| name == "X-Debug-Handler-Time");
7409
7410 assert!(has_path, "Should have X-Debug-Path header");
7411 assert!(has_method, "Should have X-Debug-Method header");
7412 assert!(has_timing, "Should have X-Debug-Handler-Time header");
7413 }
7414
7415 #[test]
7416 fn debug_info_interceptor_custom_prefix() {
7417 let interceptor = DebugInfoInterceptor::new().header_prefix("X-Trace-");
7418 let req = test_request();
7419 let resp = Response::with_status(StatusCode::OK);
7420
7421 let result = run_interceptor(&interceptor, &req, resp);
7422
7423 let has_trace_path = result
7424 .headers()
7425 .iter()
7426 .any(|(name, _)| name == "X-Trace-Path");
7427 assert!(has_trace_path, "Should have X-Trace-Path header");
7428 }
7429
7430 #[test]
7431 fn debug_info_interceptor_selective_options() {
7432 let interceptor = DebugInfoInterceptor::new()
7433 .include_path(true)
7434 .include_method(false)
7435 .include_timing(false)
7436 .include_request_id(false);
7437 let req = test_request();
7438 let resp = Response::with_status(StatusCode::OK);
7439
7440 let result = run_interceptor(&interceptor, &req, resp);
7441
7442 let has_path = result
7443 .headers()
7444 .iter()
7445 .any(|(name, _)| name == "X-Debug-Path");
7446 let has_method = result
7447 .headers()
7448 .iter()
7449 .any(|(name, _)| name == "X-Debug-Method");
7450
7451 assert!(has_path, "Should have X-Debug-Path header");
7452 assert!(!has_method, "Should NOT have X-Debug-Method header");
7453 }
7454
7455 #[test]
7456 fn header_transform_adds_headers() {
7457 let interceptor = HeaderTransformInterceptor::new()
7458 .add("X-Powered-By", b"fastapi_rust".to_vec())
7459 .add("X-Version", b"1.0".to_vec());
7460 let req = test_request();
7461 let resp = Response::with_status(StatusCode::OK);
7462
7463 let result = run_interceptor(&interceptor, &req, resp);
7464
7465 let has_powered_by = result
7466 .headers()
7467 .iter()
7468 .any(|(name, _)| name == "X-Powered-By");
7469 let has_version = result.headers().iter().any(|(name, _)| name == "X-Version");
7470
7471 assert!(has_powered_by, "Should have X-Powered-By header");
7472 assert!(has_version, "Should have X-Version header");
7473 }
7474
7475 #[test]
7476 fn response_body_transform_modifies_body() {
7477 let transformer = ResponseBodyTransform::new(|body| {
7478 let mut result = b"[".to_vec();
7479 result.extend_from_slice(&body);
7480 result.extend_from_slice(b"]");
7481 result
7482 });
7483 let req = test_request();
7484 let resp = Response::with_status(StatusCode::OK)
7485 .body(crate::response::ResponseBody::Bytes(b"hello".to_vec()));
7486
7487 let result = run_interceptor(&transformer, &req, resp);
7488
7489 match result.body_ref() {
7490 crate::response::ResponseBody::Bytes(b) => {
7491 assert_eq!(b, b"[hello]");
7492 }
7493 _ => panic!("Expected bytes body"),
7494 }
7495 }
7496
7497 #[test]
7498 fn response_body_transform_with_content_type_filter() {
7499 let transformer =
7500 ResponseBodyTransform::new(|_| b"transformed".to_vec()).for_content_type("text/plain");
7501 let req = test_request();
7502
7503 let json_resp = Response::with_status(StatusCode::OK)
7505 .header("content-type", b"application/json".to_vec())
7506 .body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
7507
7508 let result = run_interceptor(&transformer, &req, json_resp);
7509
7510 match result.body_ref() {
7511 crate::response::ResponseBody::Bytes(b) => {
7512 assert_eq!(b, b"original", "JSON should not be transformed");
7513 }
7514 _ => panic!("Expected bytes body"),
7515 }
7516
7517 let text_resp = Response::with_status(StatusCode::OK)
7519 .header("content-type", b"text/plain".to_vec())
7520 .body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
7521
7522 let result = run_interceptor(&transformer, &req, text_resp);
7523
7524 match result.body_ref() {
7525 crate::response::ResponseBody::Bytes(b) => {
7526 assert_eq!(b, b"transformed", "Text should be transformed");
7527 }
7528 _ => panic!("Expected bytes body"),
7529 }
7530 }
7531
7532 #[test]
7533 fn error_response_transformer_hides_details() {
7534 let transformer = ErrorResponseTransformer::new()
7535 .hide_details_for_status(StatusCode::INTERNAL_SERVER_ERROR)
7536 .with_replacement_body(b"An error occurred");
7537
7538 let req = test_request();
7539
7540 let error_resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR).body(
7542 crate::response::ResponseBody::Bytes(b"Sensitive error details".to_vec()),
7543 );
7544
7545 let result = run_interceptor(&transformer, &req, error_resp);
7546
7547 match result.body_ref() {
7548 crate::response::ResponseBody::Bytes(b) => {
7549 assert_eq!(b, b"An error occurred");
7550 }
7551 _ => panic!("Expected bytes body"),
7552 }
7553
7554 let ok_resp = Response::with_status(StatusCode::OK)
7556 .body(crate::response::ResponseBody::Bytes(b"Success".to_vec()));
7557
7558 let result = run_interceptor(&transformer, &req, ok_resp);
7559
7560 match result.body_ref() {
7561 crate::response::ResponseBody::Bytes(b) => {
7562 assert_eq!(b, b"Success");
7563 }
7564 _ => panic!("Expected bytes body"),
7565 }
7566 }
7567
7568 #[test]
7569 fn response_interceptor_stack_chains_interceptors() {
7570 let mut stack = ResponseInterceptorStack::new();
7571 stack.push(TimingInterceptor::new());
7572 stack.push(HeaderTransformInterceptor::new().add("X-Extra", b"value".to_vec()));
7573
7574 let req = test_request();
7575 let resp = Response::with_status(StatusCode::OK);
7576
7577 let ctx = test_context();
7578 let start_time = Instant::now();
7579 let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7580 let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
7581
7582 let has_timing = result
7583 .headers()
7584 .iter()
7585 .any(|(name, _)| name == "X-Response-Time");
7586 let has_extra = result.headers().iter().any(|(name, _)| name == "X-Extra");
7587
7588 assert!(
7589 has_timing,
7590 "Should have timing header from first interceptor"
7591 );
7592 assert!(
7593 has_extra,
7594 "Should have extra header from second interceptor"
7595 );
7596 }
7597
7598 #[test]
7599 fn response_interceptor_stack_empty_is_noop() {
7600 let stack = ResponseInterceptorStack::new();
7601 assert!(stack.is_empty());
7602 assert_eq!(stack.len(), 0);
7603
7604 let req = test_request();
7605 let resp = Response::with_status(StatusCode::OK)
7606 .body(crate::response::ResponseBody::Bytes(b"unchanged".to_vec()));
7607
7608 let ctx = test_context();
7609 let start_time = Instant::now();
7610 let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7611 let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
7612
7613 match result.body_ref() {
7614 crate::response::ResponseBody::Bytes(b) => {
7615 assert_eq!(b, b"unchanged");
7616 }
7617 _ => panic!("Expected bytes body"),
7618 }
7619 }
7620
7621 #[test]
7622 fn interceptor_context_provides_timing() {
7623 let ctx = test_context();
7624 let req = test_request();
7625 let start_time = Instant::now();
7626 std::thread::sleep(std::time::Duration::from_millis(5));
7627
7628 let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7629
7630 assert!(
7631 interceptor_ctx.elapsed_ms() >= 5,
7632 "Elapsed time should be at least 5ms"
7633 );
7634 assert!(interceptor_ctx.elapsed().as_millis() >= 5);
7635 }
7636
7637 #[test]
7638 fn conditional_interceptor_applies_conditionally() {
7639 let inner = HeaderTransformInterceptor::new().add("X-Success", b"true".to_vec());
7641 let conditional =
7642 ConditionalInterceptor::new(inner, |_ctx, resp| resp.status().as_u16() == 200);
7643
7644 let req = test_request();
7645
7646 let ok_resp = Response::with_status(StatusCode::OK);
7648 let result = run_interceptor(&conditional, &req, ok_resp);
7649 let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
7650 assert!(has_success, "200 response should get X-Success header");
7651
7652 let not_found = Response::with_status(StatusCode::NOT_FOUND);
7654 let result = run_interceptor(&conditional, &req, not_found);
7655 let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
7656 assert!(!has_success, "404 response should NOT get X-Success header");
7657 }
7658}
7659
7660#[cfg(test)]
7661mod cache_control_tests {
7662 use super::*;
7663 use crate::request::Method;
7664 use crate::response::StatusCode;
7665
7666 fn test_context() -> RequestContext {
7667 RequestContext::new(asupersync::Cx::for_testing(), 1)
7668 }
7669
7670 fn run_after(mw: &CacheControlMiddleware, req: &Request, resp: Response) -> Response {
7671 let ctx = test_context();
7672 let fut = mw.after(&ctx, req, resp);
7673 futures_executor::block_on(fut)
7674 }
7675
7676 #[test]
7677 fn cache_directive_as_str_works() {
7678 assert_eq!(CacheDirective::Public.as_str(), "public");
7679 assert_eq!(CacheDirective::Private.as_str(), "private");
7680 assert_eq!(CacheDirective::NoStore.as_str(), "no-store");
7681 assert_eq!(CacheDirective::NoCache.as_str(), "no-cache");
7682 assert_eq!(CacheDirective::MustRevalidate.as_str(), "must-revalidate");
7683 assert_eq!(CacheDirective::Immutable.as_str(), "immutable");
7684 }
7685
7686 #[test]
7687 fn cache_control_builder_basic() {
7688 let cc = CacheControlBuilder::new()
7689 .public()
7690 .max_age_secs(3600)
7691 .build();
7692 assert!(cc.contains("public"));
7693 assert!(cc.contains("max-age=3600"));
7694 }
7695
7696 #[test]
7697 fn cache_control_builder_complex() {
7698 let cc = CacheControlBuilder::new()
7699 .public()
7700 .max_age_secs(60)
7701 .s_maxage_secs(3600)
7702 .stale_while_revalidate_secs(86400)
7703 .build();
7704 assert!(cc.contains("public"));
7705 assert!(cc.contains("max-age=60"));
7706 assert!(cc.contains("s-maxage=3600"));
7707 assert!(cc.contains("stale-while-revalidate=86400"));
7708 }
7709
7710 #[test]
7711 fn cache_control_builder_no_cache() {
7712 let cc = CacheControlBuilder::new()
7713 .no_store()
7714 .no_cache()
7715 .must_revalidate()
7716 .build();
7717 assert!(cc.contains("no-store"));
7718 assert!(cc.contains("no-cache"));
7719 assert!(cc.contains("must-revalidate"));
7720 }
7721
7722 #[test]
7723 fn cache_preset_no_cache() {
7724 let value = CachePreset::NoCache.to_header_value();
7725 assert!(value.contains("no-store"));
7726 assert!(value.contains("no-cache"));
7727 assert!(value.contains("must-revalidate"));
7728 }
7729
7730 #[test]
7731 fn cache_preset_immutable() {
7732 let value = CachePreset::Immutable.to_header_value();
7733 assert!(value.contains("public"));
7734 assert!(value.contains("max-age=31536000"));
7735 assert!(value.contains("immutable"));
7736 }
7737
7738 #[test]
7739 fn cache_preset_static_assets() {
7740 let value = CachePreset::StaticAssets.to_header_value();
7741 assert!(value.contains("public"));
7742 assert!(value.contains("max-age=86400"));
7743 }
7744
7745 #[test]
7746 fn middleware_adds_cache_control_header() {
7747 let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7748 let req = Request::new(Method::Get, "/api/test");
7749 let resp = Response::with_status(StatusCode::OK);
7750
7751 let result = run_after(&mw, &req, resp);
7752 let headers = result.headers();
7753 let cc_header = headers
7754 .iter()
7755 .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7756 assert!(
7757 cc_header.is_some(),
7758 "Cache-Control header should be present"
7759 );
7760 let (_, value) = cc_header.unwrap();
7761 let value_str = String::from_utf8_lossy(value);
7762 assert!(value_str.contains("public"));
7763 assert!(value_str.contains("max-age=3600"));
7764 }
7765
7766 #[test]
7767 fn middleware_skips_post_requests() {
7768 let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7769 let req = Request::new(Method::Post, "/api/test");
7770 let resp = Response::with_status(StatusCode::OK);
7771
7772 let result = run_after(&mw, &req, resp);
7773 let headers = result.headers();
7774 let cc_header = headers
7775 .iter()
7776 .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7777 assert!(
7778 cc_header.is_none(),
7779 "Cache-Control should not be added for POST"
7780 );
7781 }
7782
7783 #[test]
7784 fn middleware_skips_error_responses() {
7785 let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7786 let req = Request::new(Method::Get, "/api/test");
7787 let resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
7788
7789 let result = run_after(&mw, &req, resp);
7790 let headers = result.headers();
7791 let cc_header = headers
7792 .iter()
7793 .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7794 assert!(
7795 cc_header.is_none(),
7796 "Cache-Control should not be added for error responses"
7797 );
7798 }
7799
7800 #[test]
7801 fn middleware_with_vary_header() {
7802 let mw = CacheControlMiddleware::with_config(
7803 CacheControlConfig::from_preset(CachePreset::PublicOneHour)
7804 .vary("Accept-Encoding")
7805 .vary("Accept-Language"),
7806 );
7807 let req = Request::new(Method::Get, "/api/test");
7808 let resp = Response::with_status(StatusCode::OK);
7809
7810 let result = run_after(&mw, &req, resp);
7811 let headers = result.headers();
7812 let vary_header = headers
7813 .iter()
7814 .find(|(name, _)| name.eq_ignore_ascii_case("vary"));
7815 assert!(vary_header.is_some(), "Vary header should be present");
7816 let (_, value) = vary_header.unwrap();
7817 let value_str = String::from_utf8_lossy(value);
7818 assert!(value_str.contains("Accept-Encoding"));
7819 assert!(value_str.contains("Accept-Language"));
7820 }
7821
7822 #[test]
7823 fn middleware_preserves_existing_cache_control() {
7824 let mw = CacheControlMiddleware::with_config(
7825 CacheControlConfig::from_preset(CachePreset::PublicOneHour).preserve_existing(true),
7826 );
7827 let req = Request::new(Method::Get, "/api/test");
7828 let resp =
7829 Response::with_status(StatusCode::OK).header("Cache-Control", b"max-age=60".to_vec());
7830
7831 let result = run_after(&mw, &req, resp);
7832 let headers = result.headers();
7833 let cc_headers: Vec<_> = headers
7834 .iter()
7835 .filter(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
7836 .collect();
7837 assert_eq!(cc_headers.len(), 1);
7839 let (_, value) = cc_headers[0];
7840 let value_str = String::from_utf8_lossy(value);
7841 assert_eq!(value_str, "max-age=60");
7842 }
7843
7844 #[test]
7845 fn path_pattern_matching_exact() {
7846 assert!(path_matches_pattern("/api/users", "/api/users"));
7847 assert!(!path_matches_pattern("/api/users", "/api/items"));
7848 }
7849
7850 #[test]
7851 fn path_pattern_matching_wildcard() {
7852 assert!(path_matches_pattern("/api/users/123", "/api/users/*"));
7853 assert!(path_matches_pattern("/static/css/style.css", "/static/*"));
7854 assert!(path_matches_pattern("/anything", "*"));
7855 }
7856
7857 #[test]
7858 fn date_formatting_works() {
7859 let now = std::time::SystemTime::now();
7861 let formatted = format_http_date(now);
7862 assert!(formatted.ends_with(" GMT"));
7864 let days = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
7866 assert!(days.iter().any(|d| formatted.starts_with(d)));
7867 }
7868
7869 #[test]
7870 fn leap_year_detection() {
7871 assert!(!is_leap_year(1900)); assert!(is_leap_year(2000)); assert!(is_leap_year(2024)); assert!(!is_leap_year(2023)); }
7876}
7877
7878#[cfg(test)]
7883mod trace_rejection_tests {
7884 use super::*;
7885 use crate::request::Method;
7886 use crate::response::StatusCode;
7887
7888 fn test_context() -> RequestContext {
7889 RequestContext::new(asupersync::Cx::for_testing(), 1)
7890 }
7891
7892 fn run_before(mw: &TraceRejectionMiddleware, req: &mut Request) -> ControlFlow {
7893 let ctx = test_context();
7894 let fut = mw.before(&ctx, req);
7895 futures_executor::block_on(fut)
7896 }
7897
7898 fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
7899 headers
7900 .iter()
7901 .find(|(n, _)| n.eq_ignore_ascii_case(name))
7902 .map(|(_, v)| v.as_slice())
7903 }
7904
7905 #[test]
7906 fn trace_request_rejected() {
7907 let mw = TraceRejectionMiddleware::new();
7908 let mut req = Request::new(Method::Trace, "/");
7909
7910 let result = run_before(&mw, &mut req);
7911
7912 match result {
7913 ControlFlow::Break(response) => {
7914 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
7915 }
7916 ControlFlow::Continue => panic!("TRACE request should have been rejected"),
7917 }
7918 }
7919
7920 #[test]
7921 fn trace_request_with_path() {
7922 let mw = TraceRejectionMiddleware::new();
7923 let mut req = Request::new(Method::Trace, "/api/users/123");
7924
7925 let result = run_before(&mw, &mut req);
7926
7927 match result {
7928 ControlFlow::Break(response) => {
7929 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
7930 }
7931 ControlFlow::Continue => panic!("TRACE request should have been rejected"),
7932 }
7933 }
7934
7935 #[test]
7936 fn get_request_allowed() {
7937 let mw = TraceRejectionMiddleware::new();
7938 let mut req = Request::new(Method::Get, "/");
7939
7940 let result = run_before(&mw, &mut req);
7941
7942 match result {
7943 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("GET request should be allowed"),
7945 }
7946 }
7947
7948 #[test]
7949 fn post_request_allowed() {
7950 let mw = TraceRejectionMiddleware::new();
7951 let mut req = Request::new(Method::Post, "/api/users");
7952
7953 let result = run_before(&mw, &mut req);
7954
7955 match result {
7956 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("POST request should be allowed"),
7958 }
7959 }
7960
7961 #[test]
7962 fn put_request_allowed() {
7963 let mw = TraceRejectionMiddleware::new();
7964 let mut req = Request::new(Method::Put, "/api/users/1");
7965
7966 let result = run_before(&mw, &mut req);
7967
7968 match result {
7969 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("PUT request should be allowed"),
7971 }
7972 }
7973
7974 #[test]
7975 fn delete_request_allowed() {
7976 let mw = TraceRejectionMiddleware::new();
7977 let mut req = Request::new(Method::Delete, "/api/users/1");
7978
7979 let result = run_before(&mw, &mut req);
7980
7981 match result {
7982 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("DELETE request should be allowed"),
7984 }
7985 }
7986
7987 #[test]
7988 fn patch_request_allowed() {
7989 let mw = TraceRejectionMiddleware::new();
7990 let mut req = Request::new(Method::Patch, "/api/users/1");
7991
7992 let result = run_before(&mw, &mut req);
7993
7994 match result {
7995 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("PATCH request should be allowed"),
7997 }
7998 }
7999
8000 #[test]
8001 fn options_request_allowed() {
8002 let mw = TraceRejectionMiddleware::new();
8003 let mut req = Request::new(Method::Options, "/api/users");
8004
8005 let result = run_before(&mw, &mut req);
8006
8007 match result {
8008 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("OPTIONS request should be allowed"),
8010 }
8011 }
8012
8013 #[test]
8014 fn head_request_allowed() {
8015 let mw = TraceRejectionMiddleware::new();
8016 let mut req = Request::new(Method::Head, "/");
8017
8018 let result = run_before(&mw, &mut req);
8019
8020 match result {
8021 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("HEAD request should be allowed"),
8023 }
8024 }
8025
8026 #[test]
8027 fn response_includes_allow_header() {
8028 let mw = TraceRejectionMiddleware::new();
8029 let mut req = Request::new(Method::Trace, "/");
8030
8031 let result = run_before(&mw, &mut req);
8032
8033 match result {
8034 ControlFlow::Break(response) => {
8035 let allow_header = find_header(response.headers(), "Allow");
8036 assert!(
8037 allow_header.is_some(),
8038 "Response should include Allow header"
8039 );
8040 }
8041 ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8042 }
8043 }
8044
8045 #[test]
8046 fn response_has_json_content_type() {
8047 let mw = TraceRejectionMiddleware::new();
8048 let mut req = Request::new(Method::Trace, "/");
8049
8050 let result = run_before(&mw, &mut req);
8051
8052 match result {
8053 ControlFlow::Break(response) => {
8054 let ct_header = find_header(response.headers(), "Content-Type");
8055 assert_eq!(ct_header, Some(b"application/json".as_slice()));
8056 }
8057 ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8058 }
8059 }
8060
8061 #[test]
8062 fn default_enables_logging() {
8063 let mw = TraceRejectionMiddleware::new();
8064 assert!(mw.log_attempts);
8065 }
8066
8067 #[test]
8068 fn log_attempts_can_be_disabled() {
8069 let mw = TraceRejectionMiddleware::new().log_attempts(false);
8070 assert!(!mw.log_attempts);
8071 }
8072
8073 #[test]
8074 fn middleware_name() {
8075 let mw = TraceRejectionMiddleware::new();
8076 assert_eq!(mw.name(), "TraceRejection");
8077 }
8078
8079 #[test]
8080 fn default_impl() {
8081 let mw = TraceRejectionMiddleware::default();
8082 assert!(mw.log_attempts);
8083 }
8084}
8085
8086#[cfg(test)]
8095mod https_redirect_tests {
8096 use super::*;
8097 use crate::request::Method;
8098 use crate::response::StatusCode;
8099
8100 fn test_context() -> RequestContext {
8101 RequestContext::new(asupersync::Cx::for_testing(), 1)
8102 }
8103
8104 fn run_before(mw: &HttpsRedirectMiddleware, req: &mut Request) -> ControlFlow {
8105 let ctx = test_context();
8106 let fut = mw.before(&ctx, req);
8107 futures_executor::block_on(fut)
8108 }
8109
8110 fn run_after(mw: &HttpsRedirectMiddleware, req: &Request, resp: Response) -> Response {
8111 let ctx = test_context();
8112 let fut = mw.after(&ctx, req, resp);
8113 futures_executor::block_on(fut)
8114 }
8115
8116 fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
8117 headers
8118 .iter()
8119 .find(|(n, _)| n.eq_ignore_ascii_case(name))
8120 .map(|(_, v)| v.as_slice())
8121 }
8122
8123 #[test]
8124 fn http_request_redirected() {
8125 let mw = HttpsRedirectMiddleware::new();
8126 let mut req = Request::new(Method::Get, "/");
8127 req.headers_mut().insert("Host", b"example.com".to_vec());
8128
8129 let result = run_before(&mw, &mut req);
8130
8131 match result {
8132 ControlFlow::Break(response) => {
8133 assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY);
8134 let location = find_header(response.headers(), "Location");
8135 assert_eq!(location, Some(b"https://example.com/".as_slice()));
8136 }
8137 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8138 }
8139 }
8140
8141 #[test]
8142 fn http_request_with_path_and_query() {
8143 let mw = HttpsRedirectMiddleware::new();
8144 let mut req = Request::new(Method::Get, "/api/users?page=1");
8145 req.headers_mut().insert("Host", b"example.com".to_vec());
8146
8147 let result = run_before(&mw, &mut req);
8148
8149 match result {
8150 ControlFlow::Break(response) => {
8151 let location = find_header(response.headers(), "Location");
8152 assert_eq!(
8153 location,
8154 Some(b"https://example.com/api/users?page=1".as_slice())
8155 );
8156 }
8157 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8158 }
8159 }
8160
8161 #[test]
8162 fn https_request_not_redirected() {
8163 let mw = HttpsRedirectMiddleware::new();
8164 let mut req = Request::new(Method::Get, "/");
8165 req.headers_mut().insert("Host", b"example.com".to_vec());
8166 req.headers_mut()
8167 .insert("X-Forwarded-Proto", b"https".to_vec());
8168
8169 let result = run_before(&mw, &mut req);
8170
8171 match result {
8172 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("HTTPS request should not be redirected"),
8174 }
8175 }
8176
8177 #[test]
8178 fn x_forwarded_ssl_recognized() {
8179 let mw = HttpsRedirectMiddleware::new();
8180 let mut req = Request::new(Method::Get, "/");
8181 req.headers_mut().insert("Host", b"example.com".to_vec());
8182 req.headers_mut().insert("X-Forwarded-Ssl", b"on".to_vec());
8183
8184 let result = run_before(&mw, &mut req);
8185
8186 match result {
8187 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Request with X-Forwarded-Ssl=on should not redirect"),
8189 }
8190 }
8191
8192 #[test]
8193 fn excluded_path_not_redirected() {
8194 let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
8195 let mut req = Request::new(Method::Get, "/health");
8196 req.headers_mut().insert("Host", b"example.com".to_vec());
8197
8198 let result = run_before(&mw, &mut req);
8199
8200 match result {
8201 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Excluded path should not be redirected"),
8203 }
8204 }
8205
8206 #[test]
8207 fn excluded_path_prefix_matches() {
8208 let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
8209 let mut req = Request::new(Method::Get, "/health/live");
8210 req.headers_mut().insert("Host", b"example.com".to_vec());
8211
8212 let result = run_before(&mw, &mut req);
8213
8214 match result {
8215 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Path with excluded prefix should not be redirected"),
8217 }
8218 }
8219
8220 #[test]
8221 fn temporary_redirect_option() {
8222 let mw = HttpsRedirectMiddleware::new().permanent_redirect(false);
8223 let mut req = Request::new(Method::Get, "/");
8224 req.headers_mut().insert("Host", b"example.com".to_vec());
8225
8226 let result = run_before(&mw, &mut req);
8227
8228 match result {
8229 ControlFlow::Break(response) => {
8230 assert_eq!(response.status(), StatusCode::TEMPORARY_REDIRECT);
8231 }
8232 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8233 }
8234 }
8235
8236 #[test]
8237 fn redirect_disabled() {
8238 let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
8239 let mut req = Request::new(Method::Get, "/");
8240 req.headers_mut().insert("Host", b"example.com".to_vec());
8241
8242 let result = run_before(&mw, &mut req);
8243
8244 match result {
8245 ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Redirects are disabled, should continue"),
8247 }
8248 }
8249
8250 #[test]
8251 fn hsts_header_on_https_response() {
8252 let mw = HttpsRedirectMiddleware::new();
8253 let mut req = Request::new(Method::Get, "/");
8254 req.headers_mut()
8255 .insert("X-Forwarded-Proto", b"https".to_vec());
8256
8257 let response = Response::with_status(StatusCode::OK);
8258 let result = run_after(&mw, &req, response);
8259
8260 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8261 assert!(
8262 hsts.is_some(),
8263 "HSTS header should be present on HTTPS response"
8264 );
8265 let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8266 assert!(hsts_str.contains("max-age=31536000"));
8267 }
8268
8269 #[test]
8270 fn hsts_header_not_on_http_response() {
8271 let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
8272 let req = Request::new(Method::Get, "/");
8273 let response = Response::with_status(StatusCode::OK);
8276 let result = run_after(&mw, &req, response);
8277
8278 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8279 assert!(hsts.is_none(), "HSTS header should not be on HTTP response");
8280 }
8281
8282 #[test]
8283 fn hsts_with_include_subdomains() {
8284 let mw = HttpsRedirectMiddleware::new().include_subdomains(true);
8285 let mut req = Request::new(Method::Get, "/");
8286 req.headers_mut()
8287 .insert("X-Forwarded-Proto", b"https".to_vec());
8288
8289 let response = Response::with_status(StatusCode::OK);
8290 let result = run_after(&mw, &req, response);
8291
8292 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8293 let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8294 assert!(hsts_str.contains("includeSubDomains"));
8295 }
8296
8297 #[test]
8298 fn hsts_with_preload() {
8299 let mw = HttpsRedirectMiddleware::new().preload(true);
8300 let mut req = Request::new(Method::Get, "/");
8301 req.headers_mut()
8302 .insert("X-Forwarded-Proto", b"https".to_vec());
8303
8304 let response = Response::with_status(StatusCode::OK);
8305 let result = run_after(&mw, &req, response);
8306
8307 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8308 let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8309 assert!(hsts_str.contains("preload"));
8310 }
8311
8312 #[test]
8313 fn hsts_disabled_with_zero_max_age() {
8314 let mw = HttpsRedirectMiddleware::new().hsts_max_age_secs(0);
8315 let mut req = Request::new(Method::Get, "/");
8316 req.headers_mut()
8317 .insert("X-Forwarded-Proto", b"https".to_vec());
8318
8319 let response = Response::with_status(StatusCode::OK);
8320 let result = run_after(&mw, &req, response);
8321
8322 let hsts = find_header(result.headers(), "Strict-Transport-Security");
8323 assert!(hsts.is_none(), "HSTS should be disabled with max-age=0");
8324 }
8325
8326 #[test]
8327 fn custom_https_port() {
8328 let mw = HttpsRedirectMiddleware::new().https_port(8443);
8329 let mut req = Request::new(Method::Get, "/");
8330 req.headers_mut().insert("Host", b"example.com".to_vec());
8331
8332 let result = run_before(&mw, &mut req);
8333
8334 match result {
8335 ControlFlow::Break(response) => {
8336 let location = find_header(response.headers(), "Location");
8337 assert_eq!(location, Some(b"https://example.com:8443/".as_slice()));
8338 }
8339 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8340 }
8341 }
8342
8343 #[test]
8344 fn host_with_port_stripped() {
8345 let mw = HttpsRedirectMiddleware::new();
8346 let mut req = Request::new(Method::Get, "/");
8347 req.headers_mut()
8348 .insert("Host", b"example.com:8080".to_vec());
8349
8350 let result = run_before(&mw, &mut req);
8351
8352 match result {
8353 ControlFlow::Break(response) => {
8354 let location = find_header(response.headers(), "Location");
8355 assert_eq!(location, Some(b"https://example.com/".as_slice()));
8357 }
8358 ControlFlow::Continue => panic!("HTTP request should be redirected"),
8359 }
8360 }
8361
8362 #[test]
8363 fn middleware_name() {
8364 let mw = HttpsRedirectMiddleware::new();
8365 assert_eq!(mw.name(), "HttpsRedirect");
8366 }
8367
8368 #[test]
8369 fn default_impl() {
8370 let mw = HttpsRedirectMiddleware::default();
8371 assert!(mw.config.redirect_enabled);
8372 assert!(mw.config.permanent_redirect);
8373 assert_eq!(mw.config.hsts_max_age_secs, 31_536_000);
8374 }
8375
8376 #[test]
8377 fn config_builder() {
8378 let mw = HttpsRedirectMiddleware::new()
8379 .redirect_enabled(false)
8380 .permanent_redirect(false)
8381 .hsts_max_age_secs(86400)
8382 .include_subdomains(true)
8383 .preload(true)
8384 .https_port(8443);
8385
8386 assert!(!mw.config.redirect_enabled);
8387 assert!(!mw.config.permanent_redirect);
8388 assert_eq!(mw.config.hsts_max_age_secs, 86400);
8389 assert!(mw.config.hsts_include_subdomains);
8390 assert!(mw.config.hsts_preload);
8391 assert_eq!(mw.config.https_port, 8443);
8392 }
8393
8394 #[test]
8395 fn exclude_paths_method() {
8396 let mw = HttpsRedirectMiddleware::new()
8397 .exclude_paths(vec!["/health".to_string(), "/ready".to_string()]);
8398
8399 assert_eq!(mw.config.exclude_paths.len(), 2);
8400 assert!(mw.config.exclude_paths.contains(&"/health".to_string()));
8401 assert!(mw.config.exclude_paths.contains(&"/ready".to_string()));
8402 }
8403}
8404
8405#[cfg(test)]
8414mod tests {
8415 use super::*;
8416 use crate::response::{ResponseBody, StatusCode};
8417
8418 #[allow(dead_code)]
8420 struct AddHeaderMiddleware {
8421 name: &'static str,
8422 value: &'static [u8],
8423 }
8424
8425 impl Middleware for AddHeaderMiddleware {
8426 fn after<'a>(
8427 &'a self,
8428 _ctx: &'a RequestContext,
8429 _req: &'a Request,
8430 response: Response,
8431 ) -> BoxFuture<'a, Response> {
8432 Box::pin(async move { response.header(self.name, self.value.to_vec()) })
8433 }
8434 }
8435
8436 #[allow(dead_code)]
8438 struct BlockingMiddleware;
8439
8440 impl Middleware for BlockingMiddleware {
8441 fn before<'a>(
8442 &'a self,
8443 _ctx: &'a RequestContext,
8444 _req: &'a mut Request,
8445 ) -> BoxFuture<'a, ControlFlow> {
8446 Box::pin(async {
8447 ControlFlow::Break(
8448 Response::with_status(StatusCode::FORBIDDEN)
8449 .body(ResponseBody::Bytes(b"blocked".to_vec())),
8450 )
8451 })
8452 }
8453 }
8454
8455 #[allow(dead_code)]
8457 struct TrackingMiddleware {
8458 before_count: std::sync::atomic::AtomicUsize,
8459 after_count: std::sync::atomic::AtomicUsize,
8460 }
8461
8462 #[allow(dead_code)]
8463 impl TrackingMiddleware {
8464 fn new() -> Self {
8465 Self {
8466 before_count: std::sync::atomic::AtomicUsize::new(0),
8467 after_count: std::sync::atomic::AtomicUsize::new(0),
8468 }
8469 }
8470
8471 fn before_count(&self) -> usize {
8472 self.before_count.load(std::sync::atomic::Ordering::SeqCst)
8473 }
8474
8475 fn after_count(&self) -> usize {
8476 self.after_count.load(std::sync::atomic::Ordering::SeqCst)
8477 }
8478 }
8479
8480 impl Middleware for TrackingMiddleware {
8481 fn before<'a>(
8482 &'a self,
8483 _ctx: &'a RequestContext,
8484 _req: &'a mut Request,
8485 ) -> BoxFuture<'a, ControlFlow> {
8486 self.before_count
8487 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
8488 Box::pin(async { ControlFlow::Continue })
8489 }
8490
8491 fn after<'a>(
8492 &'a self,
8493 _ctx: &'a RequestContext,
8494 _req: &'a Request,
8495 response: Response,
8496 ) -> BoxFuture<'a, Response> {
8497 self.after_count
8498 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
8499 Box::pin(async move { response })
8500 }
8501 }
8502
8503 #[test]
8504 fn control_flow_variants() {
8505 let cont = ControlFlow::Continue;
8506 assert!(cont.is_continue());
8507 assert!(!cont.is_break());
8508
8509 let brk = ControlFlow::Break(Response::ok());
8510 assert!(!brk.is_continue());
8511 assert!(brk.is_break());
8512 }
8513
8514 #[test]
8515 fn middleware_stack_empty() {
8516 let stack = MiddlewareStack::new();
8517 assert!(stack.is_empty());
8518 assert_eq!(stack.len(), 0);
8519 }
8520
8521 #[test]
8522 fn middleware_stack_push() {
8523 let mut stack = MiddlewareStack::new();
8524 stack.push(NoopMiddleware);
8525 stack.push(NoopMiddleware);
8526 assert_eq!(stack.len(), 2);
8527 assert!(!stack.is_empty());
8528 }
8529
8530 #[test]
8531 fn noop_middleware_name() {
8532 let mw = NoopMiddleware;
8533 assert_eq!(mw.name(), "Noop");
8534 }
8535
8536 #[test]
8537 fn logging_redacts_sensitive_headers() {
8538 let mut headers = crate::request::Headers::new();
8539 headers.insert("Authorization", b"secret".to_vec());
8540 headers.insert("X-Request-Id", b"abc123".to_vec());
8541
8542 let redacted = super::default_redacted_headers();
8543 let formatted = super::format_headers(headers.iter(), &redacted);
8544
8545 assert!(formatted.contains("authorization=<redacted>"));
8546 assert!(formatted.contains("x-request-id=abc123"));
8547 }
8548
8549 #[test]
8550 fn logging_body_truncation() {
8551 let body = b"abcdef";
8552 let preview = super::format_bytes(body, 4);
8553 assert_eq!(preview, "abcd...");
8554
8555 let preview_full = super::format_bytes(body, 10);
8556 assert_eq!(preview_full, "abcdef");
8557 }
8558
8559 fn test_context() -> RequestContext {
8560 let cx = asupersync::Cx::for_testing();
8561 RequestContext::new(cx, 1)
8562 }
8563
8564 fn header_value(response: &Response, name: &str) -> Option<String> {
8565 response
8566 .headers()
8567 .iter()
8568 .find(|(n, _)| n.eq_ignore_ascii_case(name))
8569 .and_then(|(_, v)| std::str::from_utf8(v).ok())
8570 .map(ToString::to_string)
8571 }
8572
8573 #[test]
8574 fn cors_exact_origin_allows() {
8575 let cors = Cors::new().allow_origin("https://example.com");
8576 let ctx = test_context();
8577 let mut req = Request::new(crate::request::Method::Get, "/");
8578 req.headers_mut()
8579 .insert("origin", b"https://example.com".to_vec());
8580
8581 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8582 assert!(matches!(result, ControlFlow::Continue));
8583
8584 let response = Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()));
8585 let response = futures_executor::block_on(cors.after(&ctx, &req, response));
8586
8587 assert_eq!(
8588 header_value(&response, "access-control-allow-origin"),
8589 Some("https://example.com".to_string())
8590 );
8591 assert_eq!(header_value(&response, "vary"), Some("Origin".to_string()));
8592 }
8593
8594 #[test]
8595 fn cors_wildcard_origin_allows() {
8596 let cors = Cors::new().allow_origin_wildcard("https://*.example.com");
8597 let ctx = test_context();
8598 let mut req = Request::new(crate::request::Method::Get, "/");
8599 req.headers_mut()
8600 .insert("origin", b"https://api.example.com".to_vec());
8601
8602 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8603 assert!(matches!(result, ControlFlow::Continue));
8604 }
8605
8606 #[test]
8607 fn cors_regex_origin_allows() {
8608 let cors = Cors::new().allow_origin_regex(r"^https://.*\.example\.com$");
8609 let ctx = test_context();
8610 let mut req = Request::new(crate::request::Method::Get, "/");
8611 req.headers_mut()
8612 .insert("origin", b"https://svc.example.com".to_vec());
8613
8614 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8615 assert!(matches!(result, ControlFlow::Continue));
8616 }
8617
8618 #[test]
8619 fn cors_preflight_handled() {
8620 let cors = Cors::new()
8621 .allow_any_origin()
8622 .allow_headers(["x-test", "content-type"])
8623 .max_age(600);
8624 let ctx = test_context();
8625 let mut req = Request::new(crate::request::Method::Options, "/");
8626 req.headers_mut()
8627 .insert("origin", b"https://example.com".to_vec());
8628 req.headers_mut()
8629 .insert("access-control-request-method", b"POST".to_vec());
8630 req.headers_mut().insert(
8631 "access-control-request-headers",
8632 b"x-test, content-type".to_vec(),
8633 );
8634
8635 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8636 let ControlFlow::Break(response) = result else {
8637 panic!("expected preflight break");
8638 };
8639
8640 assert_eq!(response.status().as_u16(), 204);
8641 assert_eq!(
8642 header_value(&response, "access-control-allow-origin"),
8643 Some("*".to_string())
8644 );
8645 assert_eq!(
8646 header_value(&response, "access-control-allow-methods"),
8647 Some("GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD".to_string())
8648 );
8649 assert_eq!(
8650 header_value(&response, "access-control-allow-headers"),
8651 Some("x-test, content-type".to_string())
8652 );
8653 assert_eq!(
8654 header_value(&response, "access-control-max-age"),
8655 Some("600".to_string())
8656 );
8657 }
8658
8659 #[test]
8660 fn cors_credentials_echo_origin() {
8661 let cors = Cors::new().allow_any_origin().allow_credentials(true);
8662 let ctx = test_context();
8663 let mut req = Request::new(crate::request::Method::Get, "/");
8664 req.headers_mut()
8665 .insert("origin", b"https://example.com".to_vec());
8666
8667 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8668 assert!(matches!(result, ControlFlow::Continue));
8669
8670 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8671 assert_eq!(
8672 header_value(&response, "access-control-allow-origin"),
8673 Some("https://example.com".to_string())
8674 );
8675 assert_eq!(
8676 header_value(&response, "access-control-allow-credentials"),
8677 Some("true".to_string())
8678 );
8679 }
8680
8681 #[test]
8686 fn cors_spec_compliance_credentials_never_wildcard_origin() {
8687 let cors = Cors::new().allow_any_origin().allow_credentials(true);
8690 let ctx = test_context();
8691
8692 for origin in &[
8694 "https://example.com",
8695 "https://api.example.com",
8696 "http://localhost:3000",
8697 ] {
8698 let mut req = Request::new(crate::request::Method::Get, "/");
8699 req.headers_mut()
8700 .insert("origin", origin.as_bytes().to_vec());
8701
8702 futures_executor::block_on(cors.before(&ctx, &mut req));
8703 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8704
8705 let allow_origin = header_value(&response, "access-control-allow-origin");
8706 assert_eq!(
8707 allow_origin,
8708 Some((*origin).to_string()),
8709 "With credentials enabled, Access-Control-Allow-Origin must echo '{}', not '*'",
8710 origin
8711 );
8712 assert_ne!(
8713 allow_origin,
8714 Some("*".to_string()),
8715 "CORS spec violation: credentials + wildcard origin is forbidden"
8716 );
8717 }
8718 }
8719
8720 #[test]
8721 fn cors_spec_compliance_preflight_with_credentials() {
8722 let cors = Cors::new()
8724 .allow_any_origin()
8725 .allow_credentials(true)
8726 .allow_headers(["content-type", "x-custom-header"]);
8727 let ctx = test_context();
8728
8729 let mut req = Request::new(crate::request::Method::Options, "/");
8730 req.headers_mut()
8731 .insert("origin", b"https://example.com".to_vec());
8732 req.headers_mut()
8733 .insert("access-control-request-method", b"POST".to_vec());
8734 req.headers_mut()
8735 .insert("access-control-request-headers", b"content-type".to_vec());
8736
8737 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8738 let ControlFlow::Break(response) = result else {
8739 panic!("expected preflight break");
8740 };
8741
8742 let allow_origin = header_value(&response, "access-control-allow-origin");
8744 assert_eq!(allow_origin, Some("https://example.com".to_string()));
8745 assert_ne!(
8746 allow_origin,
8747 Some("*".to_string()),
8748 "CORS spec violation: preflight with credentials must not use wildcard origin"
8749 );
8750
8751 assert_eq!(
8753 header_value(&response, "access-control-allow-credentials"),
8754 Some("true".to_string())
8755 );
8756 }
8757
8758 #[test]
8759 fn cors_spec_without_credentials_allows_wildcard() {
8760 let cors = Cors::new().allow_any_origin();
8762 let ctx = test_context();
8763 let mut req = Request::new(crate::request::Method::Get, "/");
8764 req.headers_mut()
8765 .insert("origin", b"https://example.com".to_vec());
8766
8767 futures_executor::block_on(cors.before(&ctx, &mut req));
8768 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8769
8770 assert_eq!(
8772 header_value(&response, "access-control-allow-origin"),
8773 Some("*".to_string())
8774 );
8775 assert!(header_value(&response, "access-control-allow-credentials").is_none());
8777 }
8778
8779 #[test]
8780 fn cors_disallowed_preflight_forbidden() {
8781 let cors = Cors::new().allow_origin("https://good.example");
8782 let ctx = test_context();
8783 let mut req = Request::new(crate::request::Method::Options, "/");
8784 req.headers_mut()
8785 .insert("origin", b"https://evil.example".to_vec());
8786 req.headers_mut()
8787 .insert("access-control-request-method", b"GET".to_vec());
8788
8789 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8790 let ControlFlow::Break(response) = result else {
8791 panic!("expected forbidden preflight");
8792 };
8793 assert_eq!(response.status().as_u16(), 403);
8794 }
8795
8796 #[test]
8797 fn cors_simple_request_disallowed_origin_no_headers() {
8798 let cors = Cors::new().allow_origin("https://good.example");
8800 let ctx = test_context();
8801 let mut req = Request::new(crate::request::Method::Get, "/");
8802 req.headers_mut()
8803 .insert("origin", b"https://evil.example".to_vec());
8804
8805 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8806 assert!(matches!(result, ControlFlow::Continue));
8808
8809 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8810 assert!(header_value(&response, "access-control-allow-origin").is_none());
8812 }
8813
8814 #[test]
8815 fn cors_expose_headers_configuration() {
8816 let cors = Cors::new()
8817 .allow_any_origin()
8818 .expose_headers(["x-custom-header", "x-another-header"]);
8819 let ctx = test_context();
8820 let mut req = Request::new(crate::request::Method::Get, "/");
8821 req.headers_mut()
8822 .insert("origin", b"https://example.com".to_vec());
8823
8824 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8825 assert!(matches!(result, ControlFlow::Continue));
8826
8827 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8828 assert_eq!(
8829 header_value(&response, "access-control-expose-headers"),
8830 Some("x-custom-header, x-another-header".to_string())
8831 );
8832 }
8833
8834 #[test]
8835 fn cors_any_origin_sets_wildcard() {
8836 let cors = Cors::new().allow_any_origin();
8837 let ctx = test_context();
8838 let mut req = Request::new(crate::request::Method::Get, "/");
8839 req.headers_mut()
8840 .insert("origin", b"https://any-site.com".to_vec());
8841
8842 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8843 assert!(matches!(result, ControlFlow::Continue));
8844
8845 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8846 assert_eq!(
8847 header_value(&response, "access-control-allow-origin"),
8848 Some("*".to_string())
8849 );
8850 }
8851
8852 #[test]
8853 fn cors_config_allows_method_override() {
8854 let cors = Cors::new()
8856 .allow_any_origin()
8857 .allow_methods([crate::request::Method::Get, crate::request::Method::Post]);
8858 let ctx = test_context();
8859 let mut req = Request::new(crate::request::Method::Options, "/");
8860 req.headers_mut()
8861 .insert("origin", b"https://example.com".to_vec());
8862 req.headers_mut()
8863 .insert("access-control-request-method", b"POST".to_vec());
8864
8865 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8866 let ControlFlow::Break(response) = result else {
8867 panic!("expected preflight break");
8868 };
8869 assert_eq!(
8870 header_value(&response, "access-control-allow-methods"),
8871 Some("GET, POST".to_string())
8872 );
8873 }
8874
8875 #[test]
8876 fn cors_no_origin_header_skips_cors() {
8877 let cors = Cors::new().allow_any_origin();
8879 let ctx = test_context();
8880 let mut req = Request::new(crate::request::Method::Get, "/");
8881
8882 let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8883 assert!(matches!(result, ControlFlow::Continue));
8884
8885 let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8886 assert!(header_value(&response, "access-control-allow-origin").is_none());
8887 }
8888
8889 #[test]
8890 fn cors_middleware_name() {
8891 let cors = Cors::new();
8892 assert_eq!(cors.name(), "Cors");
8893 }
8894
8895 #[test]
8900 fn request_id_generates_unique_ids() {
8901 let id1 = RequestId::generate();
8902 let id2 = RequestId::generate();
8903 let id3 = RequestId::generate();
8904
8905 assert_ne!(id1, id2);
8906 assert_ne!(id2, id3);
8907 assert_ne!(id1, id3);
8908
8909 assert!(!id1.as_str().is_empty());
8911 assert!(!id2.as_str().is_empty());
8912 assert!(!id3.as_str().is_empty());
8913 }
8914
8915 #[test]
8916 fn request_id_display() {
8917 let id = RequestId::new("test-request-123");
8918 assert_eq!(format!("{}", id), "test-request-123");
8919 }
8920
8921 #[test]
8922 fn request_id_from_string() {
8923 let id: RequestId = "my-id".into();
8924 assert_eq!(id.as_str(), "my-id");
8925
8926 let id2: RequestId = String::from("my-id-2").into();
8927 assert_eq!(id2.as_str(), "my-id-2");
8928 }
8929
8930 #[test]
8931 fn request_id_config_defaults() {
8932 let config = RequestIdConfig::default();
8933 assert_eq!(config.header_name, "x-request-id");
8934 assert!(config.accept_from_client);
8935 assert!(config.add_to_response);
8936 assert_eq!(config.max_client_id_length, 128);
8937 }
8938
8939 #[test]
8940 fn request_id_config_builder() {
8941 let config = RequestIdConfig::new()
8942 .header_name("X-Trace-ID")
8943 .accept_from_client(false)
8944 .add_to_response(false)
8945 .max_client_id_length(64);
8946
8947 assert_eq!(config.header_name, "X-Trace-ID");
8948 assert!(!config.accept_from_client);
8949 assert!(!config.add_to_response);
8950 assert_eq!(config.max_client_id_length, 64);
8951 }
8952
8953 #[test]
8954 fn request_id_middleware_generates_id() {
8955 let middleware = RequestIdMiddleware::new();
8956 let ctx = test_context();
8957 let mut req = Request::new(crate::request::Method::Get, "/");
8958
8959 let result = futures_executor::block_on(middleware.before(&ctx, &mut req));
8960 assert!(matches!(result, ControlFlow::Continue));
8961
8962 let stored_id = req.get_extension::<RequestId>();
8963 assert!(stored_id.is_some());
8964 assert!(!stored_id.unwrap().as_str().is_empty());
8965 }
8966
8967 #[test]
8968 fn request_id_middleware_accepts_client_id() {
8969 let middleware = RequestIdMiddleware::new();
8970 let ctx = test_context();
8971 let mut req = Request::new(crate::request::Method::Get, "/");
8972 req.headers_mut()
8973 .insert("x-request-id", b"client-provided-id-123".to_vec());
8974
8975 futures_executor::block_on(middleware.before(&ctx, &mut req));
8976
8977 let stored_id = req.get_extension::<RequestId>().unwrap();
8978 assert_eq!(stored_id.as_str(), "client-provided-id-123");
8979 }
8980
8981 #[test]
8982 fn request_id_middleware_rejects_invalid_client_id() {
8983 let middleware = RequestIdMiddleware::new();
8984 let ctx = test_context();
8985
8986 let mut req = Request::new(crate::request::Method::Get, "/");
8988 req.headers_mut()
8989 .insert("x-request-id", b"invalid<script>id".to_vec());
8990
8991 futures_executor::block_on(middleware.before(&ctx, &mut req));
8992
8993 let stored_id = req.get_extension::<RequestId>().unwrap();
8994 assert_ne!(stored_id.as_str(), "invalid<script>id");
8996 }
8997
8998 #[test]
8999 fn request_id_middleware_rejects_too_long_client_id() {
9000 let config = RequestIdConfig::new().max_client_id_length(10);
9001 let middleware = RequestIdMiddleware::with_config(config);
9002 let ctx = test_context();
9003
9004 let mut req = Request::new(crate::request::Method::Get, "/");
9005 req.headers_mut()
9006 .insert("x-request-id", b"this-id-is-way-too-long".to_vec());
9007
9008 futures_executor::block_on(middleware.before(&ctx, &mut req));
9009
9010 let stored_id = req.get_extension::<RequestId>().unwrap();
9011 assert_ne!(stored_id.as_str(), "this-id-is-way-too-long");
9013 }
9014
9015 #[test]
9016 fn request_id_middleware_adds_to_response() {
9017 let middleware = RequestIdMiddleware::new();
9018 let ctx = test_context();
9019 let mut req = Request::new(crate::request::Method::Get, "/");
9020
9021 futures_executor::block_on(middleware.before(&ctx, &mut req));
9022 let stored_id = req.get_extension::<RequestId>().unwrap().clone();
9023
9024 let response = Response::ok();
9025 let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9026
9027 let header = header_value(&response, "x-request-id");
9028 assert_eq!(header, Some(stored_id.0));
9029 }
9030
9031 #[test]
9032 fn request_id_middleware_respects_add_to_response_false() {
9033 let config = RequestIdConfig::new().add_to_response(false);
9034 let middleware = RequestIdMiddleware::with_config(config);
9035 let ctx = test_context();
9036 let mut req = Request::new(crate::request::Method::Get, "/");
9037
9038 futures_executor::block_on(middleware.before(&ctx, &mut req));
9039
9040 let response = Response::ok();
9041 let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9042
9043 let header = header_value(&response, "x-request-id");
9044 assert!(header.is_none());
9045 }
9046
9047 #[test]
9048 fn request_id_middleware_respects_accept_from_client_false() {
9049 let config = RequestIdConfig::new().accept_from_client(false);
9050 let middleware = RequestIdMiddleware::with_config(config);
9051 let ctx = test_context();
9052 let mut req = Request::new(crate::request::Method::Get, "/");
9053 req.headers_mut()
9054 .insert("x-request-id", b"client-id".to_vec());
9055
9056 futures_executor::block_on(middleware.before(&ctx, &mut req));
9057
9058 let stored_id = req.get_extension::<RequestId>().unwrap();
9059 assert_ne!(stored_id.as_str(), "client-id");
9061 }
9062
9063 #[test]
9064 fn request_id_middleware_custom_header_name() {
9065 let config = RequestIdConfig::new().header_name("X-Trace-ID");
9066 let middleware = RequestIdMiddleware::with_config(config);
9067 let ctx = test_context();
9068 let mut req = Request::new(crate::request::Method::Get, "/");
9069 req.headers_mut()
9070 .insert("X-Trace-ID", b"trace-123".to_vec());
9071
9072 futures_executor::block_on(middleware.before(&ctx, &mut req));
9073
9074 let stored_id = req.get_extension::<RequestId>().unwrap();
9075 assert_eq!(stored_id.as_str(), "trace-123");
9076
9077 let response = Response::ok();
9078 let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9079
9080 let header = header_value(&response, "X-Trace-ID");
9081 assert_eq!(header, Some("trace-123".to_string()));
9082 }
9083
9084 #[test]
9085 fn is_valid_request_id_accepts_valid() {
9086 assert!(super::is_valid_request_id("abc123"));
9087 assert!(super::is_valid_request_id("request-id-123"));
9088 assert!(super::is_valid_request_id("request_id_123"));
9089 assert!(super::is_valid_request_id("request.id.123"));
9090 assert!(super::is_valid_request_id("ABC123"));
9091 assert!(super::is_valid_request_id("a-b_c.D"));
9092 }
9093
9094 #[test]
9095 fn is_valid_request_id_rejects_invalid() {
9096 assert!(!super::is_valid_request_id(""));
9097 assert!(!super::is_valid_request_id("id with spaces"));
9098 assert!(!super::is_valid_request_id("id<script>"));
9099 assert!(!super::is_valid_request_id("id\nwith\nnewlines"));
9100 assert!(!super::is_valid_request_id("id;with;semicolons"));
9101 assert!(!super::is_valid_request_id("id/with/slashes"));
9102 }
9103
9104 #[test]
9105 fn request_id_middleware_name() {
9106 let middleware = RequestIdMiddleware::new();
9107 assert_eq!(middleware.name(), "RequestId");
9108 }
9109
9110 struct OrderTrackingMiddleware {
9116 id: &'static str,
9117 log: Arc<std::sync::Mutex<Vec<String>>>,
9118 }
9119
9120 impl OrderTrackingMiddleware {
9121 fn new(id: &'static str, log: Arc<std::sync::Mutex<Vec<String>>>) -> Self {
9122 Self { id, log }
9123 }
9124 }
9125
9126 impl Middleware for OrderTrackingMiddleware {
9127 fn before<'a>(
9128 &'a self,
9129 _ctx: &'a RequestContext,
9130 _req: &'a mut Request,
9131 ) -> BoxFuture<'a, ControlFlow> {
9132 self.log.lock().unwrap().push(format!("{}.before", self.id));
9133 Box::pin(async { ControlFlow::Continue })
9134 }
9135
9136 fn after<'a>(
9137 &'a self,
9138 _ctx: &'a RequestContext,
9139 _req: &'a Request,
9140 response: Response,
9141 ) -> BoxFuture<'a, Response> {
9142 self.log.lock().unwrap().push(format!("{}.after", self.id));
9143 Box::pin(async move { response })
9144 }
9145 }
9146
9147 struct ConditionalBreakMiddleware {
9149 id: &'static str,
9150 should_break: bool,
9151 log: Arc<std::sync::Mutex<Vec<String>>>,
9152 }
9153
9154 impl ConditionalBreakMiddleware {
9155 fn new(
9156 id: &'static str,
9157 should_break: bool,
9158 log: Arc<std::sync::Mutex<Vec<String>>>,
9159 ) -> Self {
9160 Self {
9161 id,
9162 should_break,
9163 log,
9164 }
9165 }
9166 }
9167
9168 impl Middleware for ConditionalBreakMiddleware {
9169 fn before<'a>(
9170 &'a self,
9171 _ctx: &'a RequestContext,
9172 _req: &'a mut Request,
9173 ) -> BoxFuture<'a, ControlFlow> {
9174 self.log.lock().unwrap().push(format!("{}.before", self.id));
9175 let should_break = self.should_break;
9176 Box::pin(async move {
9177 if should_break {
9178 ControlFlow::Break(
9179 Response::with_status(StatusCode::FORBIDDEN)
9180 .body(ResponseBody::Bytes(b"blocked".to_vec())),
9181 )
9182 } else {
9183 ControlFlow::Continue
9184 }
9185 })
9186 }
9187
9188 fn after<'a>(
9189 &'a self,
9190 _ctx: &'a RequestContext,
9191 _req: &'a Request,
9192 response: Response,
9193 ) -> BoxFuture<'a, Response> {
9194 self.log.lock().unwrap().push(format!("{}.after", self.id));
9195 Box::pin(async move { response })
9196 }
9197 }
9198
9199 struct OkHandler;
9201
9202 impl Handler for OkHandler {
9203 fn call<'a>(
9204 &'a self,
9205 _ctx: &'a RequestContext,
9206 _req: &'a mut Request,
9207 ) -> BoxFuture<'a, Response> {
9208 Box::pin(async move { Response::ok().body(ResponseBody::Bytes(b"handler".to_vec())) })
9209 }
9210 }
9211
9212 struct CheckHeaderHandler;
9214
9215 impl Handler for CheckHeaderHandler {
9216 fn call<'a>(
9217 &'a self,
9218 _ctx: &'a RequestContext,
9219 req: &'a mut Request,
9220 ) -> BoxFuture<'a, Response> {
9221 let has_header = req.headers().get("X-Modified-By").is_some();
9222 Box::pin(async move {
9223 if has_header {
9224 Response::ok().body(ResponseBody::Bytes(b"header-present".to_vec()))
9225 } else {
9226 Response::with_status(StatusCode::BAD_REQUEST)
9227 }
9228 })
9229 }
9230 }
9231
9232 struct ErrorHandler;
9234
9235 impl Handler for ErrorHandler {
9236 fn call<'a>(
9237 &'a self,
9238 _ctx: &'a RequestContext,
9239 _req: &'a mut Request,
9240 ) -> BoxFuture<'a, Response> {
9241 Box::pin(async move { Response::with_status(StatusCode::INTERNAL_SERVER_ERROR) })
9242 }
9243 }
9244
9245 #[test]
9246 fn middleware_stack_executes_in_correct_order() {
9247 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9250
9251 let mut stack = MiddlewareStack::new();
9252 stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9253 stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9254 stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
9255
9256 let ctx = test_context();
9257 let mut req = Request::new(crate::request::Method::Get, "/");
9258
9259 futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9260
9261 let calls = log.lock().unwrap().clone();
9262 assert_eq!(
9263 calls,
9264 vec![
9265 "mw1.before",
9266 "mw2.before",
9267 "mw3.before",
9268 "mw3.after",
9269 "mw2.after",
9270 "mw1.after",
9271 ]
9272 );
9273 }
9274
9275 #[test]
9276 fn middleware_stack_short_circuit_skips_later_middleware() {
9277 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9280
9281 let mut stack = MiddlewareStack::new();
9282 stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9283 stack.push(ConditionalBreakMiddleware::new("mw2", true, log.clone()));
9284 stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
9285
9286 let ctx = test_context();
9287 let mut req = Request::new(crate::request::Method::Get, "/");
9288
9289 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9290
9291 assert_eq!(response.status().as_u16(), 403);
9293
9294 let calls = log.lock().unwrap().clone();
9295 assert_eq!(
9296 calls,
9297 vec![
9298 "mw1.before",
9299 "mw2.before",
9300 "mw1.after",
9303 ]
9304 );
9305 }
9306
9307 #[test]
9308 fn middleware_stack_first_middleware_breaks() {
9309 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9311
9312 let mut stack = MiddlewareStack::new();
9313 stack.push(ConditionalBreakMiddleware::new("mw1", true, log.clone()));
9314 stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9315
9316 let ctx = test_context();
9317 let mut req = Request::new(crate::request::Method::Get, "/");
9318
9319 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9320
9321 assert_eq!(response.status().as_u16(), 403);
9322
9323 let calls = log.lock().unwrap().clone();
9324 assert_eq!(calls, vec!["mw1.before"]);
9325 }
9327
9328 #[test]
9329 fn middleware_stack_last_middleware_breaks() {
9330 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9332
9333 let mut stack = MiddlewareStack::new();
9334 stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9335 stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9336 stack.push(ConditionalBreakMiddleware::new("mw3", true, log.clone()));
9337
9338 let ctx = test_context();
9339 let mut req = Request::new(crate::request::Method::Get, "/");
9340
9341 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9342
9343 assert_eq!(response.status().as_u16(), 403);
9344
9345 let calls = log.lock().unwrap().clone();
9346 assert_eq!(
9347 calls,
9348 vec![
9349 "mw1.before",
9350 "mw2.before",
9351 "mw3.before",
9352 "mw2.after",
9354 "mw1.after",
9355 ]
9356 );
9357 }
9358
9359 #[test]
9360 fn middleware_stack_empty_executes_handler_directly() {
9361 let stack = MiddlewareStack::new();
9362 let ctx = test_context();
9363 let mut req = Request::new(crate::request::Method::Get, "/");
9364
9365 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9366
9367 assert_eq!(response.status().as_u16(), 200);
9368 }
9369
9370 #[test]
9371 fn middleware_stack_with_capacity() {
9372 let stack = MiddlewareStack::with_capacity(10);
9373 assert!(stack.is_empty());
9374 assert_eq!(stack.len(), 0);
9375 }
9376
9377 #[test]
9378 fn middleware_stack_push_arc() {
9379 let mut stack = MiddlewareStack::new();
9380 let mw: Arc<dyn Middleware> = Arc::new(NoopMiddleware);
9381 stack.push_arc(mw);
9382 assert_eq!(stack.len(), 1);
9383 }
9384
9385 #[test]
9390 fn add_response_header_adds_header() {
9391 let mw = AddResponseHeader::new("X-Custom", b"custom-value".to_vec());
9392 let ctx = test_context();
9393 let req = Request::new(crate::request::Method::Get, "/");
9394
9395 let response = Response::ok();
9396 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9397
9398 assert_eq!(
9399 header_value(&response, "X-Custom"),
9400 Some("custom-value".to_string())
9401 );
9402 }
9403
9404 #[test]
9405 fn add_response_header_preserves_existing_headers() {
9406 let mw = AddResponseHeader::new("X-New", b"new".to_vec());
9407 let ctx = test_context();
9408 let req = Request::new(crate::request::Method::Get, "/");
9409
9410 let response = Response::ok().header("X-Existing", b"existing".to_vec());
9411 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9412
9413 assert_eq!(
9414 header_value(&response, "X-Existing"),
9415 Some("existing".to_string())
9416 );
9417 assert_eq!(header_value(&response, "X-New"), Some("new".to_string()));
9418 }
9419
9420 #[test]
9421 fn add_response_header_name() {
9422 let mw = AddResponseHeader::new("X-Test", b"test".to_vec());
9423 assert_eq!(mw.name(), "AddResponseHeader");
9424 }
9425
9426 #[test]
9431 fn require_header_allows_with_header() {
9432 let mw = RequireHeader::new("X-Api-Key");
9433 let ctx = test_context();
9434 let mut req = Request::new(crate::request::Method::Get, "/");
9435 req.headers_mut()
9436 .insert("X-Api-Key", b"secret-key".to_vec());
9437
9438 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9439 assert!(matches!(result, ControlFlow::Continue));
9440 }
9441
9442 #[test]
9443 fn require_header_blocks_without_header() {
9444 let mw = RequireHeader::new("X-Api-Key");
9445 let ctx = test_context();
9446 let mut req = Request::new(crate::request::Method::Get, "/");
9447
9448 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9449
9450 match result {
9451 ControlFlow::Break(response) => {
9452 assert_eq!(response.status().as_u16(), 400);
9453 }
9454 ControlFlow::Continue => panic!("Expected Break, got Continue"),
9455 }
9456 }
9457
9458 #[test]
9459 fn require_header_name() {
9460 let mw = RequireHeader::new("X-Test");
9461 assert_eq!(mw.name(), "RequireHeader");
9462 }
9463
9464 #[test]
9469 fn path_prefix_filter_allows_matching_path() {
9470 let mw = PathPrefixFilter::new("/api");
9471 let ctx = test_context();
9472 let mut req = Request::new(crate::request::Method::Get, "/api/users");
9473
9474 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9475 assert!(matches!(result, ControlFlow::Continue));
9476 }
9477
9478 #[test]
9479 fn path_prefix_filter_allows_exact_prefix() {
9480 let mw = PathPrefixFilter::new("/api");
9481 let ctx = test_context();
9482 let mut req = Request::new(crate::request::Method::Get, "/api");
9483
9484 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9485 assert!(matches!(result, ControlFlow::Continue));
9486 }
9487
9488 #[test]
9489 fn path_prefix_filter_blocks_non_matching_path() {
9490 let mw = PathPrefixFilter::new("/api");
9491 let ctx = test_context();
9492 let mut req = Request::new(crate::request::Method::Get, "/admin/users");
9493
9494 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9495
9496 match result {
9497 ControlFlow::Break(response) => {
9498 assert_eq!(response.status().as_u16(), 404);
9499 }
9500 ControlFlow::Continue => panic!("Expected Break, got Continue"),
9501 }
9502 }
9503
9504 #[test]
9505 fn path_prefix_filter_name() {
9506 let mw = PathPrefixFilter::new("/api");
9507 assert_eq!(mw.name(), "PathPrefixFilter");
9508 }
9509
9510 #[test]
9515 fn conditional_status_applies_true_status() {
9516 let mw = ConditionalStatus::new(
9517 |req| req.path() == "/health",
9518 StatusCode::OK,
9519 StatusCode::NOT_FOUND,
9520 );
9521 let ctx = test_context();
9522 let req = Request::new(crate::request::Method::Get, "/health");
9523 let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
9524
9525 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9526 assert_eq!(response.status().as_u16(), 200);
9527 }
9528
9529 #[test]
9530 fn conditional_status_applies_false_status() {
9531 let mw = ConditionalStatus::new(
9532 |req| req.path() == "/health",
9533 StatusCode::OK,
9534 StatusCode::NOT_FOUND,
9535 );
9536 let ctx = test_context();
9537 let req = Request::new(crate::request::Method::Get, "/other");
9538 let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
9539
9540 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9541 assert_eq!(response.status().as_u16(), 404);
9542 }
9543
9544 #[test]
9545 fn conditional_status_name() {
9546 let mw = ConditionalStatus::new(|_| true, StatusCode::OK, StatusCode::NOT_FOUND);
9547 assert_eq!(mw.name(), "ConditionalStatus");
9548 }
9549
9550 #[derive(Clone)]
9555 struct LayerTestMiddleware {
9556 prefix: String,
9557 }
9558
9559 impl LayerTestMiddleware {
9560 fn new(prefix: impl Into<String>) -> Self {
9561 Self {
9562 prefix: prefix.into(),
9563 }
9564 }
9565 }
9566
9567 impl Middleware for LayerTestMiddleware {
9568 fn after<'a>(
9569 &'a self,
9570 _ctx: &'a RequestContext,
9571 _req: &'a Request,
9572 response: Response,
9573 ) -> BoxFuture<'a, Response> {
9574 let prefix = self.prefix.clone();
9575 Box::pin(async move { response.header("X-Layer", prefix.into_bytes()) })
9576 }
9577 }
9578
9579 #[test]
9580 fn layer_wraps_handler() {
9581 let layer = Layer::new(LayerTestMiddleware::new("wrapped"));
9582 let wrapped = layer.wrap(OkHandler);
9583
9584 let ctx = test_context();
9585 let mut req = Request::new(crate::request::Method::Get, "/");
9586
9587 let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
9588
9589 assert_eq!(response.status().as_u16(), 200);
9590 assert_eq!(
9591 header_value(&response, "X-Layer"),
9592 Some("wrapped".to_string())
9593 );
9594 }
9595
9596 #[test]
9597 fn layered_handles_break() {
9598 #[derive(Clone)]
9599 struct BreakingMiddleware;
9600
9601 impl Middleware for BreakingMiddleware {
9602 fn before<'a>(
9603 &'a self,
9604 _ctx: &'a RequestContext,
9605 _req: &'a mut Request,
9606 ) -> BoxFuture<'a, ControlFlow> {
9607 Box::pin(async {
9608 ControlFlow::Break(Response::with_status(StatusCode::UNAUTHORIZED))
9609 })
9610 }
9611
9612 fn after<'a>(
9613 &'a self,
9614 _ctx: &'a RequestContext,
9615 _req: &'a Request,
9616 response: Response,
9617 ) -> BoxFuture<'a, Response> {
9618 Box::pin(async move { response.header("X-After", b"ran".to_vec()) })
9619 }
9620 }
9621
9622 let layer = Layer::new(BreakingMiddleware);
9623 let wrapped = layer.wrap(OkHandler);
9624
9625 let ctx = test_context();
9626 let mut req = Request::new(crate::request::Method::Get, "/");
9627
9628 let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
9629
9630 assert_eq!(response.status().as_u16(), 401);
9632 assert_eq!(header_value(&response, "X-After"), Some("ran".to_string()));
9634 }
9635
9636 #[test]
9641 fn request_response_logger_default() {
9642 let logger = RequestResponseLogger::default();
9643 assert!(logger.log_request_headers);
9644 assert!(logger.log_response_headers);
9645 assert!(!logger.log_body);
9646 assert_eq!(logger.max_body_bytes, 1024);
9647 }
9648
9649 #[test]
9650 fn request_response_logger_builder() {
9651 let logger = RequestResponseLogger::new()
9652 .log_request_headers(false)
9653 .log_response_headers(false)
9654 .log_body(true)
9655 .max_body_bytes(2048)
9656 .redact_header("x-secret");
9657
9658 assert!(!logger.log_request_headers);
9659 assert!(!logger.log_response_headers);
9660 assert!(logger.log_body);
9661 assert_eq!(logger.max_body_bytes, 2048);
9662 assert!(logger.redact_headers.contains("x-secret"));
9663 }
9664
9665 #[test]
9666 fn request_response_logger_name() {
9667 let logger = RequestResponseLogger::new();
9668 assert_eq!(logger.name(), "RequestResponseLogger");
9669 }
9670
9671 #[test]
9676 fn middleware_stack_modifies_request_for_handler() {
9677 struct RequestModifier;
9679
9680 impl Middleware for RequestModifier {
9681 fn before<'a>(
9682 &'a self,
9683 _ctx: &'a RequestContext,
9684 req: &'a mut Request,
9685 ) -> BoxFuture<'a, ControlFlow> {
9686 req.headers_mut()
9687 .insert("X-Modified-By", b"middleware".to_vec());
9688 Box::pin(async { ControlFlow::Continue })
9689 }
9690 }
9691
9692 let mut stack = MiddlewareStack::new();
9693 stack.push(RequestModifier);
9694
9695 let ctx = test_context();
9696 let mut req = Request::new(crate::request::Method::Get, "/");
9697
9698 let response =
9699 futures_executor::block_on(stack.execute(&CheckHeaderHandler, &ctx, &mut req));
9700
9701 assert_eq!(response.status().as_u16(), 200);
9702 }
9703
9704 #[test]
9705 fn middleware_stack_multiple_response_modifications() {
9706 let mut stack = MiddlewareStack::new();
9707 stack.push(AddResponseHeader::new("X-First", b"1".to_vec()));
9708 stack.push(AddResponseHeader::new("X-Second", b"2".to_vec()));
9709 stack.push(AddResponseHeader::new("X-Third", b"3".to_vec()));
9710
9711 let ctx = test_context();
9712 let mut req = Request::new(crate::request::Method::Get, "/");
9713
9714 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9715
9716 assert_eq!(header_value(&response, "X-First"), Some("1".to_string()));
9718 assert_eq!(header_value(&response, "X-Second"), Some("2".to_string()));
9719 assert_eq!(header_value(&response, "X-Third"), Some("3".to_string()));
9720 }
9721
9722 #[test]
9723 fn middleware_stack_handler_receives_response_after_break() {
9724 let mut stack = MiddlewareStack::new();
9726 stack.push(ConditionalBreakMiddleware::new(
9727 "breaker",
9728 true,
9729 Arc::new(std::sync::Mutex::new(Vec::new())),
9730 ));
9731
9732 let ctx = test_context();
9733 let mut req = Request::new(crate::request::Method::Get, "/");
9734
9735 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9736
9737 assert_eq!(response.status().as_u16(), 403);
9738 match response.body_ref() {
9740 ResponseBody::Bytes(b) => assert_eq!(b, b"blocked"),
9741 _ => panic!("Expected Bytes body"),
9742 }
9743 }
9744
9745 #[test]
9750 fn middleware_after_can_change_status() {
9751 struct StatusChanger;
9752
9753 impl Middleware for StatusChanger {
9754 fn after<'a>(
9755 &'a self,
9756 _ctx: &'a RequestContext,
9757 _req: &'a Request,
9758 _response: Response,
9759 ) -> BoxFuture<'a, Response> {
9760 Box::pin(async { Response::with_status(StatusCode::SERVICE_UNAVAILABLE) })
9761 }
9762 }
9763
9764 let mut stack = MiddlewareStack::new();
9765 stack.push(StatusChanger);
9766
9767 let ctx = test_context();
9768 let mut req = Request::new(crate::request::Method::Get, "/");
9769
9770 let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9771
9772 assert_eq!(response.status().as_u16(), 503);
9774 }
9775
9776 #[test]
9777 fn middleware_after_runs_even_on_error_status() {
9778 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9779 let mut stack = MiddlewareStack::new();
9780 stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9781
9782 let ctx = test_context();
9783 let mut req = Request::new(crate::request::Method::Get, "/");
9784
9785 let response = futures_executor::block_on(stack.execute(&ErrorHandler, &ctx, &mut req));
9786
9787 assert_eq!(response.status().as_u16(), 500);
9788
9789 let calls = log.lock().unwrap().clone();
9790 assert_eq!(calls, vec!["mw1.before", "mw1.after"]);
9792 }
9793
9794 #[test]
9799 fn wildcard_match_simple() {
9800 assert!(super::wildcard_match("*.example.com", "api.example.com"));
9801 assert!(super::wildcard_match("*.example.com", "www.example.com"));
9802 assert!(!super::wildcard_match("*.example.com", "example.com"));
9803 }
9804
9805 #[test]
9806 fn wildcard_match_suffix_pattern() {
9807 assert!(super::wildcard_match("*.txt", "file.txt"));
9809 assert!(super::wildcard_match("*.txt", "document.txt"));
9810 assert!(!super::wildcard_match("*.txt", "file.doc"));
9811 assert!(super::wildcard_match("*-suffix", "any-suffix"));
9812 }
9813
9814 #[test]
9815 fn wildcard_match_no_wildcard() {
9816 assert!(super::wildcard_match("exact", "exact"));
9817 assert!(!super::wildcard_match("exact", "different"));
9818 }
9819
9820 #[test]
9821 fn regex_match_anchored() {
9822 assert!(super::regex_match("^hello$", "hello"));
9823 assert!(!super::regex_match("^hello$", "hello world"));
9824 assert!(!super::regex_match("^hello$", "say hello"));
9825 }
9826
9827 #[test]
9828 fn regex_match_dot_wildcard() {
9829 assert!(super::regex_match("h.llo", "hello"));
9830 assert!(super::regex_match("h.llo", "hallo"));
9831 }
9832
9833 #[test]
9834 fn regex_match_star() {
9835 assert!(super::regex_match("hel*o", "hello"));
9836 assert!(super::regex_match("hel*o", "helo"));
9837 assert!(super::regex_match("hel*o", "hellllllo"));
9838 }
9839
9840 #[test]
9845 fn middleware_default_before_continues() {
9846 struct DefaultBefore;
9847 impl Middleware for DefaultBefore {}
9848
9849 let mw = DefaultBefore;
9850 let ctx = test_context();
9851 let mut req = Request::new(crate::request::Method::Get, "/");
9852
9853 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9854 assert!(matches!(result, ControlFlow::Continue));
9855 }
9856
9857 #[test]
9858 fn middleware_default_after_passes_through() {
9859 struct DefaultAfter;
9860 impl Middleware for DefaultAfter {}
9861
9862 let mw = DefaultAfter;
9863 let ctx = test_context();
9864 let req = Request::new(crate::request::Method::Get, "/");
9865 let response = Response::with_status(StatusCode::CREATED);
9866
9867 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
9868 assert_eq!(result.status().as_u16(), 201);
9869 }
9870
9871 #[test]
9872 fn middleware_default_name_is_type_name() {
9873 struct MyCustomMiddleware;
9874 impl Middleware for MyCustomMiddleware {}
9875
9876 let mw = MyCustomMiddleware;
9877 assert!(mw.name().contains("MyCustomMiddleware"));
9878 }
9879
9880 #[test]
9885 fn security_headers_default_config() {
9886 let config = SecurityHeadersConfig::default();
9887 assert_eq!(config.x_content_type_options, Some("nosniff"));
9888 assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
9889 assert_eq!(config.x_xss_protection, Some("0"));
9890 assert!(config.content_security_policy.is_none());
9891 assert!(config.hsts.is_none());
9892 assert_eq!(
9893 config.referrer_policy,
9894 Some(ReferrerPolicy::StrictOriginWhenCrossOrigin)
9895 );
9896 assert!(config.permissions_policy.is_none());
9897 }
9898
9899 #[test]
9900 fn security_headers_none_config() {
9901 let config = SecurityHeadersConfig::none();
9902 assert!(config.x_content_type_options.is_none());
9903 assert!(config.x_frame_options.is_none());
9904 assert!(config.x_xss_protection.is_none());
9905 assert!(config.content_security_policy.is_none());
9906 assert!(config.hsts.is_none());
9907 assert!(config.referrer_policy.is_none());
9908 assert!(config.permissions_policy.is_none());
9909 }
9910
9911 #[test]
9912 fn security_headers_strict_config() {
9913 let config = SecurityHeadersConfig::strict();
9914 assert_eq!(config.x_content_type_options, Some("nosniff"));
9915 assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
9916 assert_eq!(
9917 config.content_security_policy,
9918 Some("default-src 'self'".to_string())
9919 );
9920 assert_eq!(config.hsts, Some((31536000, true, false)));
9921 assert_eq!(config.referrer_policy, Some(ReferrerPolicy::NoReferrer));
9922 assert!(config.permissions_policy.is_some());
9923 }
9924
9925 #[test]
9926 fn security_headers_config_builder() {
9927 let config = SecurityHeadersConfig::new()
9928 .x_frame_options(Some(XFrameOptions::SameOrigin))
9929 .content_security_policy("default-src 'self'")
9930 .hsts(86400, false, false)
9931 .referrer_policy(Some(ReferrerPolicy::Origin));
9932
9933 assert_eq!(config.x_frame_options, Some(XFrameOptions::SameOrigin));
9934 assert_eq!(
9935 config.content_security_policy,
9936 Some("default-src 'self'".to_string())
9937 );
9938 assert_eq!(config.hsts, Some((86400, false, false)));
9939 assert_eq!(config.referrer_policy, Some(ReferrerPolicy::Origin));
9940 }
9941
9942 #[test]
9943 fn security_headers_hsts_value_format() {
9944 let config = SecurityHeadersConfig::none().hsts(3600, false, false);
9946 assert_eq!(config.build_hsts_value(), Some("max-age=3600".to_string()));
9947
9948 let config = SecurityHeadersConfig::none().hsts(3600, true, false);
9950 assert_eq!(
9951 config.build_hsts_value(),
9952 Some("max-age=3600; includeSubDomains".to_string())
9953 );
9954
9955 let config = SecurityHeadersConfig::none().hsts(3600, false, true);
9957 assert_eq!(
9958 config.build_hsts_value(),
9959 Some("max-age=3600; preload".to_string())
9960 );
9961
9962 let config = SecurityHeadersConfig::none().hsts(3600, true, true);
9964 assert_eq!(
9965 config.build_hsts_value(),
9966 Some("max-age=3600; includeSubDomains; preload".to_string())
9967 );
9968 }
9969
9970 #[test]
9971 fn security_headers_middleware_adds_default_headers() {
9972 let mw = SecurityHeaders::new();
9973 let ctx = test_context();
9974 let req = Request::new(crate::request::Method::Get, "/");
9975 let response = Response::ok();
9976
9977 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
9978
9979 assert!(header_value(&result, "X-Content-Type-Options").is_some());
9981 assert!(header_value(&result, "X-Frame-Options").is_some());
9982 assert!(header_value(&result, "X-XSS-Protection").is_some());
9983 assert!(header_value(&result, "Referrer-Policy").is_some());
9984
9985 assert!(header_value(&result, "Content-Security-Policy").is_none());
9987 assert!(header_value(&result, "Strict-Transport-Security").is_none());
9988 assert!(header_value(&result, "Permissions-Policy").is_none());
9989 }
9990
9991 #[test]
9992 fn security_headers_middleware_with_csp() {
9993 let config = SecurityHeadersConfig::new()
9994 .content_security_policy("default-src 'self'; script-src 'self' 'unsafe-inline'");
9995 let mw = SecurityHeaders::with_config(config);
9996 let ctx = test_context();
9997 let req = Request::new(crate::request::Method::Get, "/");
9998 let response = Response::ok();
9999
10000 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10001
10002 let csp = header_value(&result, "Content-Security-Policy");
10003 assert!(csp.is_some());
10004 assert_eq!(
10005 csp.unwrap(),
10006 "default-src 'self'; script-src 'self' 'unsafe-inline'"
10007 );
10008 }
10009
10010 #[test]
10011 fn security_headers_middleware_with_hsts() {
10012 let config = SecurityHeadersConfig::new().hsts(31536000, true, false);
10013 let mw = SecurityHeaders::with_config(config);
10014 let ctx = test_context();
10015 let req = Request::new(crate::request::Method::Get, "/");
10016 let response = Response::ok();
10017
10018 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10019
10020 let hsts = header_value(&result, "Strict-Transport-Security");
10021 assert!(hsts.is_some());
10022 assert_eq!(hsts.unwrap(), "max-age=31536000; includeSubDomains");
10023 }
10024
10025 #[test]
10026 fn security_headers_middleware_name() {
10027 let mw = SecurityHeaders::new();
10028 assert_eq!(mw.name(), "SecurityHeaders");
10029 }
10030
10031 #[test]
10032 fn x_frame_options_values() {
10033 assert_eq!(XFrameOptions::Deny.as_bytes(), b"DENY");
10034 assert_eq!(XFrameOptions::SameOrigin.as_bytes(), b"SAMEORIGIN");
10035 }
10036
10037 #[test]
10038 fn referrer_policy_values() {
10039 assert_eq!(ReferrerPolicy::NoReferrer.as_bytes(), b"no-referrer");
10040 assert_eq!(
10041 ReferrerPolicy::NoReferrerWhenDowngrade.as_bytes(),
10042 b"no-referrer-when-downgrade"
10043 );
10044 assert_eq!(ReferrerPolicy::Origin.as_bytes(), b"origin");
10045 assert_eq!(
10046 ReferrerPolicy::OriginWhenCrossOrigin.as_bytes(),
10047 b"origin-when-cross-origin"
10048 );
10049 assert_eq!(ReferrerPolicy::SameOrigin.as_bytes(), b"same-origin");
10050 assert_eq!(ReferrerPolicy::StrictOrigin.as_bytes(), b"strict-origin");
10051 assert_eq!(
10052 ReferrerPolicy::StrictOriginWhenCrossOrigin.as_bytes(),
10053 b"strict-origin-when-cross-origin"
10054 );
10055 assert_eq!(ReferrerPolicy::UnsafeUrl.as_bytes(), b"unsafe-url");
10056 }
10057
10058 #[test]
10059 fn security_headers_strict_preset() {
10060 let mw = SecurityHeaders::strict();
10061 let ctx = test_context();
10062 let req = Request::new(crate::request::Method::Get, "/");
10063 let response = Response::ok();
10064
10065 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10066
10067 assert!(header_value(&result, "X-Content-Type-Options").is_some());
10069 assert!(header_value(&result, "X-Frame-Options").is_some());
10070 assert!(header_value(&result, "Content-Security-Policy").is_some());
10071 assert!(header_value(&result, "Strict-Transport-Security").is_some());
10072 assert!(header_value(&result, "Referrer-Policy").is_some());
10073 assert!(header_value(&result, "Permissions-Policy").is_some());
10074 }
10075
10076 #[test]
10077 fn security_headers_config_clearing_methods() {
10078 let config = SecurityHeadersConfig::strict()
10079 .no_content_security_policy()
10080 .no_hsts()
10081 .no_permissions_policy();
10082
10083 assert!(config.content_security_policy.is_none());
10084 assert!(config.hsts.is_none());
10085 assert!(config.permissions_policy.is_none());
10086 }
10087
10088 #[test]
10093 fn csrf_token_generate_produces_unique_tokens() {
10094 let token1 = CsrfToken::generate();
10095 let token2 = CsrfToken::generate();
10096 assert_ne!(token1, token2);
10097 assert!(!token1.as_str().is_empty());
10098 assert!(!token2.as_str().is_empty());
10099 }
10100
10101 #[test]
10102 fn csrf_token_display() {
10103 let token = CsrfToken::new("test-token-123");
10104 assert_eq!(format!("{}", token), "test-token-123");
10105 }
10106
10107 #[test]
10108 fn csrf_config_defaults() {
10109 let config = CsrfConfig::default();
10110 assert_eq!(config.cookie_name, "csrf_token");
10111 assert_eq!(config.header_name, "x-csrf-token");
10112 assert_eq!(config.mode, CsrfMode::DoubleSubmit);
10113 assert!(!config.rotate_token);
10114 assert!(config.production);
10115 assert!(config.error_message.is_none());
10116 }
10117
10118 #[test]
10119 fn csrf_config_builder() {
10120 let config = CsrfConfig::new()
10121 .cookie_name("XSRF-TOKEN")
10122 .header_name("X-XSRF-Token")
10123 .mode(CsrfMode::HeaderOnly)
10124 .rotate_token(true)
10125 .production(false)
10126 .error_message("Custom CSRF error");
10127
10128 assert_eq!(config.cookie_name, "XSRF-TOKEN");
10129 assert_eq!(config.header_name, "X-XSRF-Token");
10130 assert_eq!(config.mode, CsrfMode::HeaderOnly);
10131 assert!(config.rotate_token);
10132 assert!(!config.production);
10133 assert_eq!(config.error_message, Some("Custom CSRF error".to_string()));
10134 }
10135
10136 #[test]
10137 fn csrf_middleware_allows_get_without_token() {
10138 let csrf = CsrfMiddleware::new();
10139 let ctx = test_context();
10140 let mut req = Request::new(crate::request::Method::Get, "/");
10141
10142 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10143 assert!(result.is_continue());
10144 assert!(req.get_extension::<CsrfToken>().is_some());
10146 }
10147
10148 #[test]
10149 fn csrf_middleware_allows_head_without_token() {
10150 let csrf = CsrfMiddleware::new();
10151 let ctx = test_context();
10152 let mut req = Request::new(crate::request::Method::Head, "/");
10153
10154 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10155 assert!(result.is_continue());
10156 }
10157
10158 #[test]
10159 fn csrf_middleware_allows_options_without_token() {
10160 let csrf = CsrfMiddleware::new();
10161 let ctx = test_context();
10162 let mut req = Request::new(crate::request::Method::Options, "/");
10163
10164 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10165 assert!(result.is_continue());
10166 }
10167
10168 #[test]
10169 fn csrf_middleware_blocks_post_without_token() {
10170 let csrf = CsrfMiddleware::new();
10171 let ctx = test_context();
10172 let mut req = Request::new(crate::request::Method::Post, "/");
10173
10174 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10175 assert!(result.is_break());
10176
10177 if let ControlFlow::Break(response) = result {
10178 assert_eq!(response.status(), StatusCode::FORBIDDEN);
10179 }
10180 }
10181
10182 #[test]
10183 fn csrf_middleware_blocks_put_without_token() {
10184 let csrf = CsrfMiddleware::new();
10185 let ctx = test_context();
10186 let mut req = Request::new(crate::request::Method::Put, "/");
10187
10188 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10189 assert!(result.is_break());
10190 }
10191
10192 #[test]
10193 fn csrf_middleware_blocks_delete_without_token() {
10194 let csrf = CsrfMiddleware::new();
10195 let ctx = test_context();
10196 let mut req = Request::new(crate::request::Method::Delete, "/");
10197
10198 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10199 assert!(result.is_break());
10200 }
10201
10202 #[test]
10203 fn csrf_middleware_blocks_patch_without_token() {
10204 let csrf = CsrfMiddleware::new();
10205 let ctx = test_context();
10206 let mut req = Request::new(crate::request::Method::Patch, "/");
10207
10208 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10209 assert!(result.is_break());
10210 }
10211
10212 #[test]
10213 fn csrf_middleware_allows_post_with_matching_tokens() {
10214 let csrf = CsrfMiddleware::new();
10215 let ctx = test_context();
10216 let mut req = Request::new(crate::request::Method::Post, "/");
10217
10218 let token = "valid-csrf-token-12345";
10220 req.headers_mut()
10221 .insert("cookie", format!("csrf_token={}", token).into_bytes());
10222 req.headers_mut()
10223 .insert("x-csrf-token", token.as_bytes().to_vec());
10224
10225 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10226 assert!(result.is_continue());
10227
10228 let stored_token = req.get_extension::<CsrfToken>().unwrap();
10230 assert_eq!(stored_token.as_str(), token);
10231 }
10232
10233 #[test]
10234 fn csrf_middleware_blocks_post_with_mismatched_tokens() {
10235 let csrf = CsrfMiddleware::new();
10236 let ctx = test_context();
10237 let mut req = Request::new(crate::request::Method::Post, "/");
10238
10239 req.headers_mut()
10241 .insert("cookie", b"csrf_token=token-in-cookie".to_vec());
10242 req.headers_mut()
10243 .insert("x-csrf-token", b"different-token".to_vec());
10244
10245 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10246 assert!(result.is_break());
10247
10248 if let ControlFlow::Break(response) = result {
10249 assert_eq!(response.status(), StatusCode::FORBIDDEN);
10250 }
10251 }
10252
10253 #[test]
10254 fn csrf_middleware_blocks_post_with_header_only_in_double_submit_mode() {
10255 let csrf = CsrfMiddleware::new();
10256 let ctx = test_context();
10257 let mut req = Request::new(crate::request::Method::Post, "/");
10258
10259 req.headers_mut()
10261 .insert("x-csrf-token", b"some-token".to_vec());
10262
10263 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10264 assert!(result.is_break());
10265 }
10266
10267 #[test]
10268 fn csrf_middleware_blocks_post_with_cookie_only_in_double_submit_mode() {
10269 let csrf = CsrfMiddleware::new();
10270 let ctx = test_context();
10271 let mut req = Request::new(crate::request::Method::Post, "/");
10272
10273 req.headers_mut()
10275 .insert("cookie", b"csrf_token=some-token".to_vec());
10276
10277 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10278 assert!(result.is_break());
10279 }
10280
10281 #[test]
10282 fn csrf_middleware_header_only_mode_accepts_header_token() {
10283 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10284 let ctx = test_context();
10285 let mut req = Request::new(crate::request::Method::Post, "/");
10286
10287 req.headers_mut()
10288 .insert("x-csrf-token", b"valid-token".to_vec());
10289
10290 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10291 assert!(result.is_continue());
10292 }
10293
10294 #[test]
10295 fn csrf_middleware_header_only_mode_rejects_empty_header() {
10296 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10297 let ctx = test_context();
10298 let mut req = Request::new(crate::request::Method::Post, "/");
10299
10300 req.headers_mut().insert("x-csrf-token", b"".to_vec());
10301
10302 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10303 assert!(result.is_break());
10304 }
10305
10306 #[test]
10307 fn csrf_middleware_sets_cookie_on_get() {
10308 let csrf = CsrfMiddleware::new();
10309 let ctx = test_context();
10310 let mut req = Request::new(crate::request::Method::Get, "/");
10311
10312 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10314
10315 let response = Response::ok();
10317 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10318
10319 let cookie_value = header_value(&result, "set-cookie");
10321 assert!(cookie_value.is_some());
10322
10323 let cookie_value = cookie_value.unwrap();
10324 assert!(cookie_value.starts_with("csrf_token="));
10325 assert!(cookie_value.contains("SameSite=Strict"));
10326 assert!(cookie_value.contains("Secure")); }
10328
10329 #[test]
10330 fn csrf_middleware_no_secure_in_dev_mode() {
10331 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(false));
10332 let ctx = test_context();
10333 let mut req = Request::new(crate::request::Method::Get, "/");
10334
10335 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10336
10337 let response = Response::ok();
10338 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10339
10340 let cookie_value = header_value(&result, "set-cookie").unwrap();
10341 assert!(!cookie_value.contains("Secure")); }
10343
10344 #[test]
10345 fn csrf_middleware_does_not_set_cookie_if_already_present() {
10346 let csrf = CsrfMiddleware::new();
10347 let ctx = test_context();
10348 let mut req = Request::new(crate::request::Method::Get, "/");
10349
10350 req.headers_mut()
10352 .insert("cookie", b"csrf_token=existing-token".to_vec());
10353
10354 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10355
10356 let response = Response::ok();
10357 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10358
10359 assert!(header_value(&result, "set-cookie").is_none());
10361 }
10362
10363 #[test]
10364 fn csrf_middleware_rotates_token_when_configured() {
10365 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
10366 let ctx = test_context();
10367 let mut req = Request::new(crate::request::Method::Get, "/");
10368
10369 req.headers_mut()
10371 .insert("cookie", b"csrf_token=old-token".to_vec());
10372
10373 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10374
10375 let response = Response::ok();
10376 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10377
10378 assert!(header_value(&result, "set-cookie").is_some());
10380 }
10381
10382 #[test]
10383 fn csrf_middleware_custom_header_name() {
10384 let csrf = CsrfMiddleware::with_config(
10385 CsrfConfig::new()
10386 .header_name("X-XSRF-Token")
10387 .cookie_name("XSRF-TOKEN"),
10388 );
10389 let ctx = test_context();
10390 let mut req = Request::new(crate::request::Method::Post, "/");
10391
10392 let token = "custom-token-value";
10393 req.headers_mut()
10394 .insert("cookie", format!("XSRF-TOKEN={}", token).into_bytes());
10395 req.headers_mut()
10396 .insert("x-xsrf-token", token.as_bytes().to_vec());
10397
10398 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10399 assert!(result.is_continue());
10400 }
10401
10402 #[test]
10403 fn csrf_middleware_error_response_is_json() {
10404 let csrf = CsrfMiddleware::new();
10405 let ctx = test_context();
10406 let mut req = Request::new(crate::request::Method::Post, "/");
10407
10408 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10409
10410 if let ControlFlow::Break(response) = result {
10411 let content_type = header_value(&response, "content-type");
10412 assert_eq!(content_type, Some("application/json".to_string()));
10413
10414 if let ResponseBody::Bytes(body) = response.body_ref() {
10416 let body_str = std::str::from_utf8(body).unwrap();
10417 assert!(body_str.contains("csrf_error"));
10418 assert!(body_str.contains("x-csrf-token"));
10419 } else {
10420 panic!("Expected Bytes body");
10421 }
10422 } else {
10423 panic!("Expected Break");
10424 }
10425 }
10426
10427 #[test]
10428 fn csrf_middleware_custom_error_message() {
10429 let csrf = CsrfMiddleware::with_config(
10430 CsrfConfig::new().error_message("Access denied: invalid security token"),
10431 );
10432 let ctx = test_context();
10433 let mut req = Request::new(crate::request::Method::Post, "/");
10434
10435 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10436
10437 if let ControlFlow::Break(response) = result {
10438 if let ResponseBody::Bytes(body) = response.body_ref() {
10439 let body_str = std::str::from_utf8(body).unwrap();
10440 assert!(body_str.contains("Access denied: invalid security token"));
10441 }
10442 }
10443 }
10444
10445 #[test]
10446 fn csrf_middleware_name() {
10447 let csrf = CsrfMiddleware::new();
10448 assert_eq!(csrf.name(), "CSRF");
10449 }
10450
10451 #[test]
10452 fn csrf_middleware_parses_cookie_with_multiple_cookies() {
10453 let csrf = CsrfMiddleware::new();
10454 let ctx = test_context();
10455 let mut req = Request::new(crate::request::Method::Post, "/");
10456
10457 let token = "the-csrf-token";
10459 req.headers_mut().insert(
10460 "cookie",
10461 format!("session=abc123; csrf_token={}; user=test", token).into_bytes(),
10462 );
10463 req.headers_mut()
10464 .insert("x-csrf-token", token.as_bytes().to_vec());
10465
10466 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10467 assert!(result.is_continue());
10468 }
10469
10470 #[test]
10471 fn csrf_middleware_handles_empty_token_value() {
10472 let csrf = CsrfMiddleware::new();
10473 let ctx = test_context();
10474 let mut req = Request::new(crate::request::Method::Post, "/");
10475
10476 req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10478 req.headers_mut().insert("x-csrf-token", b"".to_vec());
10479
10480 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10481 assert!(result.is_break()); }
10483
10484 #[test]
10487 fn csrf_token_generate_many_unique() {
10488 let mut tokens = std::collections::HashSet::new();
10490 for _ in 0..100 {
10491 let token = CsrfToken::generate();
10492 assert!(
10493 tokens.insert(token.0.clone()),
10494 "Duplicate token generated: {}",
10495 token.0
10496 );
10497 }
10498 assert_eq!(tokens.len(), 100);
10499 }
10500
10501 #[test]
10502 fn csrf_token_generate_format_is_hex() {
10503 let token = CsrfToken::generate();
10504 let s = token.as_str();
10505 assert!(
10507 s.len() >= 64,
10508 "Expected at least 64 hex characters, got {} in '{s}'",
10509 s.len()
10510 );
10511 assert!(
10512 s.chars().all(|c| c.is_ascii_hexdigit()),
10513 "Non-hex character in token: {s}"
10514 );
10515 }
10516
10517 #[test]
10518 fn csrf_token_generate_minimum_length() {
10519 let token = CsrfToken::generate();
10520 assert!(
10522 token.as_str().len() >= 64,
10523 "Token too short: {} (len={})",
10524 token.as_str(),
10525 token.as_str().len()
10526 );
10527 }
10528
10529 #[test]
10530 fn csrf_token_from_str() {
10531 let token: CsrfToken = "my-token".into();
10532 assert_eq!(token.as_str(), "my-token");
10533 assert_eq!(token.0, "my-token");
10534 }
10535
10536 #[test]
10537 fn csrf_token_clone_eq() {
10538 let t1 = CsrfToken::new("abc");
10539 let t2 = t1.clone();
10540 assert_eq!(t1, t2);
10541 assert_eq!(t1.as_str(), t2.as_str());
10542 }
10543
10544 #[test]
10545 fn csrf_middleware_allows_trace_without_token() {
10546 let csrf = CsrfMiddleware::new();
10547 let ctx = test_context();
10548 let mut req = Request::new(crate::request::Method::Trace, "/");
10549
10550 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10551 assert!(result.is_continue());
10552 assert!(req.get_extension::<CsrfToken>().is_some());
10554 }
10555
10556 #[test]
10557 fn csrf_safe_method_generates_token_into_extension() {
10558 let csrf = CsrfMiddleware::new();
10559 let ctx = test_context();
10560
10561 for method in [
10562 crate::request::Method::Get,
10563 crate::request::Method::Head,
10564 crate::request::Method::Options,
10565 crate::request::Method::Trace,
10566 ] {
10567 let mut req = Request::new(method, "/test");
10568 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10569 assert!(result.is_continue());
10570 let token = req.get_extension::<CsrfToken>().expect("token missing");
10571 assert!(!token.as_str().is_empty());
10572 }
10573 }
10574
10575 #[test]
10576 fn csrf_safe_method_preserves_existing_cookie_token() {
10577 let csrf = CsrfMiddleware::new();
10578 let ctx = test_context();
10579 let mut req = Request::new(crate::request::Method::Get, "/");
10580 req.headers_mut()
10581 .insert("cookie", b"csrf_token=my-existing-token".to_vec());
10582
10583 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10584
10585 let token = req.get_extension::<CsrfToken>().unwrap();
10587 assert_eq!(token.as_str(), "my-existing-token");
10588 }
10589
10590 #[test]
10591 fn csrf_valid_post_stores_token_in_extension() {
10592 let csrf = CsrfMiddleware::new();
10593 let ctx = test_context();
10594 let mut req = Request::new(crate::request::Method::Post, "/submit");
10595
10596 let tk = "valid-token-xyz";
10597 req.headers_mut()
10598 .insert("cookie", format!("csrf_token={}", tk).into_bytes());
10599 req.headers_mut()
10600 .insert("x-csrf-token", tk.as_bytes().to_vec());
10601
10602 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10603 assert!(result.is_continue());
10604 let stored = req.get_extension::<CsrfToken>().unwrap();
10605 assert_eq!(stored.as_str(), tk);
10606 }
10607
10608 #[test]
10609 fn csrf_double_submit_both_empty_strings_rejected() {
10610 let csrf = CsrfMiddleware::new();
10611 let ctx = test_context();
10612 let mut req = Request::new(crate::request::Method::Post, "/");
10613
10614 req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10616 req.headers_mut().insert("x-csrf-token", b"".to_vec());
10617
10618 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10619 assert!(result.is_break());
10620 }
10621
10622 #[test]
10623 fn csrf_double_submit_matching_empty_rejected() {
10624 let csrf = CsrfMiddleware::new();
10626 let ctx = test_context();
10627 let mut req = Request::new(crate::request::Method::Post, "/");
10628
10629 req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10630 req.headers_mut().insert("x-csrf-token", b"".to_vec());
10631
10632 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10633 assert!(
10634 result.is_break(),
10635 "Empty matching tokens should be rejected"
10636 );
10637 }
10638
10639 #[test]
10640 fn csrf_header_only_mode_does_not_need_cookie() {
10641 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10642 let ctx = test_context();
10643 let mut req = Request::new(crate::request::Method::Post, "/");
10644
10645 req.headers_mut()
10647 .insert("x-csrf-token", b"header-only-token".to_vec());
10648
10649 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10650 assert!(result.is_continue());
10651 let token = req.get_extension::<CsrfToken>().unwrap();
10652 assert_eq!(token.as_str(), "header-only-token");
10653 }
10654
10655 #[test]
10656 fn csrf_header_only_mode_ignores_mismatched_cookie() {
10657 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10659 let ctx = test_context();
10660 let mut req = Request::new(crate::request::Method::Post, "/");
10661
10662 req.headers_mut()
10663 .insert("cookie", b"csrf_token=different-value".to_vec());
10664 req.headers_mut()
10665 .insert("x-csrf-token", b"header-value".to_vec());
10666
10667 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10668 assert!(result.is_continue(), "HeaderOnly should ignore cookie");
10669 }
10670
10671 #[test]
10672 fn csrf_header_only_mode_rejects_no_header() {
10673 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10674 let ctx = test_context();
10675 let mut req = Request::new(crate::request::Method::Post, "/");
10676 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10678 assert!(result.is_break());
10679 }
10680
10681 #[test]
10682 fn csrf_header_only_error_message_mentions_header() {
10683 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10684 let ctx = test_context();
10685 let mut req = Request::new(crate::request::Method::Post, "/");
10686
10687 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10688 if let ControlFlow::Break(response) = result {
10689 if let ResponseBody::Bytes(body) = response.body_ref() {
10690 let body_str = std::str::from_utf8(body).unwrap();
10691 assert!(
10692 body_str.contains("missing in header"),
10693 "Expected 'missing in header' in: {}",
10694 body_str
10695 );
10696 }
10697 } else {
10698 panic!("Expected Break");
10699 }
10700 }
10701
10702 #[test]
10703 fn csrf_mismatch_error_differs_from_missing_error() {
10704 let csrf = CsrfMiddleware::new();
10705 let ctx = test_context();
10706
10707 let mut req_missing = Request::new(crate::request::Method::Post, "/");
10709 let missing_result = futures_executor::block_on(csrf.before(&ctx, &mut req_missing));
10710 let missing_body = match missing_result {
10711 ControlFlow::Break(r) => match r.body_ref() {
10712 ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
10713 ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
10714 },
10715 ControlFlow::Continue => panic!("Expected Break"),
10716 };
10717
10718 let mut req_mismatch = Request::new(crate::request::Method::Post, "/");
10720 req_mismatch
10721 .headers_mut()
10722 .insert("cookie", b"csrf_token=aaa".to_vec());
10723 req_mismatch
10724 .headers_mut()
10725 .insert("x-csrf-token", b"bbb".to_vec());
10726 let mismatch_result = futures_executor::block_on(csrf.before(&ctx, &mut req_mismatch));
10727 let mismatch_body = match mismatch_result {
10728 ControlFlow::Break(r) => match r.body_ref() {
10729 ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
10730 ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
10731 },
10732 ControlFlow::Continue => panic!("Expected Break"),
10733 };
10734
10735 assert_ne!(
10737 missing_body, mismatch_body,
10738 "Missing vs mismatch should have different error messages"
10739 );
10740 assert!(missing_body.contains("missing"));
10741 assert!(mismatch_body.contains("mismatch"));
10742 }
10743
10744 #[test]
10745 fn csrf_cookie_not_httponly() {
10746 let csrf = CsrfMiddleware::new();
10748 let ctx = test_context();
10749 let mut req = Request::new(crate::request::Method::Get, "/");
10750
10751 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10752 let response = Response::ok();
10753 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10754
10755 let cookie_value = header_value(&result, "set-cookie").unwrap();
10756 assert!(
10757 !cookie_value.to_lowercase().contains("httponly"),
10758 "CSRF cookie must NOT be HttpOnly (needs JS access), got: {}",
10759 cookie_value
10760 );
10761 }
10762
10763 #[test]
10764 fn csrf_cookie_has_path_slash() {
10765 let csrf = CsrfMiddleware::new();
10766 let ctx = test_context();
10767 let mut req = Request::new(crate::request::Method::Get, "/");
10768
10769 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10770 let response = Response::ok();
10771 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10772
10773 let cookie_value = header_value(&result, "set-cookie").unwrap();
10774 assert!(
10775 cookie_value.contains("Path=/"),
10776 "Cookie should have Path=/, got: {}",
10777 cookie_value
10778 );
10779 }
10780
10781 #[test]
10782 fn csrf_cookie_has_samesite_strict() {
10783 let csrf = CsrfMiddleware::new();
10784 let ctx = test_context();
10785 let mut req = Request::new(crate::request::Method::Get, "/");
10786
10787 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10788 let response = Response::ok();
10789 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10790
10791 let cookie_value = header_value(&result, "set-cookie").unwrap();
10792 assert!(
10793 cookie_value.contains("SameSite=Strict"),
10794 "Cookie should have SameSite=Strict, got: {}",
10795 cookie_value
10796 );
10797 }
10798
10799 #[test]
10800 fn csrf_production_mode_sets_secure_flag() {
10801 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(true));
10802 let ctx = test_context();
10803 let mut req = Request::new(crate::request::Method::Get, "/");
10804
10805 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10806 let response = Response::ok();
10807 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10808
10809 let cookie_value = header_value(&result, "set-cookie").unwrap();
10810 assert!(
10811 cookie_value.contains("Secure"),
10812 "Production cookie must have Secure flag, got: {}",
10813 cookie_value
10814 );
10815 }
10816
10817 #[test]
10818 fn csrf_no_set_cookie_on_post_response() {
10819 let csrf = CsrfMiddleware::new();
10821 let ctx = test_context();
10822 let mut req = Request::new(crate::request::Method::Post, "/");
10823
10824 let token = "valid-token";
10825 req.headers_mut()
10826 .insert("cookie", format!("csrf_token={}", token).into_bytes());
10827 req.headers_mut()
10828 .insert("x-csrf-token", token.as_bytes().to_vec());
10829
10830 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10831 let response = Response::ok();
10832 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10833
10834 assert!(
10835 header_value(&result, "set-cookie").is_none(),
10836 "POST response should not set CSRF cookie"
10837 );
10838 }
10839
10840 #[test]
10841 fn csrf_head_method_sets_cookie() {
10842 let csrf = CsrfMiddleware::new();
10843 let ctx = test_context();
10844 let mut req = Request::new(crate::request::Method::Head, "/");
10845
10846 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10847 let response = Response::ok();
10848 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10849
10850 assert!(
10851 header_value(&result, "set-cookie").is_some(),
10852 "HEAD response should set CSRF cookie"
10853 );
10854 }
10855
10856 #[test]
10857 fn csrf_options_method_sets_cookie() {
10858 let csrf = CsrfMiddleware::new();
10859 let ctx = test_context();
10860 let mut req = Request::new(crate::request::Method::Options, "/");
10861
10862 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10863 let response = Response::ok();
10864 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10865
10866 assert!(
10867 header_value(&result, "set-cookie").is_some(),
10868 "OPTIONS response should set CSRF cookie"
10869 );
10870 }
10871
10872 #[test]
10873 fn csrf_rotation_produces_different_token_in_cookie() {
10874 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
10875 let ctx = test_context();
10876 let mut req = Request::new(crate::request::Method::Get, "/");
10877
10878 let old_token = "old-token-value";
10879 req.headers_mut()
10880 .insert("cookie", format!("csrf_token={}", old_token).into_bytes());
10881
10882 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10883 let response = Response::ok();
10884 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10885
10886 let cookie_value = header_value(&result, "set-cookie").unwrap();
10887 assert!(cookie_value.starts_with("csrf_token="));
10892 }
10893
10894 #[test]
10895 fn csrf_no_rotation_skips_set_cookie_when_present() {
10896 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(false));
10897 let ctx = test_context();
10898 let mut req = Request::new(crate::request::Method::Get, "/");
10899
10900 req.headers_mut()
10901 .insert("cookie", b"csrf_token=existing".to_vec());
10902
10903 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10904 let response = Response::ok();
10905 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10906
10907 assert!(
10908 header_value(&result, "set-cookie").is_none(),
10909 "Without rotation, should not re-set existing cookie"
10910 );
10911 }
10912
10913 #[test]
10914 fn csrf_custom_cookie_name_in_set_cookie_response() {
10915 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().cookie_name("XSRF-TOKEN"));
10916 let ctx = test_context();
10917 let mut req = Request::new(crate::request::Method::Get, "/");
10918
10919 let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10920 let response = Response::ok();
10921 let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10922
10923 let cookie_value = header_value(&result, "set-cookie").unwrap();
10924 assert!(
10925 cookie_value.starts_with("XSRF-TOKEN="),
10926 "Custom cookie name should appear in Set-Cookie, got: {}",
10927 cookie_value
10928 );
10929 }
10930
10931 #[test]
10932 fn csrf_custom_header_name_validated() {
10933 let csrf = CsrfMiddleware::with_config(
10934 CsrfConfig::new()
10935 .header_name("X-Custom-CSRF")
10936 .cookie_name("my_csrf"),
10937 );
10938 let ctx = test_context();
10939 let mut req = Request::new(crate::request::Method::Post, "/");
10940
10941 let token = "custom-tok";
10942 req.headers_mut()
10943 .insert("cookie", format!("my_csrf={}", token).into_bytes());
10944 req.headers_mut()
10945 .insert("x-custom-csrf", token.as_bytes().to_vec());
10946
10947 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10948 assert!(result.is_continue());
10949 }
10950
10951 #[test]
10952 fn csrf_custom_header_name_wrong_header_rejected() {
10953 let csrf = CsrfMiddleware::with_config(CsrfConfig::new().header_name("X-Custom-CSRF"));
10954 let ctx = test_context();
10955 let mut req = Request::new(crate::request::Method::Post, "/");
10956
10957 let token = "some-token";
10958 req.headers_mut()
10959 .insert("cookie", format!("csrf_token={}", token).into_bytes());
10960 req.headers_mut()
10962 .insert("x-csrf-token", token.as_bytes().to_vec());
10963
10964 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10965 assert!(result.is_break(), "Wrong header name should be rejected");
10966 }
10967
10968 #[test]
10969 fn csrf_cookie_parsing_multiple_cookies_picks_correct() {
10970 let csrf = CsrfMiddleware::new();
10971 let ctx = test_context();
10972 let mut req = Request::new(crate::request::Method::Post, "/");
10973
10974 let token = "correct-csrf";
10975 req.headers_mut().insert(
10976 "cookie",
10977 format!("session=abc; other=xyz; csrf_token={}; tracking=123", token).into_bytes(),
10978 );
10979 req.headers_mut()
10980 .insert("x-csrf-token", token.as_bytes().to_vec());
10981
10982 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10983 assert!(result.is_continue());
10984 }
10985
10986 #[test]
10987 fn csrf_cookie_parsing_spaces_around_semicolons() {
10988 let csrf = CsrfMiddleware::new();
10989 let ctx = test_context();
10990 let mut req = Request::new(crate::request::Method::Post, "/");
10991
10992 let token = "spaced-token";
10993 req.headers_mut().insert(
10994 "cookie",
10995 format!("session=abc ; csrf_token={} ; other=xyz", token).into_bytes(),
10996 );
10997 req.headers_mut()
10998 .insert("x-csrf-token", token.as_bytes().to_vec());
10999
11000 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11001 assert!(result.is_continue());
11002 }
11003
11004 #[test]
11005 fn csrf_error_response_status_is_403() {
11006 let csrf = CsrfMiddleware::new();
11007 let ctx = test_context();
11008
11009 for method in [
11011 crate::request::Method::Post,
11012 crate::request::Method::Put,
11013 crate::request::Method::Delete,
11014 crate::request::Method::Patch,
11015 ] {
11016 let mut req = Request::new(method, "/");
11017 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11018 match result {
11019 ControlFlow::Break(response) => {
11020 assert_eq!(
11021 response.status(),
11022 StatusCode::FORBIDDEN,
11023 "Expected 403 for {:?}",
11024 method
11025 );
11026 }
11027 ControlFlow::Continue => panic!("Expected Break for {:?}", method),
11028 }
11029 }
11030 }
11031
11032 #[test]
11033 fn csrf_error_body_json_structure() {
11034 let csrf = CsrfMiddleware::new();
11035 let ctx = test_context();
11036 let mut req = Request::new(crate::request::Method::Post, "/");
11037
11038 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11039 if let ControlFlow::Break(response) = result {
11040 if let ResponseBody::Bytes(body) = response.body_ref() {
11041 let body_str = std::str::from_utf8(body).unwrap();
11042 let parsed: serde_json::Value = serde_json::from_str(body_str)
11044 .unwrap_or_else(|e| panic!("Invalid JSON: {}: {}", body_str, e));
11045 assert!(parsed["detail"].is_array());
11046 let detail = &parsed["detail"][0];
11047 assert_eq!(detail["type"], "csrf_error");
11048 assert!(detail["loc"].is_array());
11049 assert_eq!(detail["loc"][0], "header");
11050 assert_eq!(detail["loc"][1], "x-csrf-token");
11051 assert!(detail["msg"].is_string());
11052 } else {
11053 panic!("Expected Bytes body");
11054 }
11055 } else {
11056 panic!("Expected Break");
11057 }
11058 }
11059
11060 #[test]
11061 fn csrf_default_trait() {
11062 let csrf = CsrfMiddleware::default();
11063 assert_eq!(csrf.name(), "CSRF");
11064 let ctx = test_context();
11066 let mut req = Request::new(crate::request::Method::Get, "/");
11067 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11068 assert!(result.is_continue());
11069 }
11070
11071 #[test]
11072 fn csrf_mode_default_is_double_submit() {
11073 assert_eq!(CsrfMode::default(), CsrfMode::DoubleSubmit);
11074 }
11075
11076 #[test]
11077 fn csrf_double_submit_both_present_same_non_empty_passes() {
11078 let csrf = CsrfMiddleware::new();
11080 let ctx = test_context();
11081
11082 let token = "a1b2c3d4e5f6";
11083 let mut req = Request::new(crate::request::Method::Delete, "/resource/1");
11084 req.headers_mut()
11085 .insert("cookie", format!("csrf_token={}", token).into_bytes());
11086 req.headers_mut()
11087 .insert("x-csrf-token", token.as_bytes().to_vec());
11088
11089 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11090 assert!(result.is_continue());
11091 }
11092
11093 #[test]
11094 fn csrf_double_submit_case_sensitive() {
11095 let csrf = CsrfMiddleware::new();
11097 let ctx = test_context();
11098 let mut req = Request::new(crate::request::Method::Post, "/");
11099
11100 req.headers_mut()
11101 .insert("cookie", b"csrf_token=AbCdEf".to_vec());
11102 req.headers_mut().insert("x-csrf-token", b"abcdef".to_vec());
11103
11104 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11105 assert!(
11106 result.is_break(),
11107 "Token comparison should be case-sensitive"
11108 );
11109 }
11110
11111 #[test]
11112 fn csrf_token_cookie_extractor_reads_csrf_cookie() {
11113 use crate::extract::{CookieName, CsrfTokenCookie};
11115 assert_eq!(CsrfTokenCookie::NAME, "csrf_token");
11116 }
11117
11118 #[test]
11119 fn csrf_make_set_cookie_header_value_production() {
11120 let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", true);
11121 let s = std::str::from_utf8(&value).unwrap();
11122 assert!(s.contains("csrf_token=tok123"));
11123 assert!(s.contains("Path=/"));
11124 assert!(s.contains("SameSite=Strict"));
11125 assert!(s.contains("Secure"));
11126 assert!(!s.to_lowercase().contains("httponly"));
11127 }
11128
11129 #[test]
11130 fn csrf_make_set_cookie_header_value_development() {
11131 let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", false);
11132 let s = std::str::from_utf8(&value).unwrap();
11133 assert!(s.contains("csrf_token=tok123"));
11134 assert!(s.contains("Path=/"));
11135 assert!(s.contains("SameSite=Strict"));
11136 assert!(!s.contains("Secure"));
11137 }
11138
11139 #[test]
11140 fn csrf_before_after_full_cycle_get_then_post() {
11141 let csrf = CsrfMiddleware::new();
11143 let ctx = test_context();
11144
11145 let mut get_req = Request::new(crate::request::Method::Get, "/form");
11147 let _ = futures_executor::block_on(csrf.before(&ctx, &mut get_req));
11148 let get_response = Response::ok();
11149 let get_result = futures_executor::block_on(csrf.after(&ctx, &get_req, get_response));
11150
11151 let set_cookie = header_value(&get_result, "set-cookie").expect("GET should set cookie");
11152 let token_value = set_cookie
11154 .strip_prefix("csrf_token=")
11155 .unwrap()
11156 .split(';')
11157 .next()
11158 .unwrap();
11159 assert!(!token_value.is_empty());
11160
11161 let mut post_req = Request::new(crate::request::Method::Post, "/form");
11163 post_req
11164 .headers_mut()
11165 .insert("cookie", format!("csrf_token={}", token_value).into_bytes());
11166 post_req
11167 .headers_mut()
11168 .insert("x-csrf-token", token_value.as_bytes().to_vec());
11169
11170 let result = futures_executor::block_on(csrf.before(&ctx, &mut post_req));
11171 assert!(result.is_continue(), "POST with valid token should pass");
11172 }
11173
11174 #[test]
11175 fn csrf_all_state_changing_methods_require_token() {
11176 let csrf = CsrfMiddleware::new();
11177 let ctx = test_context();
11178
11179 for method in [
11180 crate::request::Method::Post,
11181 crate::request::Method::Put,
11182 crate::request::Method::Delete,
11183 crate::request::Method::Patch,
11184 ] {
11185 let mut req = Request::new(method, "/resource");
11186 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11187 assert!(
11188 result.is_break(),
11189 "{:?} without token should be rejected",
11190 method
11191 );
11192 }
11193 }
11194
11195 #[test]
11196 fn csrf_all_safe_methods_pass_without_token() {
11197 let csrf = CsrfMiddleware::new();
11198 let ctx = test_context();
11199
11200 for method in [
11201 crate::request::Method::Get,
11202 crate::request::Method::Head,
11203 crate::request::Method::Options,
11204 crate::request::Method::Trace,
11205 ] {
11206 let mut req = Request::new(method, "/resource");
11207 let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11208 assert!(
11209 result.is_continue(),
11210 "{:?} should be allowed without token",
11211 method
11212 );
11213 }
11214 }
11215
11216 #[derive(Clone)]
11223 struct OrderRecordingMiddleware {
11224 id: &'static str,
11225 log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11226 }
11227
11228 impl OrderRecordingMiddleware {
11229 fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11230 Self { id, log }
11231 }
11232 }
11233
11234 impl Middleware for OrderRecordingMiddleware {
11235 fn before<'a>(
11236 &'a self,
11237 _ctx: &'a RequestContext,
11238 _req: &'a mut Request,
11239 ) -> BoxFuture<'a, ControlFlow> {
11240 let id = self.id;
11241 let log = self.log.clone();
11242 Box::pin(async move {
11243 log.lock().unwrap().push(format!("{id}:before"));
11244 ControlFlow::Continue
11245 })
11246 }
11247
11248 fn after<'a>(
11249 &'a self,
11250 _ctx: &'a RequestContext,
11251 _req: &'a Request,
11252 response: Response,
11253 ) -> BoxFuture<'a, Response> {
11254 let id = self.id;
11255 let log = self.log.clone();
11256 Box::pin(async move {
11257 log.lock().unwrap().push(format!("{id}:after"));
11258 response
11259 })
11260 }
11261
11262 fn name(&self) -> &'static str {
11263 "OrderRecording"
11264 }
11265 }
11266
11267 struct ShortCircuitMiddleware {
11269 id: &'static str,
11270 log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11271 }
11272
11273 impl ShortCircuitMiddleware {
11274 fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11275 Self { id, log }
11276 }
11277 }
11278
11279 impl Middleware for ShortCircuitMiddleware {
11280 fn before<'a>(
11281 &'a self,
11282 _ctx: &'a RequestContext,
11283 _req: &'a mut Request,
11284 ) -> BoxFuture<'a, ControlFlow> {
11285 let id = self.id;
11286 let log = self.log.clone();
11287 Box::pin(async move {
11288 log.lock().unwrap().push(format!("{id}:before:break"));
11289 ControlFlow::Break(
11290 Response::with_status(StatusCode::FORBIDDEN)
11291 .body(ResponseBody::Bytes(b"short-circuited".to_vec())),
11292 )
11293 })
11294 }
11295
11296 fn after<'a>(
11297 &'a self,
11298 _ctx: &'a RequestContext,
11299 _req: &'a Request,
11300 response: Response,
11301 ) -> BoxFuture<'a, Response> {
11302 let id = self.id;
11303 let log = self.log.clone();
11304 Box::pin(async move {
11305 log.lock().unwrap().push(format!("{id}:after"));
11306 response
11307 })
11308 }
11309
11310 fn name(&self) -> &'static str {
11311 "ShortCircuit"
11312 }
11313 }
11314
11315 struct RecordingHandler {
11317 log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11318 }
11319
11320 impl RecordingHandler {
11321 fn new(log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11322 Self { log }
11323 }
11324 }
11325
11326 impl Handler for RecordingHandler {
11327 fn call<'a>(
11328 &'a self,
11329 _ctx: &'a RequestContext,
11330 _req: &'a mut Request,
11331 ) -> BoxFuture<'a, Response> {
11332 let log = self.log.clone();
11333 Box::pin(async move {
11334 log.lock().unwrap().push("handler".to_string());
11335 Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()))
11336 })
11337 }
11338 }
11339
11340 #[test]
11341 fn middleware_stack_three_middleware_onion_order() {
11342 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11346
11347 let mut stack = MiddlewareStack::new();
11348 stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11349 stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11350 stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
11351
11352 let handler = RecordingHandler::new(log.clone());
11353 let ctx = test_context();
11354 let mut req = Request::new(crate::request::Method::Get, "/");
11355
11356 let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11357
11358 let execution_log = log.lock().unwrap().clone();
11359 assert_eq!(
11360 execution_log,
11361 vec![
11362 "mw1:before",
11363 "mw2:before",
11364 "mw3:before",
11365 "handler",
11366 "mw3:after",
11367 "mw2:after",
11368 "mw1:after",
11369 ]
11370 );
11371 }
11372
11373 #[test]
11374 fn middleware_stack_short_circuit_runs_prior_after_hooks() {
11375 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11383
11384 let mut stack = MiddlewareStack::new();
11385 stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11386 stack.push(ShortCircuitMiddleware::new("mw2", log.clone()));
11387 stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
11388
11389 let handler = RecordingHandler::new(log.clone());
11390 let ctx = test_context();
11391 let mut req = Request::new(crate::request::Method::Get, "/");
11392
11393 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11394
11395 assert_eq!(response.status().as_u16(), 403);
11397
11398 let execution_log = log.lock().unwrap().clone();
11399 assert_eq!(
11402 execution_log,
11403 vec!["mw1:before", "mw2:before:break", "mw1:after",]
11404 );
11405 }
11406
11407 #[test]
11408 fn middleware_stack_first_middleware_short_circuits() {
11409 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11413
11414 let mut stack = MiddlewareStack::new();
11415 stack.push(ShortCircuitMiddleware::new("mw1", log.clone()));
11416 stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11417
11418 let handler = RecordingHandler::new(log.clone());
11419 let ctx = test_context();
11420 let mut req = Request::new(crate::request::Method::Get, "/");
11421
11422 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11423 assert_eq!(response.status().as_u16(), 403);
11424
11425 let execution_log = log.lock().unwrap().clone();
11426 assert_eq!(execution_log, vec!["mw1:before:break",]);
11428 }
11429
11430 #[test]
11431 fn middleware_stack_empty_runs_handler_only() {
11432 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11434
11435 let stack = MiddlewareStack::new();
11436 let handler = RecordingHandler::new(log.clone());
11437 let ctx = test_context();
11438 let mut req = Request::new(crate::request::Method::Get, "/");
11439
11440 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11441 assert_eq!(response.status().as_u16(), 200);
11442
11443 let execution_log = log.lock().unwrap().clone();
11444 assert_eq!(execution_log, vec!["handler"]);
11445 }
11446
11447 #[test]
11448 fn middleware_stack_single_middleware_ordering() {
11449 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11451
11452 let mut stack = MiddlewareStack::new();
11453 stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11454
11455 let handler = RecordingHandler::new(log.clone());
11456 let ctx = test_context();
11457 let mut req = Request::new(crate::request::Method::Get, "/");
11458
11459 let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11460
11461 let execution_log = log.lock().unwrap().clone();
11462 assert_eq!(execution_log, vec!["mw1:before", "handler", "mw1:after",]);
11463 }
11464
11465 #[test]
11466 fn middleware_stack_five_middleware_onion_order() {
11467 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11469
11470 let mut stack = MiddlewareStack::new();
11471 stack.push(OrderRecordingMiddleware::new("a", log.clone()));
11472 stack.push(OrderRecordingMiddleware::new("b", log.clone()));
11473 stack.push(OrderRecordingMiddleware::new("c", log.clone()));
11474 stack.push(OrderRecordingMiddleware::new("d", log.clone()));
11475 stack.push(OrderRecordingMiddleware::new("e", log.clone()));
11476
11477 let handler = RecordingHandler::new(log.clone());
11478 let ctx = test_context();
11479 let mut req = Request::new(crate::request::Method::Get, "/");
11480
11481 let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11482
11483 let execution_log = log.lock().unwrap().clone();
11484 assert_eq!(
11485 execution_log,
11486 vec![
11487 "a:before", "b:before", "c:before", "d:before", "e:before", "handler", "e:after",
11488 "d:after", "c:after", "b:after", "a:after",
11489 ]
11490 );
11491 }
11492
11493 #[test]
11494 fn middleware_stack_short_circuit_at_end_runs_prior_afters() {
11495 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11502
11503 let mut stack = MiddlewareStack::new();
11504 stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11505 stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11506 stack.push(ShortCircuitMiddleware::new("mw3", log.clone()));
11507
11508 let handler = RecordingHandler::new(log.clone());
11509 let ctx = test_context();
11510 let mut req = Request::new(crate::request::Method::Get, "/");
11511
11512 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11513 assert_eq!(response.status().as_u16(), 403);
11514
11515 let execution_log = log.lock().unwrap().clone();
11516 assert_eq!(
11518 execution_log,
11519 vec![
11520 "mw1:before",
11521 "mw2:before",
11522 "mw3:before:break",
11523 "mw2:after",
11524 "mw1:after",
11525 ]
11526 );
11527 }
11528
11529 struct ModifyingMiddleware {
11531 id: &'static str,
11532 log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11533 }
11534
11535 impl ModifyingMiddleware {
11536 fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11537 Self { id, log }
11538 }
11539 }
11540
11541 impl Middleware for ModifyingMiddleware {
11542 fn before<'a>(
11543 &'a self,
11544 _ctx: &'a RequestContext,
11545 req: &'a mut Request,
11546 ) -> BoxFuture<'a, ControlFlow> {
11547 let id = self.id;
11548 let log = self.log.clone();
11549 Box::pin(async move {
11550 req.headers_mut()
11552 .insert(format!("x-{id}-before"), b"true".to_vec());
11553 log.lock().unwrap().push(format!("{id}:before"));
11554 ControlFlow::Continue
11555 })
11556 }
11557
11558 fn after<'a>(
11559 &'a self,
11560 _ctx: &'a RequestContext,
11561 _req: &'a Request,
11562 response: Response,
11563 ) -> BoxFuture<'a, Response> {
11564 let id = self.id;
11565 let log = self.log.clone();
11566 Box::pin(async move {
11567 log.lock().unwrap().push(format!("{id}:after"));
11568 response.header(format!("x-{id}-after"), b"true".to_vec())
11570 })
11571 }
11572
11573 fn name(&self) -> &'static str {
11574 "Modifying"
11575 }
11576 }
11577
11578 #[test]
11579 fn middleware_stack_modifications_accumulate_correctly() {
11580 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11583
11584 let mut stack = MiddlewareStack::new();
11585 stack.push(ModifyingMiddleware::new("mw1", log.clone()));
11586 stack.push(ModifyingMiddleware::new("mw2", log.clone()));
11587 stack.push(ModifyingMiddleware::new("mw3", log.clone()));
11588
11589 let handler = RecordingHandler::new(log.clone());
11590 let ctx = test_context();
11591 let mut req = Request::new(crate::request::Method::Get, "/");
11592
11593 let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11594
11595 assert!(header_value(&response, "x-mw1-after").is_some());
11597 assert!(header_value(&response, "x-mw2-after").is_some());
11598 assert!(header_value(&response, "x-mw3-after").is_some());
11599
11600 assert!(req.headers().contains("x-mw1-before"));
11602 assert!(req.headers().contains("x-mw2-before"));
11603 assert!(req.headers().contains("x-mw3-before"));
11604 }
11605
11606 #[test]
11607 fn layer_wrap_maintains_middleware_order() {
11608 let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11610
11611 let layer = Layer::new(OrderRecordingMiddleware::new("layer", log.clone()));
11613
11614 let handler = RecordingHandler::new(log.clone());
11616 let layered_handler = layer.wrap(handler);
11617
11618 let ctx = test_context();
11619 let mut req = Request::new(crate::request::Method::Get, "/");
11620
11621 let _response = futures_executor::block_on(layered_handler.call(&ctx, &mut req));
11623
11624 let execution_log = log.lock().unwrap().clone();
11625 assert_eq!(
11626 execution_log,
11627 vec!["layer:before", "handler", "layer:after",]
11628 );
11629 }
11630}
11631
11632#[cfg(all(test, feature = "compression"))]
11637mod compression_tests {
11638 use super::*;
11639 use crate::request::Method;
11640 use crate::response::ResponseBody;
11641
11642 fn test_context() -> RequestContext {
11643 RequestContext::new(asupersync::Cx::for_testing(), 1)
11644 }
11645
11646 #[test]
11647 fn compression_config_defaults() {
11648 let config = CompressionConfig::default();
11649 assert_eq!(config.min_size, 1024);
11650 assert_eq!(config.level, 6);
11651 assert!(!config.skip_content_types.is_empty());
11652 }
11653
11654 #[test]
11655 fn compression_config_builder() {
11656 let config = CompressionConfig::new().min_size(512).level(9);
11657 assert_eq!(config.min_size, 512);
11658 assert_eq!(config.level, 9);
11659 }
11660
11661 #[test]
11662 fn compression_level_clamped() {
11663 let config = CompressionConfig::new().level(100);
11664 assert_eq!(config.level, 9);
11665
11666 let config = CompressionConfig::new().level(0);
11667 assert_eq!(config.level, 1);
11668 }
11669
11670 #[test]
11671 fn skip_content_type_exact_match() {
11672 let config = CompressionConfig::default();
11673 assert!(config.should_skip_content_type("image/jpeg"));
11674 assert!(config.should_skip_content_type("image/jpeg; charset=utf-8"));
11675 assert!(!config.should_skip_content_type("text/html"));
11676 }
11677
11678 #[test]
11679 fn skip_content_type_prefix_match() {
11680 let config = CompressionConfig::default();
11681 assert!(config.should_skip_content_type("video/mp4"));
11683 assert!(config.should_skip_content_type("video/webm"));
11684 assert!(config.should_skip_content_type("audio/mpeg"));
11685 }
11686
11687 #[test]
11688 fn compression_skips_small_responses() {
11689 let middleware = CompressionMiddleware::new();
11690 let ctx = test_context();
11691
11692 let mut req = Request::new(Method::Get, "/");
11694 req.headers_mut()
11695 .insert("accept-encoding", b"gzip".to_vec());
11696
11697 let response = Response::ok()
11699 .header("content-type", b"text/plain".to_vec())
11700 .body(ResponseBody::Bytes(b"Hello, World!".to_vec()));
11701
11702 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11704
11705 let has_encoding = result
11707 .headers()
11708 .iter()
11709 .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11710 assert!(!has_encoding, "Small response should not be compressed");
11711 }
11712
11713 #[test]
11714 fn compression_works_for_large_responses() {
11715 let config = CompressionConfig::new().min_size(10); let middleware = CompressionMiddleware::with_config(config);
11717 let ctx = test_context();
11718
11719 let mut req = Request::new(Method::Get, "/");
11721 req.headers_mut()
11722 .insert("accept-encoding", b"gzip".to_vec());
11723
11724 let body = "Hello, World! ".repeat(100);
11726 let original_size = body.len();
11727
11728 let response = Response::ok()
11729 .header("content-type", b"text/plain".to_vec())
11730 .body(ResponseBody::Bytes(body.into_bytes()));
11731
11732 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11734
11735 let encoding = result
11737 .headers()
11738 .iter()
11739 .find(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11740 assert!(encoding.is_some(), "Large response should be compressed");
11741
11742 let (_, value) = encoding.unwrap();
11743 assert_eq!(value, b"gzip");
11744
11745 let vary = result
11747 .headers()
11748 .iter()
11749 .find(|(name, _)| name.eq_ignore_ascii_case("vary"));
11750 assert!(vary.is_some(), "Should have Vary header");
11751
11752 if let ResponseBody::Bytes(compressed) = result.body_ref() {
11754 assert!(
11755 compressed.len() < original_size,
11756 "Compressed size should be smaller"
11757 );
11758 } else {
11759 panic!("Expected Bytes body");
11760 }
11761 }
11762
11763 #[test]
11764 fn compression_skips_without_accept_encoding() {
11765 let config = CompressionConfig::new().min_size(10);
11766 let middleware = CompressionMiddleware::with_config(config);
11767 let ctx = test_context();
11768
11769 let req = Request::new(Method::Get, "/");
11771
11772 let body = "Hello, World! ".repeat(100);
11773 let response = Response::ok()
11774 .header("content-type", b"text/plain".to_vec())
11775 .body(ResponseBody::Bytes(body.into_bytes()));
11776
11777 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11778
11779 let has_encoding = result
11781 .headers()
11782 .iter()
11783 .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11784 assert!(!has_encoding, "Should not compress without Accept-Encoding");
11785 }
11786
11787 #[test]
11788 fn compression_skips_already_compressed_content() {
11789 let config = CompressionConfig::new().min_size(10);
11790 let middleware = CompressionMiddleware::with_config(config);
11791 let ctx = test_context();
11792
11793 let mut req = Request::new(Method::Get, "/");
11795 req.headers_mut()
11796 .insert("accept-encoding", b"gzip".to_vec());
11797
11798 let body = "Some image data".repeat(100);
11800 let response = Response::ok()
11801 .header("content-type", b"image/jpeg".to_vec())
11802 .body(ResponseBody::Bytes(body.into_bytes()));
11803
11804 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11805
11806 let has_encoding = result
11808 .headers()
11809 .iter()
11810 .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11811 assert!(
11812 !has_encoding,
11813 "Should not compress already-compressed content types"
11814 );
11815 }
11816
11817 #[test]
11818 fn compression_skips_if_already_has_content_encoding() {
11819 let config = CompressionConfig::new().min_size(10);
11820 let middleware = CompressionMiddleware::with_config(config);
11821 let ctx = test_context();
11822
11823 let mut req = Request::new(Method::Get, "/");
11825 req.headers_mut()
11826 .insert("accept-encoding", b"gzip".to_vec());
11827
11828 let body = "Hello, World! ".repeat(100);
11830 let response = Response::ok()
11831 .header("content-type", b"text/plain".to_vec())
11832 .header("content-encoding", b"br".to_vec())
11833 .body(ResponseBody::Bytes(body.into_bytes()));
11834
11835 let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11836
11837 let encodings: Vec<_> = result
11839 .headers()
11840 .iter()
11841 .filter(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
11842 .collect();
11843
11844 assert_eq!(encodings.len(), 1);
11846 assert_eq!(encodings[0].1, b"br");
11847 }
11848
11849 #[test]
11850 fn accepts_gzip_parses_header_correctly() {
11851 let mut req = Request::new(Method::Get, "/");
11855 req.headers_mut()
11856 .insert("accept-encoding", b"gzip".to_vec());
11857 assert!(CompressionMiddleware::accepts_gzip(&req));
11858
11859 let mut req = Request::new(Method::Get, "/");
11861 req.headers_mut()
11862 .insert("accept-encoding", b"deflate, gzip, br".to_vec());
11863 assert!(CompressionMiddleware::accepts_gzip(&req));
11864
11865 let mut req = Request::new(Method::Get, "/");
11867 req.headers_mut()
11868 .insert("accept-encoding", b"gzip;q=1.0, identity;q=0.5".to_vec());
11869 assert!(CompressionMiddleware::accepts_gzip(&req));
11870
11871 let mut req = Request::new(Method::Get, "/");
11873 req.headers_mut().insert("accept-encoding", b"*".to_vec());
11874 assert!(CompressionMiddleware::accepts_gzip(&req));
11875
11876 let mut req = Request::new(Method::Get, "/");
11878 req.headers_mut()
11879 .insert("accept-encoding", b"deflate, br".to_vec());
11880 assert!(!CompressionMiddleware::accepts_gzip(&req));
11881
11882 let req_no_header = Request::new(Method::Get, "/");
11884 assert!(!CompressionMiddleware::accepts_gzip(&req_no_header));
11885 }
11886
11887 #[test]
11888 fn compression_middleware_name() {
11889 let middleware = CompressionMiddleware::new();
11890 assert_eq!(middleware.name(), "Compression");
11891 }
11892}
11893
11894#[cfg(test)]
11899mod request_inspection_tests {
11900 use super::*;
11901 use crate::request::Method;
11902 use crate::response::ResponseBody;
11903
11904 fn test_context() -> RequestContext {
11905 RequestContext::new(asupersync::Cx::for_testing(), 1)
11906 }
11907
11908 #[test]
11909 fn inspection_middleware_default_creates_normal_verbosity() {
11910 let mw = RequestInspectionMiddleware::new();
11911 assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
11912 assert_eq!(mw.slow_threshold_ms, 1000);
11913 assert_eq!(mw.max_body_preview, 2048);
11914 assert_eq!(mw.name(), "RequestInspection");
11915 }
11916
11917 #[test]
11918 fn inspection_middleware_builder_methods() {
11919 let mw = RequestInspectionMiddleware::new()
11920 .verbosity(InspectionVerbosity::Verbose)
11921 .slow_threshold_ms(500)
11922 .max_body_preview(4096)
11923 .log_config(LogConfig::development())
11924 .redact_header("x-api-key");
11925
11926 assert_eq!(mw.verbosity, InspectionVerbosity::Verbose);
11927 assert_eq!(mw.slow_threshold_ms, 500);
11928 assert_eq!(mw.max_body_preview, 4096);
11929 assert!(mw.redact_headers.contains("x-api-key"));
11930 assert!(mw.redact_headers.contains("authorization"));
11932 assert!(mw.redact_headers.contains("cookie"));
11933 }
11934
11935 #[test]
11936 fn inspection_before_continues_processing() {
11937 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
11938 let ctx = test_context();
11939 let mut req = Request::new(Method::Post, "/api/users");
11940
11941 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
11942 assert!(result.is_continue());
11943 }
11944
11945 #[test]
11946 fn inspection_after_returns_response_unchanged() {
11947 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
11948 let ctx = test_context();
11949 let mut req = Request::new(Method::Get, "/health");
11950
11951 let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
11953
11954 let response = Response::ok().body(ResponseBody::Bytes(b"OK".to_vec()));
11955
11956 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
11957 assert_eq!(result.status().as_u16(), 200);
11958 assert_eq!(result.body_ref().len(), 2);
11959 }
11960
11961 #[test]
11962 fn inspection_stores_start_extension() {
11963 let mw = RequestInspectionMiddleware::new();
11964 let ctx = test_context();
11965 let mut req = Request::new(Method::Get, "/");
11966
11967 let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
11968
11969 assert!(req.get_extension::<InspectionStart>().is_some());
11971 }
11972
11973 #[test]
11974 fn inspection_all_verbosity_levels_continue() {
11975 for verbosity in [
11976 InspectionVerbosity::Minimal,
11977 InspectionVerbosity::Normal,
11978 InspectionVerbosity::Verbose,
11979 ] {
11980 let mw = RequestInspectionMiddleware::new().verbosity(verbosity);
11981 let ctx = test_context();
11982 let mut req = Request::new(Method::Get, "/test");
11983 req.headers_mut()
11984 .insert("content-type", b"text/plain".to_vec());
11985
11986 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
11987 assert!(
11988 result.is_continue(),
11989 "Verbosity {verbosity:?} should continue"
11990 );
11991 }
11992 }
11993
11994 #[test]
11995 fn inspection_verbose_with_json_body() {
11996 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
11997 let ctx = test_context();
11998 let body = br#"{"name":"Alice","age":30}"#;
11999 let mut req = Request::new(Method::Post, "/api/users");
12000 req.headers_mut()
12001 .insert("content-type", b"application/json".to_vec());
12002 req.set_body(Body::Bytes(body.to_vec()));
12003
12004 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12005 assert!(result.is_continue());
12006 }
12007
12008 #[test]
12009 fn inspection_verbose_after_with_json_response() {
12010 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
12011 let ctx = test_context();
12012 let mut req = Request::new(Method::Get, "/api/users/1");
12013
12014 let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
12015
12016 let response = Response::ok()
12017 .header("content-type", b"application/json".to_vec())
12018 .body(ResponseBody::Bytes(br#"{"id":1,"name":"Alice"}"#.to_vec()));
12019
12020 let result = futures_executor::block_on(mw.after(&ctx, &req, response));
12021 assert_eq!(result.status().as_u16(), 200);
12022 }
12023
12024 #[test]
12025 fn inspection_redacts_sensitive_headers() {
12026 let mw = RequestInspectionMiddleware::new();
12027
12028 assert!(mw.redact_headers.contains("authorization"));
12030 assert!(mw.redact_headers.contains("proxy-authorization"));
12031 assert!(mw.redact_headers.contains("cookie"));
12032 assert!(mw.redact_headers.contains("set-cookie"));
12033 }
12034
12035 #[test]
12036 fn inspection_format_headers_redacts() {
12037 let mw = RequestInspectionMiddleware::new().redact_header("x-secret");
12038
12039 let headers = vec![
12040 ("content-type", b"text/plain".as_slice()),
12041 ("x-secret", b"my-secret-value".as_slice()),
12042 ("x-normal", b"visible".as_slice()),
12043 ];
12044
12045 let output = mw.format_inspection_headers(headers.into_iter());
12046 assert!(output.contains("content-type: text/plain"));
12047 assert!(output.contains("x-secret: [REDACTED]"));
12048 assert!(output.contains("x-normal: visible"));
12049 assert!(!output.contains("my-secret-value"));
12050 }
12051
12052 #[test]
12053 fn inspection_format_body_preview_truncates() {
12054 let mw = RequestInspectionMiddleware::new().max_body_preview(10);
12055
12056 let body = b"Hello, World! This is a long body.";
12057 let result = mw.format_body_preview(body, None);
12058 assert!(result.is_some());
12059 let text = result.unwrap();
12060 assert!(text.ends_with("..."));
12061 assert!(text.len() <= 15); }
12063
12064 #[test]
12065 fn inspection_format_body_preview_empty() {
12066 let mw = RequestInspectionMiddleware::new();
12067 assert!(mw.format_body_preview(b"", None).is_none());
12068 }
12069
12070 #[test]
12071 fn inspection_format_body_preview_zero_max() {
12072 let mw = RequestInspectionMiddleware::new().max_body_preview(0);
12073 assert!(mw.format_body_preview(b"hello", None).is_none());
12074 }
12075
12076 #[test]
12077 fn inspection_format_body_preview_json_pretty() {
12078 let mw = RequestInspectionMiddleware::new();
12079 let body = br#"{"key":"value","num":42}"#;
12080 let ct = b"application/json".as_slice();
12081 let result = mw.format_body_preview(body, Some(ct));
12082 assert!(result.is_some());
12083 let text = result.unwrap();
12084 assert!(text.contains('\n'));
12086 assert!(text.contains("\"key\": \"value\""));
12087 }
12088
12089 #[test]
12090 fn inspection_format_body_preview_non_json() {
12091 let mw = RequestInspectionMiddleware::new();
12092 let body = b"Hello, World!";
12093 let ct = b"text/plain".as_slice();
12094 let result = mw.format_body_preview(body, Some(ct));
12095 assert_eq!(result.unwrap(), "Hello, World!");
12096 }
12097
12098 #[test]
12099 fn inspection_format_body_preview_binary() {
12100 let mw = RequestInspectionMiddleware::new();
12101 let body: &[u8] = &[0xFF, 0xFE, 0xFD, 0x00];
12102 let result = mw.format_body_preview(body, None);
12103 assert!(result.is_some());
12104 assert!(result.unwrap().contains("binary"));
12105 }
12106
12107 #[test]
12108 fn try_pretty_json_valid_object() {
12109 let result = try_pretty_json(r#"{"a":"b","c":1}"#);
12110 assert!(result.is_some());
12111 let pretty = result.unwrap();
12112 assert!(pretty.contains('\n'));
12113 assert!(pretty.contains(" \"a\": \"b\""));
12114 }
12115
12116 #[test]
12117 fn try_pretty_json_valid_array() {
12118 let result = try_pretty_json(r"[1,2,3]");
12119 assert!(result.is_some());
12120 let pretty = result.unwrap();
12121 assert!(pretty.contains('\n'));
12122 }
12123
12124 #[test]
12125 fn try_pretty_json_empty_object() {
12126 let result = try_pretty_json("{}");
12127 assert!(result.is_some());
12128 assert_eq!(result.unwrap(), "{}");
12129 }
12130
12131 #[test]
12132 fn try_pretty_json_empty_array() {
12133 let result = try_pretty_json("[]");
12134 assert!(result.is_some());
12135 assert_eq!(result.unwrap(), "[]");
12136 }
12137
12138 #[test]
12139 fn try_pretty_json_not_json() {
12140 assert!(try_pretty_json("hello world").is_none());
12141 assert!(try_pretty_json("12345").is_none());
12142 }
12143
12144 #[test]
12145 fn try_pretty_json_nested() {
12146 let input = r#"{"user":{"name":"Alice","roles":["admin","user"]}}"#;
12147 let result = try_pretty_json(input);
12148 assert!(result.is_some());
12149 let pretty = result.unwrap();
12150 assert!(pretty.contains("\"user\":"));
12151 assert!(pretty.contains("\"name\": \"Alice\""));
12152 assert!(pretty.contains("\"roles\":"));
12153 }
12154
12155 #[test]
12156 fn try_pretty_json_with_escapes() {
12157 let input = r#"{"msg":"hello \"world\""}"#;
12158 let result = try_pretty_json(input);
12159 assert!(result.is_some());
12160 let pretty = result.unwrap();
12161 assert!(pretty.contains(r#"\"world\""#));
12162 }
12163
12164 #[test]
12165 fn inspection_name() {
12166 let mw = RequestInspectionMiddleware::new();
12167 assert_eq!(mw.name(), "RequestInspection");
12168 }
12169
12170 #[test]
12171 fn inspection_default_via_default_trait() {
12172 let mw = RequestInspectionMiddleware::default();
12173 assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
12174 assert_eq!(mw.slow_threshold_ms, 1000);
12175 }
12176
12177 #[test]
12178 fn inspection_with_query_string() {
12179 let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
12180 let ctx = test_context();
12181 let mut req = Request::new(Method::Get, "/search");
12182 req.set_query(Some("q=rust&page=1".to_string()));
12183
12184 let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12185 assert!(result.is_continue());
12186 }
12187
12188 #[test]
12189 fn inspection_response_body_stream() {
12190 let mw = RequestInspectionMiddleware::new();
12191 let result = mw.format_response_preview(&ResponseBody::Empty, None);
12192 assert!(result.is_none());
12193 }
12194}
12195
12196#[cfg(test)]
12201mod rate_limit_tests {
12202 use super::*;
12203 use crate::request::Method;
12204 use crate::response::{ResponseBody, StatusCode};
12205 use std::time::Duration;
12206
12207 fn test_context() -> RequestContext {
12208 RequestContext::new(asupersync::Cx::for_testing(), 1)
12209 }
12210
12211 fn run_rate_limit_before(mw: &RateLimitMiddleware, req: &mut Request) -> ControlFlow {
12212 let ctx = test_context();
12213 let fut = mw.before(&ctx, req);
12214 futures_executor::block_on(fut)
12215 }
12216
12217 fn run_rate_limit_after(mw: &RateLimitMiddleware, req: &Request, resp: Response) -> Response {
12218 let ctx = test_context();
12219 let fut = mw.after(&ctx, req, resp);
12220 futures_executor::block_on(fut)
12221 }
12222
12223 #[test]
12224 fn rate_limit_default_allows_requests() {
12225 let mw = RateLimitMiddleware::new();
12226 let mut req = Request::new(Method::Get, "/api/test");
12227 req.headers_mut()
12228 .insert("x-forwarded-for", b"192.168.1.1".to_vec());
12229
12230 let result = run_rate_limit_before(&mw, &mut req);
12231 assert!(result.is_continue(), "first request should be allowed");
12232 }
12233
12234 #[test]
12235 fn rate_limit_fixed_window_blocks_after_limit() {
12236 let mw = RateLimitMiddleware::builder()
12237 .requests(3)
12238 .per(Duration::from_secs(60))
12239 .algorithm(RateLimitAlgorithm::FixedWindow)
12240 .key_extractor(IpKeyExtractor)
12241 .build();
12242
12243 for i in 0..3 {
12244 let mut req = Request::new(Method::Get, "/api/test");
12245 req.headers_mut()
12246 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12247 let result = run_rate_limit_before(&mw, &mut req);
12248 assert!(
12249 result.is_continue(),
12250 "request {i} should be allowed within limit"
12251 );
12252 }
12253
12254 let mut req = Request::new(Method::Get, "/api/test");
12256 req.headers_mut()
12257 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12258 let result = run_rate_limit_before(&mw, &mut req);
12259 assert!(result.is_break(), "fourth request should be blocked");
12260
12261 if let ControlFlow::Break(resp) = result {
12263 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
12264 }
12265 }
12266
12267 #[test]
12268 fn rate_limit_different_keys_independent() {
12269 let mw = RateLimitMiddleware::builder()
12270 .requests(2)
12271 .per(Duration::from_secs(60))
12272 .algorithm(RateLimitAlgorithm::FixedWindow)
12273 .key_extractor(IpKeyExtractor)
12274 .build();
12275
12276 for _ in 0..2 {
12278 let mut req = Request::new(Method::Get, "/");
12279 req.headers_mut()
12280 .insert("x-forwarded-for", b"1.1.1.1".to_vec());
12281 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12282 }
12283
12284 let mut req = Request::new(Method::Get, "/");
12286 req.headers_mut()
12287 .insert("x-forwarded-for", b"1.1.1.1".to_vec());
12288 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12289
12290 let mut req = Request::new(Method::Get, "/");
12292 req.headers_mut()
12293 .insert("x-forwarded-for", b"2.2.2.2".to_vec());
12294 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12295 }
12296
12297 #[test]
12298 fn rate_limit_token_bucket_allows_burst() {
12299 let mw = RateLimitMiddleware::builder()
12300 .requests(5)
12301 .per(Duration::from_secs(60))
12302 .algorithm(RateLimitAlgorithm::TokenBucket)
12303 .key_extractor(IpKeyExtractor)
12304 .build();
12305
12306 for i in 0..5 {
12308 let mut req = Request::new(Method::Get, "/");
12309 req.headers_mut()
12310 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12311 let result = run_rate_limit_before(&mw, &mut req);
12312 assert!(result.is_continue(), "burst request {i} should be allowed");
12313 }
12314
12315 let mut req = Request::new(Method::Get, "/");
12317 req.headers_mut()
12318 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12319 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12320 }
12321
12322 #[test]
12323 fn rate_limit_sliding_window_basic() {
12324 let mw = RateLimitMiddleware::builder()
12325 .requests(3)
12326 .per(Duration::from_secs(60))
12327 .algorithm(RateLimitAlgorithm::SlidingWindow)
12328 .key_extractor(IpKeyExtractor)
12329 .build();
12330
12331 for i in 0..3 {
12332 let mut req = Request::new(Method::Get, "/");
12333 req.headers_mut()
12334 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12335 assert!(
12336 run_rate_limit_before(&mw, &mut req).is_continue(),
12337 "sliding window request {i} should be allowed"
12338 );
12339 }
12340
12341 let mut req = Request::new(Method::Get, "/");
12343 req.headers_mut()
12344 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12345 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12346 }
12347
12348 #[test]
12349 fn rate_limit_header_key_extractor() {
12350 let mw = RateLimitMiddleware::builder()
12351 .requests(2)
12352 .per(Duration::from_secs(60))
12353 .algorithm(RateLimitAlgorithm::FixedWindow)
12354 .key_extractor(HeaderKeyExtractor::new("x-api-key"))
12355 .build();
12356
12357 for _ in 0..2 {
12359 let mut req = Request::new(Method::Get, "/");
12360 req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
12361 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12362 }
12363
12364 let mut req = Request::new(Method::Get, "/");
12366 req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
12367 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12368
12369 let mut req = Request::new(Method::Get, "/");
12371 req.headers_mut().insert("x-api-key", b"key-xyz".to_vec());
12372 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12373 }
12374
12375 #[test]
12376 fn rate_limit_path_key_extractor() {
12377 let mw = RateLimitMiddleware::builder()
12378 .requests(1)
12379 .per(Duration::from_secs(60))
12380 .algorithm(RateLimitAlgorithm::FixedWindow)
12381 .key_extractor(PathKeyExtractor)
12382 .build();
12383
12384 let mut req = Request::new(Method::Get, "/api/a");
12385 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12386
12387 let mut req = Request::new(Method::Get, "/api/a");
12389 assert!(run_rate_limit_before(&mw, &mut req).is_break());
12390
12391 let mut req = Request::new(Method::Get, "/api/b");
12393 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12394 }
12395
12396 #[test]
12397 fn rate_limit_no_key_skips_limiting() {
12398 let mw = RateLimitMiddleware::builder()
12399 .requests(1)
12400 .per(Duration::from_secs(60))
12401 .algorithm(RateLimitAlgorithm::FixedWindow)
12402 .key_extractor(HeaderKeyExtractor::new("x-api-key"))
12403 .build();
12404
12405 let mut req = Request::new(Method::Get, "/");
12407 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12408
12409 for _ in 0..10 {
12411 let mut req = Request::new(Method::Get, "/");
12412 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12413 }
12414 }
12415
12416 #[test]
12417 fn rate_limit_response_headers_on_success() {
12418 let mw = RateLimitMiddleware::builder()
12419 .requests(10)
12420 .per(Duration::from_secs(60))
12421 .algorithm(RateLimitAlgorithm::FixedWindow)
12422 .key_extractor(IpKeyExtractor)
12423 .build();
12424
12425 let mut req = Request::new(Method::Get, "/");
12426 req.headers_mut()
12427 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12428 let cf = run_rate_limit_before(&mw, &mut req);
12429 assert!(cf.is_continue());
12430
12431 let resp = Response::with_status(StatusCode::OK);
12432 let resp = run_rate_limit_after(&mw, &req, resp);
12433
12434 let headers = resp.headers();
12436 let has_limit = headers
12437 .iter()
12438 .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
12439 let has_remaining = headers
12440 .iter()
12441 .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-remaining"));
12442 let has_reset = headers
12443 .iter()
12444 .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-reset"));
12445
12446 assert!(has_limit, "should have X-RateLimit-Limit header");
12447 assert!(has_remaining, "should have X-RateLimit-Remaining header");
12448 assert!(has_reset, "should have X-RateLimit-Reset header");
12449
12450 let limit_val = headers
12452 .iter()
12453 .find(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"))
12454 .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12455 .unwrap();
12456 assert_eq!(limit_val, "10");
12457 }
12458
12459 #[test]
12460 fn rate_limit_429_response_has_retry_after() {
12461 let mw = RateLimitMiddleware::builder()
12462 .requests(1)
12463 .per(Duration::from_secs(60))
12464 .algorithm(RateLimitAlgorithm::FixedWindow)
12465 .key_extractor(IpKeyExtractor)
12466 .build();
12467
12468 let mut req = Request::new(Method::Get, "/");
12470 req.headers_mut()
12471 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12472 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12473
12474 let mut req = Request::new(Method::Get, "/");
12476 req.headers_mut()
12477 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12478 let result = run_rate_limit_before(&mw, &mut req);
12479
12480 if let ControlFlow::Break(resp) = result {
12481 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
12482
12483 let has_retry = resp
12485 .headers()
12486 .iter()
12487 .any(|(n, _)| n.eq_ignore_ascii_case("retry-after"));
12488 assert!(has_retry, "429 response should have Retry-After header");
12489
12490 let has_ct = resp
12492 .headers()
12493 .iter()
12494 .any(|(n, v)| n.eq_ignore_ascii_case("content-type") && v == b"application/json");
12495 assert!(has_ct, "429 response should have JSON content type");
12496 } else {
12497 panic!("expected Break(429)");
12498 }
12499 }
12500
12501 #[test]
12502 fn rate_limit_no_headers_when_disabled() {
12503 let mw = RateLimitMiddleware::builder()
12504 .requests(10)
12505 .per(Duration::from_secs(60))
12506 .algorithm(RateLimitAlgorithm::FixedWindow)
12507 .key_extractor(IpKeyExtractor)
12508 .include_headers(false)
12509 .build();
12510
12511 let mut req = Request::new(Method::Get, "/");
12512 req.headers_mut()
12513 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12514 assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12515
12516 let resp = Response::with_status(StatusCode::OK);
12517 let resp = run_rate_limit_after(&mw, &req, resp);
12518
12519 let has_limit = resp
12520 .headers()
12521 .iter()
12522 .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
12523 assert!(
12524 !has_limit,
12525 "should NOT have rate limit headers when disabled"
12526 );
12527 }
12528
12529 #[test]
12530 fn rate_limit_custom_retry_message() {
12531 let mw = RateLimitMiddleware::builder()
12532 .requests(1)
12533 .per(Duration::from_secs(60))
12534 .algorithm(RateLimitAlgorithm::FixedWindow)
12535 .key_extractor(IpKeyExtractor)
12536 .retry_message("Slow down, partner!")
12537 .build();
12538
12539 let mut req = Request::new(Method::Get, "/");
12541 req.headers_mut()
12542 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12543 run_rate_limit_before(&mw, &mut req);
12544
12545 let mut req = Request::new(Method::Get, "/");
12547 req.headers_mut()
12548 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12549 if let ControlFlow::Break(resp) = run_rate_limit_before(&mw, &mut req) {
12550 if let ResponseBody::Bytes(body) = resp.body_ref() {
12551 let body_str = std::str::from_utf8(body).unwrap();
12552 assert!(
12553 body_str.contains("Slow down, partner!"),
12554 "expected custom message in body, got: {body_str}"
12555 );
12556 } else {
12557 panic!("expected Bytes body");
12558 }
12559 } else {
12560 panic!("expected Break(429)");
12561 }
12562 }
12563
12564 #[test]
12565 fn rate_limit_ip_extractor_x_forwarded_for() {
12566 let extractor = IpKeyExtractor;
12567 let mut req = Request::new(Method::Get, "/");
12568 req.headers_mut()
12569 .insert("x-forwarded-for", b"1.2.3.4, 5.6.7.8".to_vec());
12570 assert_eq!(extractor.extract_key(&req), Some("1.2.3.4".to_string()));
12571 }
12572
12573 #[test]
12574 fn rate_limit_ip_extractor_x_real_ip() {
12575 let extractor = IpKeyExtractor;
12576 let mut req = Request::new(Method::Get, "/");
12577 req.headers_mut().insert("x-real-ip", b"9.8.7.6".to_vec());
12578 assert_eq!(extractor.extract_key(&req), Some("9.8.7.6".to_string()));
12579 }
12580
12581 #[test]
12582 fn rate_limit_ip_extractor_fallback() {
12583 let extractor = IpKeyExtractor;
12584 let req = Request::new(Method::Get, "/");
12585 assert_eq!(extractor.extract_key(&req), Some("unknown".to_string()));
12586 }
12587
12588 #[test]
12590 fn connected_ip_extractor_with_remote_addr() {
12591 use std::net::{IpAddr, Ipv4Addr};
12592
12593 let extractor = ConnectedIpKeyExtractor;
12594 let mut req = Request::new(Method::Get, "/");
12595 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100))));
12596
12597 assert_eq!(
12598 extractor.extract_key(&req),
12599 Some("192.168.1.100".to_string())
12600 );
12601 }
12602
12603 #[test]
12604 fn connected_ip_extractor_without_remote_addr() {
12605 let extractor = ConnectedIpKeyExtractor;
12606 let req = Request::new(Method::Get, "/");
12607
12608 assert_eq!(extractor.extract_key(&req), None);
12610 }
12611
12612 #[test]
12613 fn connected_ip_extractor_ignores_headers() {
12614 use std::net::{IpAddr, Ipv4Addr};
12615
12616 let extractor = ConnectedIpKeyExtractor;
12617 let mut req = Request::new(Method::Get, "/");
12618 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
12619 req.headers_mut()
12621 .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12622
12623 assert_eq!(extractor.extract_key(&req), Some("10.0.0.1".to_string()));
12625 }
12626
12627 #[test]
12629 fn trusted_proxy_extractor_from_trusted_proxy() {
12630 use std::net::{IpAddr, Ipv4Addr};
12631
12632 let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
12633
12634 let mut req = Request::new(Method::Get, "/");
12635 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
12637 req.headers_mut()
12639 .insert("x-forwarded-for", b"203.0.113.50".to_vec());
12640
12641 assert_eq!(
12643 extractor.extract_key(&req),
12644 Some("203.0.113.50".to_string())
12645 );
12646 }
12647
12648 #[test]
12649 fn trusted_proxy_extractor_from_untrusted_direct() {
12650 use std::net::{IpAddr, Ipv4Addr};
12651
12652 let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
12653
12654 let mut req = Request::new(Method::Get, "/");
12655 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50))));
12657 req.headers_mut()
12659 .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12660
12661 assert_eq!(
12663 extractor.extract_key(&req),
12664 Some("203.0.113.50".to_string())
12665 );
12666 }
12667
12668 #[test]
12669 fn trusted_proxy_extractor_no_remote_addr() {
12670 let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12671
12672 let mut req = Request::new(Method::Get, "/");
12673 req.headers_mut()
12675 .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12676
12677 assert_eq!(extractor.extract_key(&req), None);
12678 }
12679
12680 #[test]
12681 fn trusted_proxy_extractor_loopback_ipv4() {
12682 use std::net::{IpAddr, Ipv4Addr};
12683
12684 let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12685
12686 let mut req = Request::new(Method::Get, "/");
12687 req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::LOCALHOST)));
12688 req.headers_mut()
12689 .insert("x-forwarded-for", b"8.8.8.8".to_vec());
12690
12691 assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
12692 }
12693
12694 #[test]
12695 fn trusted_proxy_extractor_loopback_ipv6() {
12696 use std::net::{IpAddr, Ipv6Addr};
12697
12698 let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12699
12700 let mut req = Request::new(Method::Get, "/");
12701 req.insert_extension(RemoteAddr(IpAddr::V6(Ipv6Addr::LOCALHOST)));
12702 req.headers_mut()
12703 .insert("x-forwarded-for", b"8.8.8.8".to_vec());
12704
12705 assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
12706 }
12707
12708 #[test]
12709 fn cidr_parsing() {
12710 assert!(parse_cidr("10.0.0.0/8").is_some());
12712 assert!(parse_cidr("192.168.1.0/24").is_some());
12713 assert!(parse_cidr("0.0.0.0/0").is_some());
12714 assert!(parse_cidr("::1/128").is_some());
12715 assert!(parse_cidr("::/0").is_some());
12716
12717 assert!(parse_cidr("10.0.0.0/33").is_none()); assert!(parse_cidr("invalid").is_none());
12720 assert!(parse_cidr("10.0.0.0").is_none()); }
12722
12723 #[test]
12724 fn ip_in_cidr_matching() {
12725 use std::net::{IpAddr, Ipv4Addr};
12726
12727 let cidr_10 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0));
12728
12729 assert!(ip_in_cidr(
12731 IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
12732 cidr_10,
12733 8
12734 ));
12735 assert!(ip_in_cidr(
12736 IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255)),
12737 cidr_10,
12738 8
12739 ));
12740
12741 assert!(!ip_in_cidr(
12743 IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1)),
12744 cidr_10,
12745 8
12746 ));
12747 assert!(!ip_in_cidr(
12748 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
12749 cidr_10,
12750 8
12751 ));
12752 }
12753
12754 #[test]
12755 fn rate_limit_composite_key_extractor() {
12756 let extractor =
12757 CompositeKeyExtractor::new(vec![Box::new(IpKeyExtractor), Box::new(PathKeyExtractor)]);
12758
12759 let mut req = Request::new(Method::Get, "/api/users");
12760 req.headers_mut()
12761 .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12762
12763 let key = extractor.extract_key(&req);
12764 assert_eq!(key, Some("10.0.0.1:/api/users".to_string()));
12765 }
12766
12767 #[test]
12768 fn rate_limit_builder_defaults() {
12769 let mw = RateLimitMiddleware::builder().build();
12770 assert_eq!(mw.config.max_requests, 100);
12771 assert_eq!(mw.config.window, Duration::from_secs(60));
12772 assert_eq!(mw.config.algorithm, RateLimitAlgorithm::TokenBucket);
12773 assert!(mw.config.include_headers);
12774 }
12775
12776 #[test]
12777 fn rate_limit_builder_per_minute() {
12778 let mw = RateLimitMiddleware::builder()
12779 .requests(50)
12780 .per_minute(2)
12781 .algorithm(RateLimitAlgorithm::SlidingWindow)
12782 .build();
12783 assert_eq!(mw.config.max_requests, 50);
12784 assert_eq!(mw.config.window, Duration::from_secs(120));
12785 assert_eq!(mw.config.algorithm, RateLimitAlgorithm::SlidingWindow);
12786 }
12787
12788 #[test]
12789 fn rate_limit_builder_per_hour() {
12790 let mw = RateLimitMiddleware::builder()
12791 .requests(1000)
12792 .per_hour(1)
12793 .build();
12794 assert_eq!(mw.config.window, Duration::from_secs(3600));
12795 }
12796
12797 #[test]
12798 fn rate_limit_middleware_name() {
12799 let mw = RateLimitMiddleware::new();
12800 assert_eq!(mw.name(), "RateLimit");
12801 }
12802
12803 #[test]
12804 fn rate_limit_default_via_default_trait() {
12805 let mw = RateLimitMiddleware::default();
12806 assert_eq!(mw.config.max_requests, 100);
12807 }
12808
12809 #[test]
12814 fn etag_middleware_generates_etag_for_get() {
12815 let mw = ETagMiddleware::new();
12816 let ctx = test_context();
12817 let req = Request::new(crate::request::Method::Get, "/resource");
12818
12819 let response = Response::ok()
12821 .header("content-type", b"application/json".to_vec())
12822 .body(ResponseBody::Bytes(br#"{"status":"ok"}"#.to_vec()));
12823
12824 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12825
12826 let etag = response
12828 .headers()
12829 .iter()
12830 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12831 assert!(etag.is_some(), "Response should have ETag header");
12832
12833 let etag_value = std::str::from_utf8(&etag.unwrap().1).unwrap();
12835 assert!(etag_value.starts_with('"'), "ETag should start with quote");
12836 assert!(etag_value.ends_with('"'), "ETag should end with quote");
12837 }
12838
12839 #[test]
12840 fn etag_middleware_returns_304_on_match() {
12841 let mw = ETagMiddleware::new();
12842 let ctx = test_context();
12843
12844 let req1 = Request::new(crate::request::Method::Get, "/resource");
12846 let body = br#"{"status":"ok"}"#.to_vec();
12847 let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
12848 let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
12849
12850 let etag = response1
12851 .headers()
12852 .iter()
12853 .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
12854 .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12855 .unwrap();
12856
12857 let mut req2 = Request::new(crate::request::Method::Get, "/resource");
12859 req2.headers_mut()
12860 .insert("if-none-match", etag.as_bytes().to_vec());
12861
12862 let response2 = Response::ok().body(ResponseBody::Bytes(body));
12863 let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
12864
12865 assert_eq!(response2.status().as_u16(), 304);
12867 assert!(response2.body_ref().is_empty());
12868 }
12869
12870 #[test]
12871 fn etag_middleware_returns_full_response_on_mismatch() {
12872 let mw = ETagMiddleware::new();
12873 let ctx = test_context();
12874
12875 let mut req = Request::new(crate::request::Method::Get, "/resource");
12876 req.headers_mut()
12877 .insert("if-none-match", b"\"old-etag\"".to_vec());
12878
12879 let body = br#"{"status":"updated"}"#.to_vec();
12880 let response = Response::ok().body(ResponseBody::Bytes(body.clone()));
12881 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12882
12883 assert_eq!(response.status().as_u16(), 200);
12885 assert!(!response.body_ref().is_empty());
12886 }
12887
12888 #[test]
12889 fn etag_middleware_weak_etag_generation() {
12890 let config = ETagConfig::new().weak(true);
12891 let mw = ETagMiddleware::with_config(config);
12892 let ctx = test_context();
12893 let req = Request::new(crate::request::Method::Get, "/resource");
12894
12895 let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
12896 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12897
12898 let etag = response
12899 .headers()
12900 .iter()
12901 .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
12902 .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12903 .unwrap();
12904
12905 assert!(etag.starts_with("W/"), "Weak ETag should start with W/");
12906 }
12907
12908 #[test]
12909 fn etag_middleware_skips_post_requests() {
12910 let mw = ETagMiddleware::new();
12911 let ctx = test_context();
12912 let req = Request::new(crate::request::Method::Post, "/resource");
12913
12914 let response = Response::ok().body(ResponseBody::Bytes(b"created".to_vec()));
12915 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12916
12917 let etag = response
12919 .headers()
12920 .iter()
12921 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12922 assert!(etag.is_none(), "POST should not have ETag");
12923 }
12924
12925 #[test]
12926 fn etag_middleware_handles_head_requests() {
12927 let mw = ETagMiddleware::new();
12928 let ctx = test_context();
12929 let req = Request::new(crate::request::Method::Head, "/resource");
12930
12931 let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
12932 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12933
12934 let etag = response
12936 .headers()
12937 .iter()
12938 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12939 assert!(etag.is_some(), "HEAD should have ETag");
12940 }
12941
12942 #[test]
12943 fn etag_middleware_disabled_mode() {
12944 let config = ETagConfig::new().mode(ETagMode::Disabled);
12945 let mw = ETagMiddleware::with_config(config);
12946 let ctx = test_context();
12947 let req = Request::new(crate::request::Method::Get, "/resource");
12948
12949 let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
12950 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12951
12952 let etag = response
12954 .headers()
12955 .iter()
12956 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12957 assert!(etag.is_none(), "Disabled mode should not add ETag");
12958 }
12959
12960 #[test]
12961 fn etag_middleware_min_size_filter() {
12962 let config = ETagConfig::new().min_size(1000);
12963 let mw = ETagMiddleware::with_config(config);
12964 let ctx = test_context();
12965 let req = Request::new(crate::request::Method::Get, "/resource");
12966
12967 let response = Response::ok().body(ResponseBody::Bytes(b"small".to_vec()));
12969 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12970
12971 let etag = response
12973 .headers()
12974 .iter()
12975 .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12976 assert!(etag.is_none(), "Small body should not get ETag");
12977 }
12978
12979 #[test]
12980 fn etag_middleware_preserves_existing_etag() {
12981 let config = ETagConfig::new().mode(ETagMode::Manual);
12982 let mw = ETagMiddleware::with_config(config);
12983 let ctx = test_context();
12984
12985 let mut req = Request::new(crate::request::Method::Get, "/resource");
12987 req.headers_mut()
12988 .insert("if-none-match", b"\"custom-etag\"".to_vec());
12989
12990 let response = Response::ok()
12992 .header("etag", b"\"custom-etag\"".to_vec())
12993 .body(ResponseBody::Bytes(b"data".to_vec()));
12994 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12995
12996 assert_eq!(response.status().as_u16(), 304);
12998 }
12999
13000 #[test]
13001 fn etag_middleware_wildcard_if_none_match() {
13002 let mw = ETagMiddleware::new();
13003 let ctx = test_context();
13004 let mut req = Request::new(crate::request::Method::Get, "/resource");
13005 req.headers_mut().insert("if-none-match", b"*".to_vec());
13006
13007 let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13008 let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13009
13010 assert_eq!(response.status().as_u16(), 304);
13012 }
13013
13014 #[test]
13015 fn etag_middleware_weak_comparison_matches() {
13016 let mw = ETagMiddleware::new();
13017 let ctx = test_context();
13018
13019 let req1 = Request::new(crate::request::Method::Get, "/resource");
13021 let body = b"test data".to_vec();
13022 let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
13023 let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
13024
13025 let etag = response1
13026 .headers()
13027 .iter()
13028 .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
13029 .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
13030 .unwrap();
13031
13032 let mut req2 = Request::new(crate::request::Method::Get, "/resource");
13034 let weak_etag = format!("W/{}", etag);
13035 req2.headers_mut()
13036 .insert("if-none-match", weak_etag.as_bytes().to_vec());
13037
13038 let response2 = Response::ok().body(ResponseBody::Bytes(body));
13039 let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
13040
13041 assert_eq!(response2.status().as_u16(), 304);
13043 }
13044
13045 #[test]
13046 fn etag_middleware_name() {
13047 let mw = ETagMiddleware::new();
13048 assert_eq!(mw.name(), "ETagMiddleware");
13049 }
13050
13051 #[test]
13052 fn etag_config_builder() {
13053 let config = ETagConfig::new()
13054 .mode(ETagMode::Auto)
13055 .weak(true)
13056 .min_size(512);
13057
13058 assert_eq!(config.mode, ETagMode::Auto);
13059 assert!(config.weak);
13060 assert_eq!(config.min_size, 512);
13061 }
13062
13063 #[test]
13064 fn etag_generates_consistent_hash() {
13065 let etag1 = ETagMiddleware::generate_etag(b"hello world", false);
13067 let etag2 = ETagMiddleware::generate_etag(b"hello world", false);
13068 assert_eq!(etag1, etag2);
13069
13070 let etag3 = ETagMiddleware::generate_etag(b"hello world!", false);
13072 assert_ne!(etag1, etag3);
13073 }
13074}