Skip to main content

amaters_server/
middleware.rs

1//! Composable middleware pipeline for request processing.
2//!
3//! Provides a [`MiddlewarePipeline`] that executes a chain of [`Middleware`]
4//! implementations in order. Each middleware can inspect/modify the
5//! [`RequestContext`], optionally short-circuit the pipeline (e.g. on auth
6//! failure), or let processing continue by calling [`Next::run`].
7//!
8//! # Built-in middleware
9//!
10//! | Middleware | Purpose |
11//! |---|---|
12//! | [`LoggingMiddleware`] | Logs request/response with duration |
13//! | [`MetricsMiddleware`] | Records operation metrics |
14//! | [`AuthMiddleware`] | API-key / JWT authentication |
15//! | [`RateLimitMiddleware`] | Token-bucket rate limiting |
16//! | [`TracingMiddleware`] | Creates a tracing span per request |
17
18use std::any::Any;
19use std::collections::HashMap;
20use std::fmt;
21use std::net::SocketAddr;
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24
25use async_trait::async_trait;
26use thiserror::Error;
27use tracing::{debug, info, warn};
28
29use crate::metrics::MetricsCollector;
30
31// ---------------------------------------------------------------------------
32// Errors
33// ---------------------------------------------------------------------------
34
35/// Errors that can occur in the middleware pipeline.
36#[derive(Error, Debug)]
37pub enum MiddlewareError {
38    #[error("Authentication failed: {0}")]
39    AuthFailed(String),
40
41    #[error("Rate limited: {0}")]
42    RateLimited(String),
43
44    #[error("Internal middleware error: {0}")]
45    Internal(String),
46
47    #[error("Pipeline error: {0}")]
48    Pipeline(String),
49}
50
51pub type Result<T> = std::result::Result<T, MiddlewareError>;
52
53// ---------------------------------------------------------------------------
54// ResponseStatus / Response
55// ---------------------------------------------------------------------------
56
57/// Status of a middleware response.
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ResponseStatus {
60    Ok,
61    Error,
62    RateLimited,
63    Unauthorized,
64}
65
66impl fmt::Display for ResponseStatus {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match self {
69            Self::Ok => write!(f, "OK"),
70            Self::Error => write!(f, "Error"),
71            Self::RateLimited => write!(f, "RateLimited"),
72            Self::Unauthorized => write!(f, "Unauthorized"),
73        }
74    }
75}
76
77/// Response wrapper returned by the middleware pipeline.
78#[derive(Debug, Clone)]
79pub struct Response {
80    pub status: ResponseStatus,
81    pub body: Option<Vec<u8>>,
82    pub headers: HashMap<String, String>,
83    pub duration: Duration,
84}
85
86impl Response {
87    /// Create a successful response with no body.
88    pub fn ok() -> Self {
89        Self {
90            status: ResponseStatus::Ok,
91            body: None,
92            headers: HashMap::new(),
93            duration: Duration::ZERO,
94        }
95    }
96
97    /// Create an error response.
98    pub fn error(msg: impl Into<String>) -> Self {
99        Self {
100            status: ResponseStatus::Error,
101            body: Some(msg.into().into_bytes()),
102            headers: HashMap::new(),
103            duration: Duration::ZERO,
104        }
105    }
106
107    /// Create a rate-limited response.
108    pub fn rate_limited(msg: impl Into<String>) -> Self {
109        Self {
110            status: ResponseStatus::RateLimited,
111            body: Some(msg.into().into_bytes()),
112            headers: HashMap::new(),
113            duration: Duration::ZERO,
114        }
115    }
116
117    /// Create an unauthorized response.
118    pub fn unauthorized(msg: impl Into<String>) -> Self {
119        Self {
120            status: ResponseStatus::Unauthorized,
121            body: Some(msg.into().into_bytes()),
122            headers: HashMap::new(),
123            duration: Duration::ZERO,
124        }
125    }
126
127    /// Set a header on the response.
128    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
129        self.headers.insert(key.into(), value.into());
130        self
131    }
132
133    /// Set the body.
134    pub fn with_body(mut self, body: Vec<u8>) -> Self {
135        self.body = Some(body);
136        self
137    }
138
139    /// Set the duration.
140    pub fn with_duration(mut self, duration: Duration) -> Self {
141        self.duration = duration;
142        self
143    }
144}
145
146// ---------------------------------------------------------------------------
147// RequestContext
148// ---------------------------------------------------------------------------
149
150/// Context that travels through the middleware pipeline.
151///
152/// Middleware can read/write [`metadata`](Self::metadata) (string key-value) or
153/// store arbitrary typed data in [`attributes`](Self::attributes).
154pub struct RequestContext {
155    /// Unique request identifier (UUID v4).
156    pub request_id: String,
157    /// Remote peer address, if known.
158    pub client_addr: Option<SocketAddr>,
159    /// Logical method / query type (e.g. `"GET"`, `"PUT"`, `"QUERY"`).
160    pub method: String,
161    /// Extensible string metadata.
162    pub metadata: HashMap<String, String>,
163    /// When the request started.
164    pub start_time: Instant,
165    /// Typed attributes that middleware can set/get.
166    pub attributes: HashMap<String, Box<dyn Any + Send + Sync>>,
167}
168
169impl RequestContext {
170    /// Create a new request context.
171    pub fn new(method: impl Into<String>) -> Self {
172        Self {
173            request_id: uuid::Uuid::new_v4().to_string(),
174            client_addr: None,
175            method: method.into(),
176            metadata: HashMap::new(),
177            start_time: Instant::now(),
178            attributes: HashMap::new(),
179        }
180    }
181
182    /// Set the client address.
183    pub fn with_client_addr(mut self, addr: SocketAddr) -> Self {
184        self.client_addr = Some(addr);
185        self
186    }
187
188    /// Insert string metadata.
189    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
190        self.metadata.insert(key.into(), value.into());
191        self
192    }
193
194    /// Store a typed attribute.
195    pub fn set_attribute<T: Any + Send + Sync>(&mut self, key: impl Into<String>, value: T) {
196        self.attributes.insert(key.into(), Box::new(value));
197    }
198
199    /// Retrieve a typed attribute by reference.
200    pub fn get_attribute<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
201        self.attributes.get(key).and_then(|v| v.downcast_ref::<T>())
202    }
203
204    /// Elapsed time since `start_time`.
205    pub fn elapsed(&self) -> Duration {
206        self.start_time.elapsed()
207    }
208}
209
210impl fmt::Debug for RequestContext {
211    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212        f.debug_struct("RequestContext")
213            .field("request_id", &self.request_id)
214            .field("client_addr", &self.client_addr)
215            .field("method", &self.method)
216            .field("metadata", &self.metadata)
217            .field("start_time", &self.start_time)
218            .field("attributes_count", &self.attributes.len())
219            .finish()
220    }
221}
222
223// ---------------------------------------------------------------------------
224// Middleware + Next traits
225// ---------------------------------------------------------------------------
226
227/// Trait for the "rest of the pipeline" that a middleware calls to continue.
228#[async_trait]
229pub trait Next: Send + Sync {
230    async fn run(&self, ctx: &mut RequestContext) -> Result<Response>;
231}
232
233/// Trait implemented by each middleware layer.
234#[async_trait]
235pub trait Middleware: Send + Sync {
236    /// Process the request. Call `next.run(ctx)` to continue the pipeline.
237    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response>;
238
239    /// Human-readable name of this middleware.
240    fn name(&self) -> &str;
241
242    /// Execution order — lower values run first.
243    fn order(&self) -> i32 {
244        0
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Pipeline internals
250// ---------------------------------------------------------------------------
251
252/// Represents the tail of the middleware chain (produces the default response).
253struct PipelineTail;
254
255#[async_trait]
256impl Next for PipelineTail {
257    async fn run(&self, _ctx: &mut RequestContext) -> Result<Response> {
258        Ok(Response::ok())
259    }
260}
261
262/// Wraps one middleware layer + the remaining chain as a [`Next`].
263struct PipelineLink {
264    middleware: Arc<dyn Middleware>,
265    next: Arc<dyn Next>,
266}
267
268#[async_trait]
269impl Next for PipelineLink {
270    async fn run(&self, ctx: &mut RequestContext) -> Result<Response> {
271        self.middleware.process(ctx, self.next.as_ref()).await
272    }
273}
274
275// ---------------------------------------------------------------------------
276// MiddlewarePipeline + Builder
277// ---------------------------------------------------------------------------
278
279/// An immutable, ordered pipeline of middleware.
280///
281/// Built via [`MiddlewarePipelineBuilder`].
282pub struct MiddlewarePipeline {
283    chain: Arc<dyn Next>,
284}
285
286impl MiddlewarePipeline {
287    /// Execute the pipeline with the given context.
288    pub async fn execute(&self, ctx: &mut RequestContext) -> Result<Response> {
289        let result = self.chain.run(ctx).await;
290        // Stamp the duration on the response.
291        match result {
292            Ok(mut resp) => {
293                resp.duration = ctx.elapsed();
294                Ok(resp)
295            }
296            Err(e) => Err(e),
297        }
298    }
299}
300
301/// Builder for [`MiddlewarePipeline`].
302pub struct MiddlewarePipelineBuilder {
303    middleware: Vec<Arc<dyn Middleware>>,
304}
305
306impl Default for MiddlewarePipelineBuilder {
307    fn default() -> Self {
308        Self::new()
309    }
310}
311
312impl MiddlewarePipelineBuilder {
313    /// Create an empty builder.
314    pub fn new() -> Self {
315        Self {
316            middleware: Vec::new(),
317        }
318    }
319
320    /// Add a middleware to the pipeline.
321    pub fn with<M: Middleware + 'static>(mut self, m: M) -> Self {
322        self.middleware.push(Arc::new(m));
323        self
324    }
325
326    /// Add an already-arc'd middleware to the pipeline.
327    pub fn add_arc(mut self, m: Arc<dyn Middleware>) -> Self {
328        self.middleware.push(m);
329        self
330    }
331
332    /// Build the pipeline, sorting middleware by [`Middleware::order`].
333    pub fn build(mut self) -> MiddlewarePipeline {
334        // Stable sort so insertion order breaks ties.
335        self.middleware.sort_by_key(|m| m.order());
336
337        // Build the chain from back to front.
338        let mut next: Arc<dyn Next> = Arc::new(PipelineTail);
339        for mw in self.middleware.into_iter().rev() {
340            next = Arc::new(PipelineLink {
341                middleware: mw,
342                next,
343            });
344        }
345
346        MiddlewarePipeline { chain: next }
347    }
348}
349
350// ===========================================================================
351// Built-in middleware implementations
352// ===========================================================================
353
354// ---------------------------------------------------------------------------
355// LoggingMiddleware
356// ---------------------------------------------------------------------------
357
358/// Logs every request and its outcome.
359pub struct LoggingMiddleware {
360    level: LogLevel,
361}
362
363/// Log verbosity used by [`LoggingMiddleware`].
364#[derive(Debug, Clone, Copy, PartialEq, Eq)]
365pub enum LogLevel {
366    /// Log at `debug!` level.
367    Debug,
368    /// Log at `info!` level.
369    Info,
370}
371
372impl Default for LoggingMiddleware {
373    fn default() -> Self {
374        Self::new()
375    }
376}
377
378impl LoggingMiddleware {
379    pub fn new() -> Self {
380        Self {
381            level: LogLevel::Info,
382        }
383    }
384
385    pub fn with_level(mut self, level: LogLevel) -> Self {
386        self.level = level;
387        self
388    }
389}
390
391#[async_trait]
392impl Middleware for LoggingMiddleware {
393    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
394        let method = ctx.method.clone();
395        let request_id = ctx.request_id.clone();
396        let client = ctx
397            .client_addr
398            .map_or_else(|| "unknown".to_string(), |a| a.to_string());
399
400        match self.level {
401            LogLevel::Info => info!(
402                request_id = %request_id,
403                method = %method,
404                client = %client,
405                "Request started"
406            ),
407            LogLevel::Debug => debug!(
408                request_id = %request_id,
409                method = %method,
410                client = %client,
411                "Request started"
412            ),
413        }
414
415        let result = next.run(ctx).await;
416
417        match &result {
418            Ok(resp) => match self.level {
419                LogLevel::Info => info!(
420                    request_id = %request_id,
421                    method = %method,
422                    status = %resp.status,
423                    duration_ms = %ctx.elapsed().as_millis(),
424                    "Request completed"
425                ),
426                LogLevel::Debug => debug!(
427                    request_id = %request_id,
428                    method = %method,
429                    status = %resp.status,
430                    duration_ms = %ctx.elapsed().as_millis(),
431                    "Request completed"
432                ),
433            },
434            Err(e) => warn!(
435                request_id = %request_id,
436                method = %method,
437                error = %e,
438                duration_ms = %ctx.elapsed().as_millis(),
439                "Request failed"
440            ),
441        }
442
443        result
444    }
445
446    fn name(&self) -> &str {
447        "logging"
448    }
449
450    fn order(&self) -> i32 {
451        -100
452    }
453}
454
455// ---------------------------------------------------------------------------
456// MetricsMiddleware
457// ---------------------------------------------------------------------------
458
459/// Records request metrics via the existing [`MetricsCollector`].
460pub struct MetricsMiddleware {
461    collector: MetricsCollector,
462}
463
464impl MetricsMiddleware {
465    pub fn new(collector: MetricsCollector) -> Self {
466        Self { collector }
467    }
468}
469
470#[async_trait]
471impl Middleware for MetricsMiddleware {
472    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
473        let result = next.run(ctx).await;
474        let duration = ctx.elapsed();
475
476        self.collector.inc_requests();
477        self.collector.observe_request_latency(duration);
478
479        match &result {
480            Ok(resp) => {
481                if resp.status == ResponseStatus::Ok {
482                    self.collector.inc_success();
483                } else {
484                    self.collector.inc_failed();
485                }
486            }
487            Err(_) => {
488                self.collector.inc_failed();
489            }
490        }
491
492        result
493    }
494
495    fn name(&self) -> &str {
496        "metrics"
497    }
498
499    fn order(&self) -> i32 {
500        -90
501    }
502}
503
504// ---------------------------------------------------------------------------
505// TracingMiddleware
506// ---------------------------------------------------------------------------
507
508/// Creates a tracing span around the remainder of the pipeline.
509pub struct TracingMiddleware;
510
511impl Default for TracingMiddleware {
512    fn default() -> Self {
513        Self::new()
514    }
515}
516
517impl TracingMiddleware {
518    pub fn new() -> Self {
519        Self
520    }
521}
522
523#[async_trait]
524impl Middleware for TracingMiddleware {
525    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
526        let span = tracing::info_span!(
527            "request",
528            request_id = %ctx.request_id,
529            method = %ctx.method,
530            client_addr = ?ctx.client_addr,
531        );
532
533        let _guard = span.enter();
534        next.run(ctx).await
535    }
536
537    fn name(&self) -> &str {
538        "tracing"
539    }
540
541    fn order(&self) -> i32 {
542        -95
543    }
544}
545
546// ---------------------------------------------------------------------------
547// AuthMiddleware
548// ---------------------------------------------------------------------------
549
550/// Validates authentication credentials found in request metadata.
551///
552/// Looks for an `"authorization"` key in [`RequestContext::metadata`].
553/// On success, stores the authenticated identity as an attribute under
554/// `"auth_principal"`.
555pub struct AuthMiddleware {
556    /// Valid API keys (key -> user-id mapping).
557    api_keys: HashMap<String, String>,
558    /// Whether to allow unauthenticated requests to pass through.
559    allow_anonymous: bool,
560}
561
562impl AuthMiddleware {
563    pub fn new(api_keys: HashMap<String, String>) -> Self {
564        Self {
565            api_keys,
566            allow_anonymous: false,
567        }
568    }
569
570    /// When `true`, requests without credentials are passed through instead of
571    /// being rejected.
572    pub fn with_allow_anonymous(mut self, allow: bool) -> Self {
573        self.allow_anonymous = allow;
574        self
575    }
576}
577
578#[async_trait]
579impl Middleware for AuthMiddleware {
580    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
581        let auth_header = ctx.metadata.get("authorization").cloned();
582
583        match auth_header {
584            Some(key) => {
585                // Try API-key lookup.
586                if let Some(user_id) = self.api_keys.get(&key) {
587                    ctx.set_attribute("auth_principal", user_id.clone());
588                    debug!(
589                        request_id = %ctx.request_id,
590                        user_id = %user_id,
591                        "Authentication successful"
592                    );
593                    next.run(ctx).await
594                } else {
595                    warn!(
596                        request_id = %ctx.request_id,
597                        "Authentication failed: invalid credentials"
598                    );
599                    Ok(Response::unauthorized("Invalid credentials"))
600                }
601            }
602            None => {
603                if self.allow_anonymous {
604                    next.run(ctx).await
605                } else {
606                    warn!(
607                        request_id = %ctx.request_id,
608                        "Authentication failed: no credentials provided"
609                    );
610                    Ok(Response::unauthorized("No credentials provided"))
611                }
612            }
613        }
614    }
615
616    fn name(&self) -> &str {
617        "auth"
618    }
619
620    fn order(&self) -> i32 {
621        -80
622    }
623}
624
625// ---------------------------------------------------------------------------
626// RateLimitMiddleware
627// ---------------------------------------------------------------------------
628
629/// Simple token-bucket rate limiter.
630///
631/// Tracks a global bucket of available tokens, refilled at a fixed rate.
632pub struct RateLimitMiddleware {
633    state: Arc<parking_lot::Mutex<RateLimitState>>,
634    max_tokens: u64,
635    refill_rate: f64, // tokens per second
636}
637
638struct RateLimitState {
639    tokens: f64,
640    last_refill: Instant,
641}
642
643impl RateLimitMiddleware {
644    /// Create a rate limiter with `max_tokens` capacity, refilling at
645    /// `refill_rate` tokens per second.
646    pub fn new(max_tokens: u64, refill_rate: f64) -> Self {
647        Self {
648            state: Arc::new(parking_lot::Mutex::new(RateLimitState {
649                tokens: max_tokens as f64,
650                last_refill: Instant::now(),
651            })),
652            max_tokens,
653            refill_rate,
654        }
655    }
656
657    fn try_acquire(&self) -> bool {
658        let mut state = self.state.lock();
659        let now = Instant::now();
660        let elapsed = now.duration_since(state.last_refill).as_secs_f64();
661        state.tokens = (state.tokens + elapsed * self.refill_rate).min(self.max_tokens as f64);
662        state.last_refill = now;
663
664        if state.tokens >= 1.0 {
665            state.tokens -= 1.0;
666            true
667        } else {
668            false
669        }
670    }
671}
672
673#[async_trait]
674impl Middleware for RateLimitMiddleware {
675    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
676        if self.try_acquire() {
677            next.run(ctx).await
678        } else {
679            warn!(
680                request_id = %ctx.request_id,
681                "Rate limit exceeded"
682            );
683            Ok(Response::rate_limited("Rate limit exceeded"))
684        }
685    }
686
687    fn name(&self) -> &str {
688        "rate_limit"
689    }
690
691    fn order(&self) -> i32 {
692        -70
693    }
694}
695
696// ===========================================================================
697// Tests
698// ===========================================================================
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703    use std::sync::atomic::{AtomicUsize, Ordering};
704
705    // ---- helpers ----------------------------------------------------------
706
707    /// A trivial middleware that records the order it was called.
708    struct OrderRecorder {
709        id: i32,
710        log: Arc<parking_lot::Mutex<Vec<i32>>>,
711    }
712
713    #[async_trait]
714    impl Middleware for OrderRecorder {
715        async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
716            self.log.lock().push(self.id);
717            next.run(ctx).await
718        }
719        fn name(&self) -> &str {
720            "order_recorder"
721        }
722        fn order(&self) -> i32 {
723            self.id
724        }
725    }
726
727    /// Middleware that short-circuits (does **not** call `next`).
728    struct ShortCircuit;
729
730    #[async_trait]
731    impl Middleware for ShortCircuit {
732        async fn process(&self, _ctx: &mut RequestContext, _next: &dyn Next) -> Result<Response> {
733            Ok(Response::unauthorized("blocked"))
734        }
735        fn name(&self) -> &str {
736            "short_circuit"
737        }
738        fn order(&self) -> i32 {
739            0
740        }
741    }
742
743    /// Middleware that sets an attribute for downstream consumption.
744    struct AttributeSetter {
745        key: String,
746        value: String,
747    }
748
749    #[async_trait]
750    impl Middleware for AttributeSetter {
751        async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
752            ctx.set_attribute(&self.key, self.value.clone());
753            next.run(ctx).await
754        }
755        fn name(&self) -> &str {
756            "attr_setter"
757        }
758        fn order(&self) -> i32 {
759            -10
760        }
761    }
762
763    /// Middleware that reads an attribute set by an earlier middleware.
764    struct AttributeReader {
765        key: String,
766        found: Arc<parking_lot::Mutex<Option<String>>>,
767    }
768
769    #[async_trait]
770    impl Middleware for AttributeReader {
771        async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
772            if let Some(val) = ctx.get_attribute::<String>(&self.key) {
773                *self.found.lock() = Some(val.clone());
774            }
775            next.run(ctx).await
776        }
777        fn name(&self) -> &str {
778            "attr_reader"
779        }
780        fn order(&self) -> i32 {
781            10
782        }
783    }
784
785    /// Middleware that propagates an error.
786    struct ErrorMiddleware;
787
788    #[async_trait]
789    impl Middleware for ErrorMiddleware {
790        async fn process(&self, _ctx: &mut RequestContext, _next: &dyn Next) -> Result<Response> {
791            Err(MiddlewareError::Internal("boom".to_string()))
792        }
793        fn name(&self) -> &str {
794            "error"
795        }
796    }
797
798    /// Counter middleware — increments an atomic counter each call.
799    struct CounterMiddleware {
800        counter: Arc<AtomicUsize>,
801        ord: i32,
802    }
803
804    #[async_trait]
805    impl Middleware for CounterMiddleware {
806        async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
807            self.counter.fetch_add(1, Ordering::SeqCst);
808            next.run(ctx).await
809        }
810        fn name(&self) -> &str {
811            "counter"
812        }
813        fn order(&self) -> i32 {
814            self.ord
815        }
816    }
817
818    // ---- tests ------------------------------------------------------------
819
820    #[tokio::test]
821    async fn test_empty_pipeline_passes_through() {
822        let pipeline = MiddlewarePipelineBuilder::new().build();
823        let mut ctx = RequestContext::new("TEST");
824        let resp = pipeline
825            .execute(&mut ctx)
826            .await
827            .expect("empty pipeline should succeed");
828        assert_eq!(resp.status, ResponseStatus::Ok);
829    }
830
831    #[tokio::test]
832    async fn test_pipeline_executes_in_order() {
833        let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
834
835        let pipeline = MiddlewarePipelineBuilder::new()
836            .with(OrderRecorder {
837                id: 3,
838                log: Arc::clone(&log),
839            })
840            .with(OrderRecorder {
841                id: 1,
842                log: Arc::clone(&log),
843            })
844            .with(OrderRecorder {
845                id: 2,
846                log: Arc::clone(&log),
847            })
848            .build();
849
850        let mut ctx = RequestContext::new("TEST");
851        pipeline
852            .execute(&mut ctx)
853            .await
854            .expect("pipeline should succeed");
855
856        let order = log.lock().clone();
857        assert_eq!(
858            order,
859            vec![1, 2, 3],
860            "middleware should run sorted by order()"
861        );
862    }
863
864    #[tokio::test]
865    async fn test_short_circuit_on_auth_failure() {
866        let counter = Arc::new(AtomicUsize::new(0));
867
868        let pipeline = MiddlewarePipelineBuilder::new()
869            .with(ShortCircuit)
870            .with(CounterMiddleware {
871                counter: Arc::clone(&counter),
872                ord: 10,
873            })
874            .build();
875
876        let mut ctx = RequestContext::new("TEST");
877        let resp = pipeline
878            .execute(&mut ctx)
879            .await
880            .expect("should get unauthorized response");
881
882        assert_eq!(resp.status, ResponseStatus::Unauthorized);
883        assert_eq!(
884            counter.load(Ordering::SeqCst),
885            0,
886            "downstream middleware must not run after short-circuit"
887        );
888    }
889
890    #[tokio::test]
891    async fn test_context_attributes_passed_between_middleware() {
892        let found = Arc::new(parking_lot::Mutex::new(None));
893
894        let pipeline = MiddlewarePipelineBuilder::new()
895            .with(AttributeSetter {
896                key: "user".to_string(),
897                value: "alice".to_string(),
898            })
899            .with(AttributeReader {
900                key: "user".to_string(),
901                found: Arc::clone(&found),
902            })
903            .build();
904
905        let mut ctx = RequestContext::new("TEST");
906        pipeline
907            .execute(&mut ctx)
908            .await
909            .expect("pipeline should succeed");
910
911        let val = found.lock().clone();
912        assert_eq!(val, Some("alice".to_string()));
913    }
914
915    #[tokio::test]
916    async fn test_metrics_recorded_correctly() {
917        let collector = MetricsCollector::new();
918
919        let pipeline = MiddlewarePipelineBuilder::new()
920            .with(MetricsMiddleware::new(collector.clone()))
921            .build();
922
923        let mut ctx = RequestContext::new("GET");
924        pipeline
925            .execute(&mut ctx)
926            .await
927            .expect("pipeline should succeed");
928
929        let snapshot = collector.snapshot();
930        assert_eq!(snapshot.requests_total, 1);
931        assert_eq!(snapshot.requests_success, 1);
932        assert_eq!(snapshot.requests_failed, 0);
933    }
934
935    #[tokio::test]
936    async fn test_rate_limit_blocks_request() {
937        // One token, no refill.
938        let rl = RateLimitMiddleware::new(1, 0.0);
939
940        let pipeline = MiddlewarePipelineBuilder::new().with(rl).build();
941
942        // First request should pass.
943        let mut ctx1 = RequestContext::new("GET");
944        let r1 = pipeline
945            .execute(&mut ctx1)
946            .await
947            .expect("first request should pass");
948        assert_eq!(r1.status, ResponseStatus::Ok);
949
950        // Second request should be rate-limited.
951        let mut ctx2 = RequestContext::new("GET");
952        let r2 = pipeline
953            .execute(&mut ctx2)
954            .await
955            .expect("second request should be rate-limited");
956        assert_eq!(r2.status, ResponseStatus::RateLimited);
957    }
958
959    #[tokio::test]
960    async fn test_auth_middleware_valid_key() {
961        let mut keys = HashMap::new();
962        keys.insert("secret-key".to_string(), "user-42".to_string());
963
964        let pipeline = MiddlewarePipelineBuilder::new()
965            .with(AuthMiddleware::new(keys))
966            .build();
967
968        let mut ctx = RequestContext::new("GET").with_metadata("authorization", "secret-key");
969        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
970        assert_eq!(resp.status, ResponseStatus::Ok);
971
972        let principal = ctx
973            .get_attribute::<String>("auth_principal")
974            .expect("principal should be set");
975        assert_eq!(principal, "user-42");
976    }
977
978    #[tokio::test]
979    async fn test_auth_middleware_invalid_key() {
980        let mut keys = HashMap::new();
981        keys.insert("secret-key".to_string(), "user-42".to_string());
982
983        let pipeline = MiddlewarePipelineBuilder::new()
984            .with(AuthMiddleware::new(keys))
985            .build();
986
987        let mut ctx = RequestContext::new("GET").with_metadata("authorization", "wrong-key");
988        let resp = pipeline
989            .execute(&mut ctx)
990            .await
991            .expect("should get unauthorized");
992        assert_eq!(resp.status, ResponseStatus::Unauthorized);
993    }
994
995    #[tokio::test]
996    async fn test_auth_middleware_no_credentials() {
997        let keys = HashMap::new();
998        let pipeline = MiddlewarePipelineBuilder::new()
999            .with(AuthMiddleware::new(keys))
1000            .build();
1001
1002        let mut ctx = RequestContext::new("GET");
1003        let resp = pipeline
1004            .execute(&mut ctx)
1005            .await
1006            .expect("should get unauthorized");
1007        assert_eq!(resp.status, ResponseStatus::Unauthorized);
1008    }
1009
1010    #[tokio::test]
1011    async fn test_auth_middleware_anonymous_allowed() {
1012        let keys = HashMap::new();
1013        let pipeline = MiddlewarePipelineBuilder::new()
1014            .with(AuthMiddleware::new(keys).with_allow_anonymous(true))
1015            .build();
1016
1017        let mut ctx = RequestContext::new("GET");
1018        let resp = pipeline
1019            .execute(&mut ctx)
1020            .await
1021            .expect("should pass through");
1022        assert_eq!(resp.status, ResponseStatus::Ok);
1023    }
1024
1025    #[tokio::test]
1026    async fn test_error_propagation() {
1027        let pipeline = MiddlewarePipelineBuilder::new()
1028            .with(ErrorMiddleware)
1029            .build();
1030
1031        let mut ctx = RequestContext::new("GET");
1032        let result = pipeline.execute(&mut ctx).await;
1033        assert!(result.is_err());
1034        let err = result.expect_err("should be an error");
1035        assert!(
1036            err.to_string().contains("boom"),
1037            "error message should propagate"
1038        );
1039    }
1040
1041    #[tokio::test]
1042    async fn test_middleware_ordering_by_order() {
1043        let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
1044
1045        // Add in reverse order — builder should sort by order().
1046        let pipeline = MiddlewarePipelineBuilder::new()
1047            .with(OrderRecorder {
1048                id: 50,
1049                log: Arc::clone(&log),
1050            })
1051            .with(OrderRecorder {
1052                id: 10,
1053                log: Arc::clone(&log),
1054            })
1055            .with(OrderRecorder {
1056                id: 30,
1057                log: Arc::clone(&log),
1058            })
1059            .with(OrderRecorder {
1060                id: 20,
1061                log: Arc::clone(&log),
1062            })
1063            .with(OrderRecorder {
1064                id: 40,
1065                log: Arc::clone(&log),
1066            })
1067            .build();
1068
1069        let mut ctx = RequestContext::new("TEST");
1070        pipeline
1071            .execute(&mut ctx)
1072            .await
1073            .expect("pipeline should succeed");
1074
1075        let order = log.lock().clone();
1076        assert_eq!(order, vec![10, 20, 30, 40, 50]);
1077    }
1078
1079    #[tokio::test]
1080    async fn test_response_duration_is_set() {
1081        let pipeline = MiddlewarePipelineBuilder::new().build();
1082        let mut ctx = RequestContext::new("TEST");
1083        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1084        // Duration should have been stamped by execute().
1085        // (Any Duration is valid; we just confirm execute didn't panic.)
1086        let _ = resp.duration;
1087    }
1088
1089    #[tokio::test]
1090    async fn test_logging_middleware_runs() {
1091        // Smoke test — just ensure it doesn't panic.
1092        let pipeline = MiddlewarePipelineBuilder::new()
1093            .with(LoggingMiddleware::new())
1094            .build();
1095
1096        let mut ctx = RequestContext::new("GET");
1097        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1098        assert_eq!(resp.status, ResponseStatus::Ok);
1099    }
1100
1101    #[tokio::test]
1102    async fn test_tracing_middleware_runs() {
1103        let pipeline = MiddlewarePipelineBuilder::new()
1104            .with(TracingMiddleware::new())
1105            .build();
1106
1107        let mut ctx = RequestContext::new("QUERY");
1108        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1109        assert_eq!(resp.status, ResponseStatus::Ok);
1110    }
1111
1112    #[tokio::test]
1113    async fn test_full_pipeline_integration() {
1114        let collector = MetricsCollector::new();
1115
1116        let mut api_keys = HashMap::new();
1117        api_keys.insert("valid-key".to_string(), "user-1".to_string());
1118
1119        let pipeline = MiddlewarePipelineBuilder::new()
1120            .with(LoggingMiddleware::new().with_level(LogLevel::Debug))
1121            .with(TracingMiddleware::new())
1122            .with(MetricsMiddleware::new(collector.clone()))
1123            .with(AuthMiddleware::new(api_keys))
1124            .with(RateLimitMiddleware::new(100, 100.0))
1125            .build();
1126
1127        // Authenticated request
1128        let mut ctx = RequestContext::new("QUERY").with_metadata("authorization", "valid-key");
1129        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1130        assert_eq!(resp.status, ResponseStatus::Ok);
1131
1132        let snapshot = collector.snapshot();
1133        assert_eq!(snapshot.requests_total, 1);
1134        assert_eq!(snapshot.requests_success, 1);
1135    }
1136
1137    #[tokio::test]
1138    async fn test_pipeline_builder_default() {
1139        let builder = MiddlewarePipelineBuilder::default();
1140        let pipeline = builder.build();
1141        let mut ctx = RequestContext::new("TEST");
1142        let resp = pipeline
1143            .execute(&mut ctx)
1144            .await
1145            .expect("default pipeline should succeed");
1146        assert_eq!(resp.status, ResponseStatus::Ok);
1147    }
1148
1149    #[tokio::test]
1150    async fn test_request_context_debug() {
1151        let ctx = RequestContext::new("GET");
1152        let debug_str = format!("{:?}", ctx);
1153        assert!(debug_str.contains("RequestContext"));
1154        assert!(debug_str.contains("GET"));
1155    }
1156
1157    #[tokio::test]
1158    async fn test_response_status_display() {
1159        assert_eq!(ResponseStatus::Ok.to_string(), "OK");
1160        assert_eq!(ResponseStatus::Error.to_string(), "Error");
1161        assert_eq!(ResponseStatus::RateLimited.to_string(), "RateLimited");
1162        assert_eq!(ResponseStatus::Unauthorized.to_string(), "Unauthorized");
1163    }
1164
1165    #[tokio::test]
1166    async fn test_response_builders() {
1167        let r = Response::ok()
1168            .with_header("x-req", "123")
1169            .with_body(b"hello".to_vec());
1170        assert_eq!(r.status, ResponseStatus::Ok);
1171        assert_eq!(r.body, Some(b"hello".to_vec()));
1172        assert_eq!(r.headers.get("x-req"), Some(&"123".to_string()));
1173
1174        let r2 = Response::error("oops");
1175        assert_eq!(r2.status, ResponseStatus::Error);
1176        assert_eq!(r2.body, Some(b"oops".to_vec()));
1177    }
1178}