1use constant_time_eq::constant_time_eq;
19use std::any::Any;
20use std::collections::HashMap;
21use std::fmt;
22use std::net::SocketAddr;
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use async_trait::async_trait;
27use thiserror::Error;
28use tracing::{debug, info, warn};
29
30use crate::metrics::MetricsCollector;
31
32#[derive(Error, Debug)]
38pub enum MiddlewareError {
39 #[error("Authentication failed: {0}")]
40 AuthFailed(String),
41
42 #[error("Rate limited: {0}")]
43 RateLimited(String),
44
45 #[error("Internal middleware error: {0}")]
46 Internal(String),
47
48 #[error("Pipeline error: {0}")]
49 Pipeline(String),
50}
51
52pub type Result<T> = std::result::Result<T, MiddlewareError>;
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum ResponseStatus {
61 Ok,
62 Error,
63 RateLimited,
64 Unauthorized,
65}
66
67impl fmt::Display for ResponseStatus {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 match self {
70 Self::Ok => write!(f, "OK"),
71 Self::Error => write!(f, "Error"),
72 Self::RateLimited => write!(f, "RateLimited"),
73 Self::Unauthorized => write!(f, "Unauthorized"),
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct Response {
81 pub status: ResponseStatus,
82 pub body: Option<Vec<u8>>,
83 pub headers: HashMap<String, String>,
84 pub duration: Duration,
85}
86
87impl Response {
88 pub fn ok() -> Self {
90 Self {
91 status: ResponseStatus::Ok,
92 body: None,
93 headers: HashMap::new(),
94 duration: Duration::ZERO,
95 }
96 }
97
98 pub fn error(msg: impl Into<String>) -> Self {
100 Self {
101 status: ResponseStatus::Error,
102 body: Some(msg.into().into_bytes()),
103 headers: HashMap::new(),
104 duration: Duration::ZERO,
105 }
106 }
107
108 pub fn rate_limited(msg: impl Into<String>) -> Self {
110 Self {
111 status: ResponseStatus::RateLimited,
112 body: Some(msg.into().into_bytes()),
113 headers: HashMap::new(),
114 duration: Duration::ZERO,
115 }
116 }
117
118 pub fn unauthorized(msg: impl Into<String>) -> Self {
120 Self {
121 status: ResponseStatus::Unauthorized,
122 body: Some(msg.into().into_bytes()),
123 headers: HashMap::new(),
124 duration: Duration::ZERO,
125 }
126 }
127
128 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
130 self.headers.insert(key.into(), value.into());
131 self
132 }
133
134 pub fn with_body(mut self, body: Vec<u8>) -> Self {
136 self.body = Some(body);
137 self
138 }
139
140 pub fn with_duration(mut self, duration: Duration) -> Self {
142 self.duration = duration;
143 self
144 }
145}
146
147pub struct RequestContext {
156 pub request_id: String,
158 pub client_addr: Option<SocketAddr>,
160 pub method: String,
162 pub metadata: HashMap<String, String>,
164 pub start_time: Instant,
166 pub attributes: HashMap<String, Box<dyn Any + Send + Sync>>,
168}
169
170impl RequestContext {
171 pub fn new(method: impl Into<String>) -> Self {
173 Self {
174 request_id: uuid::Uuid::new_v4().to_string(),
175 client_addr: None,
176 method: method.into(),
177 metadata: HashMap::new(),
178 start_time: Instant::now(),
179 attributes: HashMap::new(),
180 }
181 }
182
183 pub fn with_client_addr(mut self, addr: SocketAddr) -> Self {
185 self.client_addr = Some(addr);
186 self
187 }
188
189 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
191 self.metadata.insert(key.into(), value.into());
192 self
193 }
194
195 pub fn set_attribute<T: Any + Send + Sync>(&mut self, key: impl Into<String>, value: T) {
197 self.attributes.insert(key.into(), Box::new(value));
198 }
199
200 pub fn get_attribute<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
202 self.attributes.get(key).and_then(|v| v.downcast_ref::<T>())
203 }
204
205 pub fn elapsed(&self) -> Duration {
207 self.start_time.elapsed()
208 }
209}
210
211impl fmt::Debug for RequestContext {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 f.debug_struct("RequestContext")
214 .field("request_id", &self.request_id)
215 .field("client_addr", &self.client_addr)
216 .field("method", &self.method)
217 .field("metadata", &self.metadata)
218 .field("start_time", &self.start_time)
219 .field("attributes_count", &self.attributes.len())
220 .finish()
221 }
222}
223
224#[async_trait]
230pub trait Next: Send + Sync {
231 async fn run(&self, ctx: &mut RequestContext) -> Result<Response>;
232}
233
234#[async_trait]
236pub trait Middleware: Send + Sync {
237 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response>;
239
240 fn name(&self) -> &str;
242
243 fn order(&self) -> i32 {
245 0
246 }
247}
248
249struct PipelineTail;
255
256#[async_trait]
257impl Next for PipelineTail {
258 async fn run(&self, _ctx: &mut RequestContext) -> Result<Response> {
259 Ok(Response::ok())
260 }
261}
262
263struct PipelineLink {
265 middleware: Arc<dyn Middleware>,
266 next: Arc<dyn Next>,
267}
268
269#[async_trait]
270impl Next for PipelineLink {
271 async fn run(&self, ctx: &mut RequestContext) -> Result<Response> {
272 self.middleware.process(ctx, self.next.as_ref()).await
273 }
274}
275
276pub struct MiddlewarePipeline {
284 chain: Arc<dyn Next>,
285}
286
287impl MiddlewarePipeline {
288 pub async fn execute(&self, ctx: &mut RequestContext) -> Result<Response> {
290 let result = self.chain.run(ctx).await;
291 match result {
293 Ok(mut resp) => {
294 resp.duration = ctx.elapsed();
295 Ok(resp)
296 }
297 Err(e) => Err(e),
298 }
299 }
300}
301
302pub struct MiddlewarePipelineBuilder {
304 middleware: Vec<Arc<dyn Middleware>>,
305}
306
307impl Default for MiddlewarePipelineBuilder {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313impl MiddlewarePipelineBuilder {
314 pub fn new() -> Self {
316 Self {
317 middleware: Vec::new(),
318 }
319 }
320
321 pub fn with<M: Middleware + 'static>(mut self, m: M) -> Self {
323 self.middleware.push(Arc::new(m));
324 self
325 }
326
327 pub fn add_arc(mut self, m: Arc<dyn Middleware>) -> Self {
329 self.middleware.push(m);
330 self
331 }
332
333 pub fn build(mut self) -> MiddlewarePipeline {
335 self.middleware.sort_by_key(|m| m.order());
337
338 let mut next: Arc<dyn Next> = Arc::new(PipelineTail);
340 for mw in self.middleware.into_iter().rev() {
341 next = Arc::new(PipelineLink {
342 middleware: mw,
343 next,
344 });
345 }
346
347 MiddlewarePipeline { chain: next }
348 }
349}
350
351pub struct LoggingMiddleware {
361 level: LogLevel,
362}
363
364#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub enum LogLevel {
367 Debug,
369 Info,
371}
372
373impl Default for LoggingMiddleware {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379impl LoggingMiddleware {
380 pub fn new() -> Self {
381 Self {
382 level: LogLevel::Info,
383 }
384 }
385
386 pub fn with_level(mut self, level: LogLevel) -> Self {
387 self.level = level;
388 self
389 }
390}
391
392#[async_trait]
393impl Middleware for LoggingMiddleware {
394 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
395 let method = ctx.method.clone();
396 let request_id = ctx.request_id.clone();
397 let client = ctx
398 .client_addr
399 .map_or_else(|| "unknown".to_string(), |a| a.to_string());
400
401 match self.level {
402 LogLevel::Info => info!(
403 request_id = %request_id,
404 method = %method,
405 client = %client,
406 "Request started"
407 ),
408 LogLevel::Debug => debug!(
409 request_id = %request_id,
410 method = %method,
411 client = %client,
412 "Request started"
413 ),
414 }
415
416 let result = next.run(ctx).await;
417
418 match &result {
419 Ok(resp) => match self.level {
420 LogLevel::Info => info!(
421 request_id = %request_id,
422 method = %method,
423 status = %resp.status,
424 duration_ms = %ctx.elapsed().as_millis(),
425 "Request completed"
426 ),
427 LogLevel::Debug => debug!(
428 request_id = %request_id,
429 method = %method,
430 status = %resp.status,
431 duration_ms = %ctx.elapsed().as_millis(),
432 "Request completed"
433 ),
434 },
435 Err(e) => warn!(
436 request_id = %request_id,
437 method = %method,
438 error = %e,
439 duration_ms = %ctx.elapsed().as_millis(),
440 "Request failed"
441 ),
442 }
443
444 result
445 }
446
447 fn name(&self) -> &str {
448 "logging"
449 }
450
451 fn order(&self) -> i32 {
452 -100
453 }
454}
455
456pub struct MetricsMiddleware {
462 collector: MetricsCollector,
463}
464
465impl MetricsMiddleware {
466 pub fn new(collector: MetricsCollector) -> Self {
467 Self { collector }
468 }
469}
470
471#[async_trait]
472impl Middleware for MetricsMiddleware {
473 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
474 let result = next.run(ctx).await;
475 let duration = ctx.elapsed();
476
477 self.collector.inc_requests();
478 self.collector.observe_request_latency(duration);
479
480 match &result {
481 Ok(resp) => {
482 if resp.status == ResponseStatus::Ok {
483 self.collector.inc_success();
484 } else {
485 self.collector.inc_failed();
486 }
487 }
488 Err(_) => {
489 self.collector.inc_failed();
490 }
491 }
492
493 result
494 }
495
496 fn name(&self) -> &str {
497 "metrics"
498 }
499
500 fn order(&self) -> i32 {
501 -90
502 }
503}
504
505pub struct TracingMiddleware;
511
512impl Default for TracingMiddleware {
513 fn default() -> Self {
514 Self::new()
515 }
516}
517
518impl TracingMiddleware {
519 pub fn new() -> Self {
520 Self
521 }
522}
523
524#[async_trait]
525impl Middleware for TracingMiddleware {
526 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
527 let span = tracing::info_span!(
528 "amaters.request",
529 "amaters.node_id" = "local",
530 "amaters.request_id" = %ctx.request_id,
531 method = %ctx.method,
532 client_addr = ?ctx.client_addr,
533 );
534
535 let _guard = span.enter();
536 next.run(ctx).await
537 }
538
539 fn name(&self) -> &str {
540 "tracing"
541 }
542
543 fn order(&self) -> i32 {
544 -95
545 }
546}
547
548pub struct OtelSpanMiddleware {
558 node_id: String,
559}
560
561impl OtelSpanMiddleware {
562 pub fn new(node_id: impl Into<String>) -> Self {
563 Self {
564 node_id: node_id.into(),
565 }
566 }
567}
568
569#[async_trait]
570impl Middleware for OtelSpanMiddleware {
571 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
572 let span = tracing::info_span!(
573 "amaters.server.request",
574 "amaters.node_id" = self.node_id.as_str(),
575 "amaters.request_id" = %ctx.request_id,
576 "amaters.method" = %ctx.method,
577 );
578
579 let _guard = span.enter();
580 next.run(ctx).await
581 }
582
583 fn name(&self) -> &str {
584 "otel_span"
585 }
586
587 fn order(&self) -> i32 {
588 -97
589 }
590}
591
592pub struct AuthMiddleware {
602 api_keys: HashMap<String, String>,
604 allow_anonymous: bool,
606}
607
608impl AuthMiddleware {
609 pub fn new(api_keys: HashMap<String, String>) -> Self {
610 Self {
611 api_keys,
612 allow_anonymous: false,
613 }
614 }
615
616 pub fn with_allow_anonymous(mut self, allow: bool) -> Self {
619 self.allow_anonymous = allow;
620 self
621 }
622}
623
624#[async_trait]
625impl Middleware for AuthMiddleware {
626 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
627 let auth_header = ctx.metadata.get("authorization").cloned();
628
629 match auth_header {
630 Some(key) => {
631 let key_bytes = key.as_bytes();
634 if let Some(user_id) = self
635 .api_keys
636 .iter()
637 .find(|(k, _)| constant_time_eq(k.as_bytes(), key_bytes))
638 .map(|(_, v)| v)
639 {
640 ctx.set_attribute("auth_principal", user_id.clone());
641 debug!(
642 request_id = %ctx.request_id,
643 user_id = %user_id,
644 "Authentication successful"
645 );
646 next.run(ctx).await
647 } else {
648 warn!(
649 request_id = %ctx.request_id,
650 "Authentication failed: invalid credentials"
651 );
652 Ok(Response::unauthorized("Invalid credentials"))
653 }
654 }
655 None => {
656 if self.allow_anonymous {
657 next.run(ctx).await
658 } else {
659 warn!(
660 request_id = %ctx.request_id,
661 "Authentication failed: no credentials provided"
662 );
663 Ok(Response::unauthorized("No credentials provided"))
664 }
665 }
666 }
667 }
668
669 fn name(&self) -> &str {
670 "auth"
671 }
672
673 fn order(&self) -> i32 {
674 -80
675 }
676}
677
678pub struct RateLimitMiddleware {
686 state: Arc<parking_lot::Mutex<RateLimitState>>,
687 max_tokens: u64,
688 refill_rate: f64, }
690
691struct RateLimitState {
692 tokens: f64,
693 last_refill: Instant,
694}
695
696impl RateLimitMiddleware {
697 pub fn new(max_tokens: u64, refill_rate: f64) -> Self {
700 Self {
701 state: Arc::new(parking_lot::Mutex::new(RateLimitState {
702 tokens: max_tokens as f64,
703 last_refill: Instant::now(),
704 })),
705 max_tokens,
706 refill_rate,
707 }
708 }
709
710 fn try_acquire(&self) -> bool {
711 let mut state = self.state.lock();
712 let now = Instant::now();
713 let elapsed = now.duration_since(state.last_refill).as_secs_f64();
714 state.tokens = (state.tokens + elapsed * self.refill_rate).min(self.max_tokens as f64);
715 state.last_refill = now;
716
717 if state.tokens >= 1.0 {
718 state.tokens -= 1.0;
719 true
720 } else {
721 false
722 }
723 }
724}
725
726#[async_trait]
727impl Middleware for RateLimitMiddleware {
728 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
729 if self.try_acquire() {
730 next.run(ctx).await
731 } else {
732 warn!(
733 request_id = %ctx.request_id,
734 "Rate limit exceeded"
735 );
736 Ok(Response::rate_limited("Rate limit exceeded"))
737 }
738 }
739
740 fn name(&self) -> &str {
741 "rate_limit"
742 }
743
744 fn order(&self) -> i32 {
745 -70
746 }
747}
748
749pub struct AdaptiveRateLimiter {
760 base_limit: u64,
761 current_limit: Arc<parking_lot::Mutex<u64>>,
762 error_window: Arc<parking_lot::Mutex<std::collections::VecDeque<bool>>>,
763 window_size: usize,
764 reduction_factor: f64,
765 recovery_factor: f64,
766 error_threshold: f64,
767}
768
769impl AdaptiveRateLimiter {
770 pub fn new(base_limit: u64) -> Self {
773 Self {
774 base_limit,
775 current_limit: Arc::new(parking_lot::Mutex::new(base_limit)),
776 error_window: Arc::new(parking_lot::Mutex::new(
777 std::collections::VecDeque::with_capacity(101),
778 )),
779 window_size: 100,
780 reduction_factor: 0.8,
781 recovery_factor: 1.05,
782 error_threshold: 0.1,
783 }
784 }
785
786 pub fn record_success(&self) {
788 self.push(false);
789 self.adjust();
790 }
791
792 pub fn record_error(&self) {
794 self.push(true);
795 self.adjust();
796 }
797
798 pub fn current_limit(&self) -> u64 {
800 *self.current_limit.lock()
801 }
802
803 fn push(&self, is_error: bool) {
804 let mut window = self.error_window.lock();
805 if window.len() >= self.window_size {
806 window.pop_front();
807 }
808 window.push_back(is_error);
809 }
810
811 fn adjust(&self) {
812 let error_rate = {
813 let window = self.error_window.lock();
814 if window.is_empty() {
815 return;
816 }
817 let errors = window.iter().filter(|&&e| e).count();
818 errors as f64 / window.len() as f64
819 };
820
821 let mut limit = self.current_limit.lock();
822 if error_rate >= self.error_threshold {
823 let reduced = (*limit as f64 * self.reduction_factor).floor() as u64;
824 *limit = reduced.max(1);
825 } else {
826 let recovered = (*limit as f64 * self.recovery_factor).ceil() as u64;
827 *limit = recovered.min(self.base_limit);
828 }
829 }
830}
831
832pub struct AdaptiveRateLimitMiddleware {
839 limiter: Arc<AdaptiveRateLimiter>,
840 token_state: Arc<parking_lot::Mutex<RateLimitState>>,
841}
842
843impl AdaptiveRateLimitMiddleware {
844 pub fn new(base_limit: u64) -> Self {
845 let limiter = Arc::new(AdaptiveRateLimiter::new(base_limit));
846 Self {
847 token_state: Arc::new(parking_lot::Mutex::new(RateLimitState {
848 tokens: base_limit as f64,
849 last_refill: Instant::now(),
850 })),
851 limiter,
852 }
853 }
854
855 fn try_acquire(&self) -> bool {
856 let capacity = self.limiter.current_limit() as f64;
857 let mut state = self.token_state.lock();
858 let now = Instant::now();
859 let elapsed = now.duration_since(state.last_refill).as_secs_f64();
860 state.tokens = (state.tokens + elapsed * capacity).min(capacity);
862 state.last_refill = now;
863
864 if state.tokens >= 1.0 {
865 state.tokens -= 1.0;
866 true
867 } else {
868 false
869 }
870 }
871}
872
873#[async_trait]
874impl Middleware for AdaptiveRateLimitMiddleware {
875 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
876 if self.try_acquire() {
877 let result = next.run(ctx).await;
878 match &result {
879 Ok(resp) if resp.status == ResponseStatus::Ok => self.limiter.record_success(),
880 _ => self.limiter.record_error(),
881 }
882 result
883 } else {
884 self.limiter.record_error();
885 warn!(
886 request_id = %ctx.request_id,
887 "Adaptive rate limit exceeded"
888 );
889 Ok(Response::rate_limited("Adaptive rate limit exceeded"))
890 }
891 }
892
893 fn name(&self) -> &str {
894 "adaptive_rate_limit"
895 }
896
897 fn order(&self) -> i32 {
898 -65
899 }
900}
901
902#[cfg(test)]
907mod tests {
908 use super::*;
909 use std::sync::atomic::{AtomicUsize, Ordering};
910
911 struct OrderRecorder {
915 id: i32,
916 log: Arc<parking_lot::Mutex<Vec<i32>>>,
917 }
918
919 #[async_trait]
920 impl Middleware for OrderRecorder {
921 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
922 self.log.lock().push(self.id);
923 next.run(ctx).await
924 }
925 fn name(&self) -> &str {
926 "order_recorder"
927 }
928 fn order(&self) -> i32 {
929 self.id
930 }
931 }
932
933 struct ShortCircuit;
935
936 #[async_trait]
937 impl Middleware for ShortCircuit {
938 async fn process(&self, _ctx: &mut RequestContext, _next: &dyn Next) -> Result<Response> {
939 Ok(Response::unauthorized("blocked"))
940 }
941 fn name(&self) -> &str {
942 "short_circuit"
943 }
944 fn order(&self) -> i32 {
945 0
946 }
947 }
948
949 struct AttributeSetter {
951 key: String,
952 value: String,
953 }
954
955 #[async_trait]
956 impl Middleware for AttributeSetter {
957 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
958 ctx.set_attribute(&self.key, self.value.clone());
959 next.run(ctx).await
960 }
961 fn name(&self) -> &str {
962 "attr_setter"
963 }
964 fn order(&self) -> i32 {
965 -10
966 }
967 }
968
969 struct AttributeReader {
971 key: String,
972 found: Arc<parking_lot::Mutex<Option<String>>>,
973 }
974
975 #[async_trait]
976 impl Middleware for AttributeReader {
977 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
978 if let Some(val) = ctx.get_attribute::<String>(&self.key) {
979 *self.found.lock() = Some(val.clone());
980 }
981 next.run(ctx).await
982 }
983 fn name(&self) -> &str {
984 "attr_reader"
985 }
986 fn order(&self) -> i32 {
987 10
988 }
989 }
990
991 struct ErrorMiddleware;
993
994 #[async_trait]
995 impl Middleware for ErrorMiddleware {
996 async fn process(&self, _ctx: &mut RequestContext, _next: &dyn Next) -> Result<Response> {
997 Err(MiddlewareError::Internal("boom".to_string()))
998 }
999 fn name(&self) -> &str {
1000 "error"
1001 }
1002 }
1003
1004 struct CounterMiddleware {
1006 counter: Arc<AtomicUsize>,
1007 ord: i32,
1008 }
1009
1010 #[async_trait]
1011 impl Middleware for CounterMiddleware {
1012 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
1013 self.counter.fetch_add(1, Ordering::SeqCst);
1014 next.run(ctx).await
1015 }
1016 fn name(&self) -> &str {
1017 "counter"
1018 }
1019 fn order(&self) -> i32 {
1020 self.ord
1021 }
1022 }
1023
1024 #[tokio::test]
1027 async fn test_empty_pipeline_passes_through() {
1028 let pipeline = MiddlewarePipelineBuilder::new().build();
1029 let mut ctx = RequestContext::new("TEST");
1030 let resp = pipeline
1031 .execute(&mut ctx)
1032 .await
1033 .expect("empty pipeline should succeed");
1034 assert_eq!(resp.status, ResponseStatus::Ok);
1035 }
1036
1037 #[tokio::test]
1038 async fn test_pipeline_executes_in_order() {
1039 let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
1040
1041 let pipeline = MiddlewarePipelineBuilder::new()
1042 .with(OrderRecorder {
1043 id: 3,
1044 log: Arc::clone(&log),
1045 })
1046 .with(OrderRecorder {
1047 id: 1,
1048 log: Arc::clone(&log),
1049 })
1050 .with(OrderRecorder {
1051 id: 2,
1052 log: Arc::clone(&log),
1053 })
1054 .build();
1055
1056 let mut ctx = RequestContext::new("TEST");
1057 pipeline
1058 .execute(&mut ctx)
1059 .await
1060 .expect("pipeline should succeed");
1061
1062 let order = log.lock().clone();
1063 assert_eq!(
1064 order,
1065 vec![1, 2, 3],
1066 "middleware should run sorted by order()"
1067 );
1068 }
1069
1070 #[tokio::test]
1071 async fn test_short_circuit_on_auth_failure() {
1072 let counter = Arc::new(AtomicUsize::new(0));
1073
1074 let pipeline = MiddlewarePipelineBuilder::new()
1075 .with(ShortCircuit)
1076 .with(CounterMiddleware {
1077 counter: Arc::clone(&counter),
1078 ord: 10,
1079 })
1080 .build();
1081
1082 let mut ctx = RequestContext::new("TEST");
1083 let resp = pipeline
1084 .execute(&mut ctx)
1085 .await
1086 .expect("should get unauthorized response");
1087
1088 assert_eq!(resp.status, ResponseStatus::Unauthorized);
1089 assert_eq!(
1090 counter.load(Ordering::SeqCst),
1091 0,
1092 "downstream middleware must not run after short-circuit"
1093 );
1094 }
1095
1096 #[tokio::test]
1097 async fn test_context_attributes_passed_between_middleware() {
1098 let found = Arc::new(parking_lot::Mutex::new(None));
1099
1100 let pipeline = MiddlewarePipelineBuilder::new()
1101 .with(AttributeSetter {
1102 key: "user".to_string(),
1103 value: "alice".to_string(),
1104 })
1105 .with(AttributeReader {
1106 key: "user".to_string(),
1107 found: Arc::clone(&found),
1108 })
1109 .build();
1110
1111 let mut ctx = RequestContext::new("TEST");
1112 pipeline
1113 .execute(&mut ctx)
1114 .await
1115 .expect("pipeline should succeed");
1116
1117 let val = found.lock().clone();
1118 assert_eq!(val, Some("alice".to_string()));
1119 }
1120
1121 #[tokio::test]
1122 async fn test_metrics_recorded_correctly() {
1123 let collector = MetricsCollector::new();
1124
1125 let pipeline = MiddlewarePipelineBuilder::new()
1126 .with(MetricsMiddleware::new(collector.clone()))
1127 .build();
1128
1129 let mut ctx = RequestContext::new("GET");
1130 pipeline
1131 .execute(&mut ctx)
1132 .await
1133 .expect("pipeline should succeed");
1134
1135 let snapshot = collector.snapshot();
1136 assert_eq!(snapshot.requests_total, 1);
1137 assert_eq!(snapshot.requests_success, 1);
1138 assert_eq!(snapshot.requests_failed, 0);
1139 }
1140
1141 #[tokio::test]
1142 async fn test_rate_limit_blocks_request() {
1143 let rl = RateLimitMiddleware::new(1, 0.0);
1145
1146 let pipeline = MiddlewarePipelineBuilder::new().with(rl).build();
1147
1148 let mut ctx1 = RequestContext::new("GET");
1150 let r1 = pipeline
1151 .execute(&mut ctx1)
1152 .await
1153 .expect("first request should pass");
1154 assert_eq!(r1.status, ResponseStatus::Ok);
1155
1156 let mut ctx2 = RequestContext::new("GET");
1158 let r2 = pipeline
1159 .execute(&mut ctx2)
1160 .await
1161 .expect("second request should be rate-limited");
1162 assert_eq!(r2.status, ResponseStatus::RateLimited);
1163 }
1164
1165 #[tokio::test]
1166 async fn test_auth_middleware_valid_key() {
1167 let mut keys = HashMap::new();
1168 keys.insert("secret-key".to_string(), "user-42".to_string());
1169
1170 let pipeline = MiddlewarePipelineBuilder::new()
1171 .with(AuthMiddleware::new(keys))
1172 .build();
1173
1174 let mut ctx = RequestContext::new("GET").with_metadata("authorization", "secret-key");
1175 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1176 assert_eq!(resp.status, ResponseStatus::Ok);
1177
1178 let principal = ctx
1179 .get_attribute::<String>("auth_principal")
1180 .expect("principal should be set");
1181 assert_eq!(principal, "user-42");
1182 }
1183
1184 #[tokio::test]
1185 async fn test_auth_middleware_invalid_key() {
1186 let mut keys = HashMap::new();
1187 keys.insert("secret-key".to_string(), "user-42".to_string());
1188
1189 let pipeline = MiddlewarePipelineBuilder::new()
1190 .with(AuthMiddleware::new(keys))
1191 .build();
1192
1193 let mut ctx = RequestContext::new("GET").with_metadata("authorization", "wrong-key");
1194 let resp = pipeline
1195 .execute(&mut ctx)
1196 .await
1197 .expect("should get unauthorized");
1198 assert_eq!(resp.status, ResponseStatus::Unauthorized);
1199 }
1200
1201 #[tokio::test]
1202 async fn test_auth_middleware_no_credentials() {
1203 let keys = HashMap::new();
1204 let pipeline = MiddlewarePipelineBuilder::new()
1205 .with(AuthMiddleware::new(keys))
1206 .build();
1207
1208 let mut ctx = RequestContext::new("GET");
1209 let resp = pipeline
1210 .execute(&mut ctx)
1211 .await
1212 .expect("should get unauthorized");
1213 assert_eq!(resp.status, ResponseStatus::Unauthorized);
1214 }
1215
1216 #[tokio::test]
1217 async fn test_auth_middleware_anonymous_allowed() {
1218 let keys = HashMap::new();
1219 let pipeline = MiddlewarePipelineBuilder::new()
1220 .with(AuthMiddleware::new(keys).with_allow_anonymous(true))
1221 .build();
1222
1223 let mut ctx = RequestContext::new("GET");
1224 let resp = pipeline
1225 .execute(&mut ctx)
1226 .await
1227 .expect("should pass through");
1228 assert_eq!(resp.status, ResponseStatus::Ok);
1229 }
1230
1231 #[tokio::test]
1232 async fn test_error_propagation() {
1233 let pipeline = MiddlewarePipelineBuilder::new()
1234 .with(ErrorMiddleware)
1235 .build();
1236
1237 let mut ctx = RequestContext::new("GET");
1238 let result = pipeline.execute(&mut ctx).await;
1239 assert!(result.is_err());
1240 let err = result.expect_err("should be an error");
1241 assert!(
1242 err.to_string().contains("boom"),
1243 "error message should propagate"
1244 );
1245 }
1246
1247 #[tokio::test]
1248 async fn test_middleware_ordering_by_order() {
1249 let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
1250
1251 let pipeline = MiddlewarePipelineBuilder::new()
1253 .with(OrderRecorder {
1254 id: 50,
1255 log: Arc::clone(&log),
1256 })
1257 .with(OrderRecorder {
1258 id: 10,
1259 log: Arc::clone(&log),
1260 })
1261 .with(OrderRecorder {
1262 id: 30,
1263 log: Arc::clone(&log),
1264 })
1265 .with(OrderRecorder {
1266 id: 20,
1267 log: Arc::clone(&log),
1268 })
1269 .with(OrderRecorder {
1270 id: 40,
1271 log: Arc::clone(&log),
1272 })
1273 .build();
1274
1275 let mut ctx = RequestContext::new("TEST");
1276 pipeline
1277 .execute(&mut ctx)
1278 .await
1279 .expect("pipeline should succeed");
1280
1281 let order = log.lock().clone();
1282 assert_eq!(order, vec![10, 20, 30, 40, 50]);
1283 }
1284
1285 #[tokio::test]
1286 async fn test_response_duration_is_set() {
1287 let pipeline = MiddlewarePipelineBuilder::new().build();
1288 let mut ctx = RequestContext::new("TEST");
1289 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1290 let _ = resp.duration;
1293 }
1294
1295 #[tokio::test]
1296 async fn test_logging_middleware_runs() {
1297 let pipeline = MiddlewarePipelineBuilder::new()
1299 .with(LoggingMiddleware::new())
1300 .build();
1301
1302 let mut ctx = RequestContext::new("GET");
1303 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1304 assert_eq!(resp.status, ResponseStatus::Ok);
1305 }
1306
1307 #[tokio::test]
1308 async fn test_tracing_middleware_runs() {
1309 let pipeline = MiddlewarePipelineBuilder::new()
1310 .with(TracingMiddleware::new())
1311 .build();
1312
1313 let mut ctx = RequestContext::new("QUERY");
1314 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1315 assert_eq!(resp.status, ResponseStatus::Ok);
1316 }
1317
1318 #[tokio::test]
1319 async fn test_full_pipeline_integration() {
1320 let collector = MetricsCollector::new();
1321
1322 let mut api_keys = HashMap::new();
1323 api_keys.insert("valid-key".to_string(), "user-1".to_string());
1324
1325 let pipeline = MiddlewarePipelineBuilder::new()
1326 .with(LoggingMiddleware::new().with_level(LogLevel::Debug))
1327 .with(TracingMiddleware::new())
1328 .with(MetricsMiddleware::new(collector.clone()))
1329 .with(AuthMiddleware::new(api_keys))
1330 .with(RateLimitMiddleware::new(100, 100.0))
1331 .build();
1332
1333 let mut ctx = RequestContext::new("QUERY").with_metadata("authorization", "valid-key");
1335 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1336 assert_eq!(resp.status, ResponseStatus::Ok);
1337
1338 let snapshot = collector.snapshot();
1339 assert_eq!(snapshot.requests_total, 1);
1340 assert_eq!(snapshot.requests_success, 1);
1341 }
1342
1343 #[tokio::test]
1344 async fn test_pipeline_builder_default() {
1345 let builder = MiddlewarePipelineBuilder::default();
1346 let pipeline = builder.build();
1347 let mut ctx = RequestContext::new("TEST");
1348 let resp = pipeline
1349 .execute(&mut ctx)
1350 .await
1351 .expect("default pipeline should succeed");
1352 assert_eq!(resp.status, ResponseStatus::Ok);
1353 }
1354
1355 #[tokio::test]
1356 async fn test_request_context_debug() {
1357 let ctx = RequestContext::new("GET");
1358 let debug_str = format!("{:?}", ctx);
1359 assert!(debug_str.contains("RequestContext"));
1360 assert!(debug_str.contains("GET"));
1361 }
1362
1363 #[tokio::test]
1364 async fn test_response_status_display() {
1365 assert_eq!(ResponseStatus::Ok.to_string(), "OK");
1366 assert_eq!(ResponseStatus::Error.to_string(), "Error");
1367 assert_eq!(ResponseStatus::RateLimited.to_string(), "RateLimited");
1368 assert_eq!(ResponseStatus::Unauthorized.to_string(), "Unauthorized");
1369 }
1370
1371 #[tokio::test]
1372 async fn test_response_builders() {
1373 let r = Response::ok()
1374 .with_header("x-req", "123")
1375 .with_body(b"hello".to_vec());
1376 assert_eq!(r.status, ResponseStatus::Ok);
1377 assert_eq!(r.body, Some(b"hello".to_vec()));
1378 assert_eq!(r.headers.get("x-req"), Some(&"123".to_string()));
1379
1380 let r2 = Response::error("oops");
1381 assert_eq!(r2.status, ResponseStatus::Error);
1382 assert_eq!(r2.body, Some(b"oops".to_vec()));
1383 }
1384
1385 #[test]
1386 fn test_adaptive_rate_limiter_reduces_on_errors() {
1387 let limiter = AdaptiveRateLimiter::new(100);
1388 assert_eq!(limiter.current_limit(), 100);
1389
1390 for _ in 0..50 {
1392 limiter.record_error();
1393 }
1394 assert!(
1395 limiter.current_limit() < 100,
1396 "limit should have decreased after high error rate"
1397 );
1398 }
1399
1400 #[test]
1401 fn test_adaptive_rate_limiter_recovers() {
1402 let limiter = AdaptiveRateLimiter::new(100);
1403
1404 for _ in 0..50 {
1406 limiter.record_error();
1407 }
1408 let reduced = limiter.current_limit();
1409 assert!(reduced < 100, "limit should be reduced");
1410
1411 for _ in 0..200 {
1413 limiter.record_success();
1414 }
1415 assert!(
1416 limiter.current_limit() > reduced,
1417 "limit should recover after sustained successes"
1418 );
1419 }
1420}