1use std::any::Any;
19use std::collections::HashMap;
20use std::fmt;
21use std::net::SocketAddr;
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24
25use async_trait::async_trait;
26use thiserror::Error;
27use tracing::{debug, info, warn};
28
29use crate::metrics::MetricsCollector;
30
31#[derive(Error, Debug)]
37pub enum MiddlewareError {
38 #[error("Authentication failed: {0}")]
39 AuthFailed(String),
40
41 #[error("Rate limited: {0}")]
42 RateLimited(String),
43
44 #[error("Internal middleware error: {0}")]
45 Internal(String),
46
47 #[error("Pipeline error: {0}")]
48 Pipeline(String),
49}
50
51pub type Result<T> = std::result::Result<T, MiddlewareError>;
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ResponseStatus {
60 Ok,
61 Error,
62 RateLimited,
63 Unauthorized,
64}
65
66impl fmt::Display for ResponseStatus {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 match self {
69 Self::Ok => write!(f, "OK"),
70 Self::Error => write!(f, "Error"),
71 Self::RateLimited => write!(f, "RateLimited"),
72 Self::Unauthorized => write!(f, "Unauthorized"),
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct Response {
80 pub status: ResponseStatus,
81 pub body: Option<Vec<u8>>,
82 pub headers: HashMap<String, String>,
83 pub duration: Duration,
84}
85
86impl Response {
87 pub fn ok() -> Self {
89 Self {
90 status: ResponseStatus::Ok,
91 body: None,
92 headers: HashMap::new(),
93 duration: Duration::ZERO,
94 }
95 }
96
97 pub fn error(msg: impl Into<String>) -> Self {
99 Self {
100 status: ResponseStatus::Error,
101 body: Some(msg.into().into_bytes()),
102 headers: HashMap::new(),
103 duration: Duration::ZERO,
104 }
105 }
106
107 pub fn rate_limited(msg: impl Into<String>) -> Self {
109 Self {
110 status: ResponseStatus::RateLimited,
111 body: Some(msg.into().into_bytes()),
112 headers: HashMap::new(),
113 duration: Duration::ZERO,
114 }
115 }
116
117 pub fn unauthorized(msg: impl Into<String>) -> Self {
119 Self {
120 status: ResponseStatus::Unauthorized,
121 body: Some(msg.into().into_bytes()),
122 headers: HashMap::new(),
123 duration: Duration::ZERO,
124 }
125 }
126
127 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
129 self.headers.insert(key.into(), value.into());
130 self
131 }
132
133 pub fn with_body(mut self, body: Vec<u8>) -> Self {
135 self.body = Some(body);
136 self
137 }
138
139 pub fn with_duration(mut self, duration: Duration) -> Self {
141 self.duration = duration;
142 self
143 }
144}
145
146pub struct RequestContext {
155 pub request_id: String,
157 pub client_addr: Option<SocketAddr>,
159 pub method: String,
161 pub metadata: HashMap<String, String>,
163 pub start_time: Instant,
165 pub attributes: HashMap<String, Box<dyn Any + Send + Sync>>,
167}
168
169impl RequestContext {
170 pub fn new(method: impl Into<String>) -> Self {
172 Self {
173 request_id: uuid::Uuid::new_v4().to_string(),
174 client_addr: None,
175 method: method.into(),
176 metadata: HashMap::new(),
177 start_time: Instant::now(),
178 attributes: HashMap::new(),
179 }
180 }
181
182 pub fn with_client_addr(mut self, addr: SocketAddr) -> Self {
184 self.client_addr = Some(addr);
185 self
186 }
187
188 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
190 self.metadata.insert(key.into(), value.into());
191 self
192 }
193
194 pub fn set_attribute<T: Any + Send + Sync>(&mut self, key: impl Into<String>, value: T) {
196 self.attributes.insert(key.into(), Box::new(value));
197 }
198
199 pub fn get_attribute<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
201 self.attributes.get(key).and_then(|v| v.downcast_ref::<T>())
202 }
203
204 pub fn elapsed(&self) -> Duration {
206 self.start_time.elapsed()
207 }
208}
209
210impl fmt::Debug for RequestContext {
211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212 f.debug_struct("RequestContext")
213 .field("request_id", &self.request_id)
214 .field("client_addr", &self.client_addr)
215 .field("method", &self.method)
216 .field("metadata", &self.metadata)
217 .field("start_time", &self.start_time)
218 .field("attributes_count", &self.attributes.len())
219 .finish()
220 }
221}
222
223#[async_trait]
229pub trait Next: Send + Sync {
230 async fn run(&self, ctx: &mut RequestContext) -> Result<Response>;
231}
232
233#[async_trait]
235pub trait Middleware: Send + Sync {
236 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response>;
238
239 fn name(&self) -> &str;
241
242 fn order(&self) -> i32 {
244 0
245 }
246}
247
248struct PipelineTail;
254
255#[async_trait]
256impl Next for PipelineTail {
257 async fn run(&self, _ctx: &mut RequestContext) -> Result<Response> {
258 Ok(Response::ok())
259 }
260}
261
262struct PipelineLink {
264 middleware: Arc<dyn Middleware>,
265 next: Arc<dyn Next>,
266}
267
268#[async_trait]
269impl Next for PipelineLink {
270 async fn run(&self, ctx: &mut RequestContext) -> Result<Response> {
271 self.middleware.process(ctx, self.next.as_ref()).await
272 }
273}
274
275pub struct MiddlewarePipeline {
283 chain: Arc<dyn Next>,
284}
285
286impl MiddlewarePipeline {
287 pub async fn execute(&self, ctx: &mut RequestContext) -> Result<Response> {
289 let result = self.chain.run(ctx).await;
290 match result {
292 Ok(mut resp) => {
293 resp.duration = ctx.elapsed();
294 Ok(resp)
295 }
296 Err(e) => Err(e),
297 }
298 }
299}
300
301pub struct MiddlewarePipelineBuilder {
303 middleware: Vec<Arc<dyn Middleware>>,
304}
305
306impl Default for MiddlewarePipelineBuilder {
307 fn default() -> Self {
308 Self::new()
309 }
310}
311
312impl MiddlewarePipelineBuilder {
313 pub fn new() -> Self {
315 Self {
316 middleware: Vec::new(),
317 }
318 }
319
320 pub fn with<M: Middleware + 'static>(mut self, m: M) -> Self {
322 self.middleware.push(Arc::new(m));
323 self
324 }
325
326 pub fn add_arc(mut self, m: Arc<dyn Middleware>) -> Self {
328 self.middleware.push(m);
329 self
330 }
331
332 pub fn build(mut self) -> MiddlewarePipeline {
334 self.middleware.sort_by_key(|m| m.order());
336
337 let mut next: Arc<dyn Next> = Arc::new(PipelineTail);
339 for mw in self.middleware.into_iter().rev() {
340 next = Arc::new(PipelineLink {
341 middleware: mw,
342 next,
343 });
344 }
345
346 MiddlewarePipeline { chain: next }
347 }
348}
349
350pub struct LoggingMiddleware {
360 level: LogLevel,
361}
362
363#[derive(Debug, Clone, Copy, PartialEq, Eq)]
365pub enum LogLevel {
366 Debug,
368 Info,
370}
371
372impl Default for LoggingMiddleware {
373 fn default() -> Self {
374 Self::new()
375 }
376}
377
378impl LoggingMiddleware {
379 pub fn new() -> Self {
380 Self {
381 level: LogLevel::Info,
382 }
383 }
384
385 pub fn with_level(mut self, level: LogLevel) -> Self {
386 self.level = level;
387 self
388 }
389}
390
391#[async_trait]
392impl Middleware for LoggingMiddleware {
393 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
394 let method = ctx.method.clone();
395 let request_id = ctx.request_id.clone();
396 let client = ctx
397 .client_addr
398 .map_or_else(|| "unknown".to_string(), |a| a.to_string());
399
400 match self.level {
401 LogLevel::Info => info!(
402 request_id = %request_id,
403 method = %method,
404 client = %client,
405 "Request started"
406 ),
407 LogLevel::Debug => debug!(
408 request_id = %request_id,
409 method = %method,
410 client = %client,
411 "Request started"
412 ),
413 }
414
415 let result = next.run(ctx).await;
416
417 match &result {
418 Ok(resp) => match self.level {
419 LogLevel::Info => info!(
420 request_id = %request_id,
421 method = %method,
422 status = %resp.status,
423 duration_ms = %ctx.elapsed().as_millis(),
424 "Request completed"
425 ),
426 LogLevel::Debug => debug!(
427 request_id = %request_id,
428 method = %method,
429 status = %resp.status,
430 duration_ms = %ctx.elapsed().as_millis(),
431 "Request completed"
432 ),
433 },
434 Err(e) => warn!(
435 request_id = %request_id,
436 method = %method,
437 error = %e,
438 duration_ms = %ctx.elapsed().as_millis(),
439 "Request failed"
440 ),
441 }
442
443 result
444 }
445
446 fn name(&self) -> &str {
447 "logging"
448 }
449
450 fn order(&self) -> i32 {
451 -100
452 }
453}
454
455pub struct MetricsMiddleware {
461 collector: MetricsCollector,
462}
463
464impl MetricsMiddleware {
465 pub fn new(collector: MetricsCollector) -> Self {
466 Self { collector }
467 }
468}
469
470#[async_trait]
471impl Middleware for MetricsMiddleware {
472 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
473 let result = next.run(ctx).await;
474 let duration = ctx.elapsed();
475
476 self.collector.inc_requests();
477 self.collector.observe_request_latency(duration);
478
479 match &result {
480 Ok(resp) => {
481 if resp.status == ResponseStatus::Ok {
482 self.collector.inc_success();
483 } else {
484 self.collector.inc_failed();
485 }
486 }
487 Err(_) => {
488 self.collector.inc_failed();
489 }
490 }
491
492 result
493 }
494
495 fn name(&self) -> &str {
496 "metrics"
497 }
498
499 fn order(&self) -> i32 {
500 -90
501 }
502}
503
504pub struct TracingMiddleware;
510
511impl Default for TracingMiddleware {
512 fn default() -> Self {
513 Self::new()
514 }
515}
516
517impl TracingMiddleware {
518 pub fn new() -> Self {
519 Self
520 }
521}
522
523#[async_trait]
524impl Middleware for TracingMiddleware {
525 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
526 let span = tracing::info_span!(
527 "request",
528 request_id = %ctx.request_id,
529 method = %ctx.method,
530 client_addr = ?ctx.client_addr,
531 );
532
533 let _guard = span.enter();
534 next.run(ctx).await
535 }
536
537 fn name(&self) -> &str {
538 "tracing"
539 }
540
541 fn order(&self) -> i32 {
542 -95
543 }
544}
545
546pub struct AuthMiddleware {
556 api_keys: HashMap<String, String>,
558 allow_anonymous: bool,
560}
561
562impl AuthMiddleware {
563 pub fn new(api_keys: HashMap<String, String>) -> Self {
564 Self {
565 api_keys,
566 allow_anonymous: false,
567 }
568 }
569
570 pub fn with_allow_anonymous(mut self, allow: bool) -> Self {
573 self.allow_anonymous = allow;
574 self
575 }
576}
577
578#[async_trait]
579impl Middleware for AuthMiddleware {
580 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
581 let auth_header = ctx.metadata.get("authorization").cloned();
582
583 match auth_header {
584 Some(key) => {
585 if let Some(user_id) = self.api_keys.get(&key) {
587 ctx.set_attribute("auth_principal", user_id.clone());
588 debug!(
589 request_id = %ctx.request_id,
590 user_id = %user_id,
591 "Authentication successful"
592 );
593 next.run(ctx).await
594 } else {
595 warn!(
596 request_id = %ctx.request_id,
597 "Authentication failed: invalid credentials"
598 );
599 Ok(Response::unauthorized("Invalid credentials"))
600 }
601 }
602 None => {
603 if self.allow_anonymous {
604 next.run(ctx).await
605 } else {
606 warn!(
607 request_id = %ctx.request_id,
608 "Authentication failed: no credentials provided"
609 );
610 Ok(Response::unauthorized("No credentials provided"))
611 }
612 }
613 }
614 }
615
616 fn name(&self) -> &str {
617 "auth"
618 }
619
620 fn order(&self) -> i32 {
621 -80
622 }
623}
624
625pub struct RateLimitMiddleware {
633 state: Arc<parking_lot::Mutex<RateLimitState>>,
634 max_tokens: u64,
635 refill_rate: f64, }
637
638struct RateLimitState {
639 tokens: f64,
640 last_refill: Instant,
641}
642
643impl RateLimitMiddleware {
644 pub fn new(max_tokens: u64, refill_rate: f64) -> Self {
647 Self {
648 state: Arc::new(parking_lot::Mutex::new(RateLimitState {
649 tokens: max_tokens as f64,
650 last_refill: Instant::now(),
651 })),
652 max_tokens,
653 refill_rate,
654 }
655 }
656
657 fn try_acquire(&self) -> bool {
658 let mut state = self.state.lock();
659 let now = Instant::now();
660 let elapsed = now.duration_since(state.last_refill).as_secs_f64();
661 state.tokens = (state.tokens + elapsed * self.refill_rate).min(self.max_tokens as f64);
662 state.last_refill = now;
663
664 if state.tokens >= 1.0 {
665 state.tokens -= 1.0;
666 true
667 } else {
668 false
669 }
670 }
671}
672
673#[async_trait]
674impl Middleware for RateLimitMiddleware {
675 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
676 if self.try_acquire() {
677 next.run(ctx).await
678 } else {
679 warn!(
680 request_id = %ctx.request_id,
681 "Rate limit exceeded"
682 );
683 Ok(Response::rate_limited("Rate limit exceeded"))
684 }
685 }
686
687 fn name(&self) -> &str {
688 "rate_limit"
689 }
690
691 fn order(&self) -> i32 {
692 -70
693 }
694}
695
696#[cfg(test)]
701mod tests {
702 use super::*;
703 use std::sync::atomic::{AtomicUsize, Ordering};
704
705 struct OrderRecorder {
709 id: i32,
710 log: Arc<parking_lot::Mutex<Vec<i32>>>,
711 }
712
713 #[async_trait]
714 impl Middleware for OrderRecorder {
715 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
716 self.log.lock().push(self.id);
717 next.run(ctx).await
718 }
719 fn name(&self) -> &str {
720 "order_recorder"
721 }
722 fn order(&self) -> i32 {
723 self.id
724 }
725 }
726
727 struct ShortCircuit;
729
730 #[async_trait]
731 impl Middleware for ShortCircuit {
732 async fn process(&self, _ctx: &mut RequestContext, _next: &dyn Next) -> Result<Response> {
733 Ok(Response::unauthorized("blocked"))
734 }
735 fn name(&self) -> &str {
736 "short_circuit"
737 }
738 fn order(&self) -> i32 {
739 0
740 }
741 }
742
743 struct AttributeSetter {
745 key: String,
746 value: String,
747 }
748
749 #[async_trait]
750 impl Middleware for AttributeSetter {
751 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
752 ctx.set_attribute(&self.key, self.value.clone());
753 next.run(ctx).await
754 }
755 fn name(&self) -> &str {
756 "attr_setter"
757 }
758 fn order(&self) -> i32 {
759 -10
760 }
761 }
762
763 struct AttributeReader {
765 key: String,
766 found: Arc<parking_lot::Mutex<Option<String>>>,
767 }
768
769 #[async_trait]
770 impl Middleware for AttributeReader {
771 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
772 if let Some(val) = ctx.get_attribute::<String>(&self.key) {
773 *self.found.lock() = Some(val.clone());
774 }
775 next.run(ctx).await
776 }
777 fn name(&self) -> &str {
778 "attr_reader"
779 }
780 fn order(&self) -> i32 {
781 10
782 }
783 }
784
785 struct ErrorMiddleware;
787
788 #[async_trait]
789 impl Middleware for ErrorMiddleware {
790 async fn process(&self, _ctx: &mut RequestContext, _next: &dyn Next) -> Result<Response> {
791 Err(MiddlewareError::Internal("boom".to_string()))
792 }
793 fn name(&self) -> &str {
794 "error"
795 }
796 }
797
798 struct CounterMiddleware {
800 counter: Arc<AtomicUsize>,
801 ord: i32,
802 }
803
804 #[async_trait]
805 impl Middleware for CounterMiddleware {
806 async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
807 self.counter.fetch_add(1, Ordering::SeqCst);
808 next.run(ctx).await
809 }
810 fn name(&self) -> &str {
811 "counter"
812 }
813 fn order(&self) -> i32 {
814 self.ord
815 }
816 }
817
818 #[tokio::test]
821 async fn test_empty_pipeline_passes_through() {
822 let pipeline = MiddlewarePipelineBuilder::new().build();
823 let mut ctx = RequestContext::new("TEST");
824 let resp = pipeline
825 .execute(&mut ctx)
826 .await
827 .expect("empty pipeline should succeed");
828 assert_eq!(resp.status, ResponseStatus::Ok);
829 }
830
831 #[tokio::test]
832 async fn test_pipeline_executes_in_order() {
833 let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
834
835 let pipeline = MiddlewarePipelineBuilder::new()
836 .with(OrderRecorder {
837 id: 3,
838 log: Arc::clone(&log),
839 })
840 .with(OrderRecorder {
841 id: 1,
842 log: Arc::clone(&log),
843 })
844 .with(OrderRecorder {
845 id: 2,
846 log: Arc::clone(&log),
847 })
848 .build();
849
850 let mut ctx = RequestContext::new("TEST");
851 pipeline
852 .execute(&mut ctx)
853 .await
854 .expect("pipeline should succeed");
855
856 let order = log.lock().clone();
857 assert_eq!(
858 order,
859 vec![1, 2, 3],
860 "middleware should run sorted by order()"
861 );
862 }
863
864 #[tokio::test]
865 async fn test_short_circuit_on_auth_failure() {
866 let counter = Arc::new(AtomicUsize::new(0));
867
868 let pipeline = MiddlewarePipelineBuilder::new()
869 .with(ShortCircuit)
870 .with(CounterMiddleware {
871 counter: Arc::clone(&counter),
872 ord: 10,
873 })
874 .build();
875
876 let mut ctx = RequestContext::new("TEST");
877 let resp = pipeline
878 .execute(&mut ctx)
879 .await
880 .expect("should get unauthorized response");
881
882 assert_eq!(resp.status, ResponseStatus::Unauthorized);
883 assert_eq!(
884 counter.load(Ordering::SeqCst),
885 0,
886 "downstream middleware must not run after short-circuit"
887 );
888 }
889
890 #[tokio::test]
891 async fn test_context_attributes_passed_between_middleware() {
892 let found = Arc::new(parking_lot::Mutex::new(None));
893
894 let pipeline = MiddlewarePipelineBuilder::new()
895 .with(AttributeSetter {
896 key: "user".to_string(),
897 value: "alice".to_string(),
898 })
899 .with(AttributeReader {
900 key: "user".to_string(),
901 found: Arc::clone(&found),
902 })
903 .build();
904
905 let mut ctx = RequestContext::new("TEST");
906 pipeline
907 .execute(&mut ctx)
908 .await
909 .expect("pipeline should succeed");
910
911 let val = found.lock().clone();
912 assert_eq!(val, Some("alice".to_string()));
913 }
914
915 #[tokio::test]
916 async fn test_metrics_recorded_correctly() {
917 let collector = MetricsCollector::new();
918
919 let pipeline = MiddlewarePipelineBuilder::new()
920 .with(MetricsMiddleware::new(collector.clone()))
921 .build();
922
923 let mut ctx = RequestContext::new("GET");
924 pipeline
925 .execute(&mut ctx)
926 .await
927 .expect("pipeline should succeed");
928
929 let snapshot = collector.snapshot();
930 assert_eq!(snapshot.requests_total, 1);
931 assert_eq!(snapshot.requests_success, 1);
932 assert_eq!(snapshot.requests_failed, 0);
933 }
934
935 #[tokio::test]
936 async fn test_rate_limit_blocks_request() {
937 let rl = RateLimitMiddleware::new(1, 0.0);
939
940 let pipeline = MiddlewarePipelineBuilder::new().with(rl).build();
941
942 let mut ctx1 = RequestContext::new("GET");
944 let r1 = pipeline
945 .execute(&mut ctx1)
946 .await
947 .expect("first request should pass");
948 assert_eq!(r1.status, ResponseStatus::Ok);
949
950 let mut ctx2 = RequestContext::new("GET");
952 let r2 = pipeline
953 .execute(&mut ctx2)
954 .await
955 .expect("second request should be rate-limited");
956 assert_eq!(r2.status, ResponseStatus::RateLimited);
957 }
958
959 #[tokio::test]
960 async fn test_auth_middleware_valid_key() {
961 let mut keys = HashMap::new();
962 keys.insert("secret-key".to_string(), "user-42".to_string());
963
964 let pipeline = MiddlewarePipelineBuilder::new()
965 .with(AuthMiddleware::new(keys))
966 .build();
967
968 let mut ctx = RequestContext::new("GET").with_metadata("authorization", "secret-key");
969 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
970 assert_eq!(resp.status, ResponseStatus::Ok);
971
972 let principal = ctx
973 .get_attribute::<String>("auth_principal")
974 .expect("principal should be set");
975 assert_eq!(principal, "user-42");
976 }
977
978 #[tokio::test]
979 async fn test_auth_middleware_invalid_key() {
980 let mut keys = HashMap::new();
981 keys.insert("secret-key".to_string(), "user-42".to_string());
982
983 let pipeline = MiddlewarePipelineBuilder::new()
984 .with(AuthMiddleware::new(keys))
985 .build();
986
987 let mut ctx = RequestContext::new("GET").with_metadata("authorization", "wrong-key");
988 let resp = pipeline
989 .execute(&mut ctx)
990 .await
991 .expect("should get unauthorized");
992 assert_eq!(resp.status, ResponseStatus::Unauthorized);
993 }
994
995 #[tokio::test]
996 async fn test_auth_middleware_no_credentials() {
997 let keys = HashMap::new();
998 let pipeline = MiddlewarePipelineBuilder::new()
999 .with(AuthMiddleware::new(keys))
1000 .build();
1001
1002 let mut ctx = RequestContext::new("GET");
1003 let resp = pipeline
1004 .execute(&mut ctx)
1005 .await
1006 .expect("should get unauthorized");
1007 assert_eq!(resp.status, ResponseStatus::Unauthorized);
1008 }
1009
1010 #[tokio::test]
1011 async fn test_auth_middleware_anonymous_allowed() {
1012 let keys = HashMap::new();
1013 let pipeline = MiddlewarePipelineBuilder::new()
1014 .with(AuthMiddleware::new(keys).with_allow_anonymous(true))
1015 .build();
1016
1017 let mut ctx = RequestContext::new("GET");
1018 let resp = pipeline
1019 .execute(&mut ctx)
1020 .await
1021 .expect("should pass through");
1022 assert_eq!(resp.status, ResponseStatus::Ok);
1023 }
1024
1025 #[tokio::test]
1026 async fn test_error_propagation() {
1027 let pipeline = MiddlewarePipelineBuilder::new()
1028 .with(ErrorMiddleware)
1029 .build();
1030
1031 let mut ctx = RequestContext::new("GET");
1032 let result = pipeline.execute(&mut ctx).await;
1033 assert!(result.is_err());
1034 let err = result.expect_err("should be an error");
1035 assert!(
1036 err.to_string().contains("boom"),
1037 "error message should propagate"
1038 );
1039 }
1040
1041 #[tokio::test]
1042 async fn test_middleware_ordering_by_order() {
1043 let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
1044
1045 let pipeline = MiddlewarePipelineBuilder::new()
1047 .with(OrderRecorder {
1048 id: 50,
1049 log: Arc::clone(&log),
1050 })
1051 .with(OrderRecorder {
1052 id: 10,
1053 log: Arc::clone(&log),
1054 })
1055 .with(OrderRecorder {
1056 id: 30,
1057 log: Arc::clone(&log),
1058 })
1059 .with(OrderRecorder {
1060 id: 20,
1061 log: Arc::clone(&log),
1062 })
1063 .with(OrderRecorder {
1064 id: 40,
1065 log: Arc::clone(&log),
1066 })
1067 .build();
1068
1069 let mut ctx = RequestContext::new("TEST");
1070 pipeline
1071 .execute(&mut ctx)
1072 .await
1073 .expect("pipeline should succeed");
1074
1075 let order = log.lock().clone();
1076 assert_eq!(order, vec![10, 20, 30, 40, 50]);
1077 }
1078
1079 #[tokio::test]
1080 async fn test_response_duration_is_set() {
1081 let pipeline = MiddlewarePipelineBuilder::new().build();
1082 let mut ctx = RequestContext::new("TEST");
1083 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1084 let _ = resp.duration;
1087 }
1088
1089 #[tokio::test]
1090 async fn test_logging_middleware_runs() {
1091 let pipeline = MiddlewarePipelineBuilder::new()
1093 .with(LoggingMiddleware::new())
1094 .build();
1095
1096 let mut ctx = RequestContext::new("GET");
1097 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1098 assert_eq!(resp.status, ResponseStatus::Ok);
1099 }
1100
1101 #[tokio::test]
1102 async fn test_tracing_middleware_runs() {
1103 let pipeline = MiddlewarePipelineBuilder::new()
1104 .with(TracingMiddleware::new())
1105 .build();
1106
1107 let mut ctx = RequestContext::new("QUERY");
1108 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1109 assert_eq!(resp.status, ResponseStatus::Ok);
1110 }
1111
1112 #[tokio::test]
1113 async fn test_full_pipeline_integration() {
1114 let collector = MetricsCollector::new();
1115
1116 let mut api_keys = HashMap::new();
1117 api_keys.insert("valid-key".to_string(), "user-1".to_string());
1118
1119 let pipeline = MiddlewarePipelineBuilder::new()
1120 .with(LoggingMiddleware::new().with_level(LogLevel::Debug))
1121 .with(TracingMiddleware::new())
1122 .with(MetricsMiddleware::new(collector.clone()))
1123 .with(AuthMiddleware::new(api_keys))
1124 .with(RateLimitMiddleware::new(100, 100.0))
1125 .build();
1126
1127 let mut ctx = RequestContext::new("QUERY").with_metadata("authorization", "valid-key");
1129 let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1130 assert_eq!(resp.status, ResponseStatus::Ok);
1131
1132 let snapshot = collector.snapshot();
1133 assert_eq!(snapshot.requests_total, 1);
1134 assert_eq!(snapshot.requests_success, 1);
1135 }
1136
1137 #[tokio::test]
1138 async fn test_pipeline_builder_default() {
1139 let builder = MiddlewarePipelineBuilder::default();
1140 let pipeline = builder.build();
1141 let mut ctx = RequestContext::new("TEST");
1142 let resp = pipeline
1143 .execute(&mut ctx)
1144 .await
1145 .expect("default pipeline should succeed");
1146 assert_eq!(resp.status, ResponseStatus::Ok);
1147 }
1148
1149 #[tokio::test]
1150 async fn test_request_context_debug() {
1151 let ctx = RequestContext::new("GET");
1152 let debug_str = format!("{:?}", ctx);
1153 assert!(debug_str.contains("RequestContext"));
1154 assert!(debug_str.contains("GET"));
1155 }
1156
1157 #[tokio::test]
1158 async fn test_response_status_display() {
1159 assert_eq!(ResponseStatus::Ok.to_string(), "OK");
1160 assert_eq!(ResponseStatus::Error.to_string(), "Error");
1161 assert_eq!(ResponseStatus::RateLimited.to_string(), "RateLimited");
1162 assert_eq!(ResponseStatus::Unauthorized.to_string(), "Unauthorized");
1163 }
1164
1165 #[tokio::test]
1166 async fn test_response_builders() {
1167 let r = Response::ok()
1168 .with_header("x-req", "123")
1169 .with_body(b"hello".to_vec());
1170 assert_eq!(r.status, ResponseStatus::Ok);
1171 assert_eq!(r.body, Some(b"hello".to_vec()));
1172 assert_eq!(r.headers.get("x-req"), Some(&"123".to_string()));
1173
1174 let r2 = Response::error("oops");
1175 assert_eq!(r2.status, ResponseStatus::Error);
1176 assert_eq!(r2.body, Some(b"oops".to_vec()));
1177 }
1178}