Skip to main content

amaters_server/
middleware.rs

1//! Composable middleware pipeline for request processing.
2//!
3//! Provides a [`MiddlewarePipeline`] that executes a chain of [`Middleware`]
4//! implementations in order. Each middleware can inspect/modify the
5//! [`RequestContext`], optionally short-circuit the pipeline (e.g. on auth
6//! failure), or let processing continue by calling [`Next::run`].
7//!
8//! # Built-in middleware
9//!
10//! | Middleware | Purpose |
11//! |---|---|
12//! | [`LoggingMiddleware`] | Logs request/response with duration |
13//! | [`MetricsMiddleware`] | Records operation metrics |
14//! | [`AuthMiddleware`] | API-key / JWT authentication |
15//! | [`RateLimitMiddleware`] | Token-bucket rate limiting |
16//! | [`TracingMiddleware`] | Creates a tracing span per request |
17
18use constant_time_eq::constant_time_eq;
19use std::any::Any;
20use std::collections::HashMap;
21use std::fmt;
22use std::net::SocketAddr;
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use async_trait::async_trait;
27use thiserror::Error;
28use tracing::{debug, info, warn};
29
30use crate::metrics::MetricsCollector;
31
32// ---------------------------------------------------------------------------
33// Errors
34// ---------------------------------------------------------------------------
35
36/// Errors that can occur in the middleware pipeline.
37#[derive(Error, Debug)]
38pub enum MiddlewareError {
39    #[error("Authentication failed: {0}")]
40    AuthFailed(String),
41
42    #[error("Rate limited: {0}")]
43    RateLimited(String),
44
45    #[error("Internal middleware error: {0}")]
46    Internal(String),
47
48    #[error("Pipeline error: {0}")]
49    Pipeline(String),
50}
51
52pub type Result<T> = std::result::Result<T, MiddlewareError>;
53
54// ---------------------------------------------------------------------------
55// ResponseStatus / Response
56// ---------------------------------------------------------------------------
57
58/// Status of a middleware response.
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum ResponseStatus {
61    Ok,
62    Error,
63    RateLimited,
64    Unauthorized,
65}
66
67impl fmt::Display for ResponseStatus {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        match self {
70            Self::Ok => write!(f, "OK"),
71            Self::Error => write!(f, "Error"),
72            Self::RateLimited => write!(f, "RateLimited"),
73            Self::Unauthorized => write!(f, "Unauthorized"),
74        }
75    }
76}
77
78/// Response wrapper returned by the middleware pipeline.
79#[derive(Debug, Clone)]
80pub struct Response {
81    pub status: ResponseStatus,
82    pub body: Option<Vec<u8>>,
83    pub headers: HashMap<String, String>,
84    pub duration: Duration,
85}
86
87impl Response {
88    /// Create a successful response with no body.
89    pub fn ok() -> Self {
90        Self {
91            status: ResponseStatus::Ok,
92            body: None,
93            headers: HashMap::new(),
94            duration: Duration::ZERO,
95        }
96    }
97
98    /// Create an error response.
99    pub fn error(msg: impl Into<String>) -> Self {
100        Self {
101            status: ResponseStatus::Error,
102            body: Some(msg.into().into_bytes()),
103            headers: HashMap::new(),
104            duration: Duration::ZERO,
105        }
106    }
107
108    /// Create a rate-limited response.
109    pub fn rate_limited(msg: impl Into<String>) -> Self {
110        Self {
111            status: ResponseStatus::RateLimited,
112            body: Some(msg.into().into_bytes()),
113            headers: HashMap::new(),
114            duration: Duration::ZERO,
115        }
116    }
117
118    /// Create an unauthorized response.
119    pub fn unauthorized(msg: impl Into<String>) -> Self {
120        Self {
121            status: ResponseStatus::Unauthorized,
122            body: Some(msg.into().into_bytes()),
123            headers: HashMap::new(),
124            duration: Duration::ZERO,
125        }
126    }
127
128    /// Set a header on the response.
129    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
130        self.headers.insert(key.into(), value.into());
131        self
132    }
133
134    /// Set the body.
135    pub fn with_body(mut self, body: Vec<u8>) -> Self {
136        self.body = Some(body);
137        self
138    }
139
140    /// Set the duration.
141    pub fn with_duration(mut self, duration: Duration) -> Self {
142        self.duration = duration;
143        self
144    }
145}
146
147// ---------------------------------------------------------------------------
148// RequestContext
149// ---------------------------------------------------------------------------
150
151/// Context that travels through the middleware pipeline.
152///
153/// Middleware can read/write [`metadata`](Self::metadata) (string key-value) or
154/// store arbitrary typed data in [`attributes`](Self::attributes).
155pub struct RequestContext {
156    /// Unique request identifier (UUID v4).
157    pub request_id: String,
158    /// Remote peer address, if known.
159    pub client_addr: Option<SocketAddr>,
160    /// Logical method / query type (e.g. `"GET"`, `"PUT"`, `"QUERY"`).
161    pub method: String,
162    /// Extensible string metadata.
163    pub metadata: HashMap<String, String>,
164    /// When the request started.
165    pub start_time: Instant,
166    /// Typed attributes that middleware can set/get.
167    pub attributes: HashMap<String, Box<dyn Any + Send + Sync>>,
168}
169
170impl RequestContext {
171    /// Create a new request context.
172    pub fn new(method: impl Into<String>) -> Self {
173        Self {
174            request_id: uuid::Uuid::new_v4().to_string(),
175            client_addr: None,
176            method: method.into(),
177            metadata: HashMap::new(),
178            start_time: Instant::now(),
179            attributes: HashMap::new(),
180        }
181    }
182
183    /// Set the client address.
184    pub fn with_client_addr(mut self, addr: SocketAddr) -> Self {
185        self.client_addr = Some(addr);
186        self
187    }
188
189    /// Insert string metadata.
190    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
191        self.metadata.insert(key.into(), value.into());
192        self
193    }
194
195    /// Store a typed attribute.
196    pub fn set_attribute<T: Any + Send + Sync>(&mut self, key: impl Into<String>, value: T) {
197        self.attributes.insert(key.into(), Box::new(value));
198    }
199
200    /// Retrieve a typed attribute by reference.
201    pub fn get_attribute<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
202        self.attributes.get(key).and_then(|v| v.downcast_ref::<T>())
203    }
204
205    /// Elapsed time since `start_time`.
206    pub fn elapsed(&self) -> Duration {
207        self.start_time.elapsed()
208    }
209}
210
211impl fmt::Debug for RequestContext {
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        f.debug_struct("RequestContext")
214            .field("request_id", &self.request_id)
215            .field("client_addr", &self.client_addr)
216            .field("method", &self.method)
217            .field("metadata", &self.metadata)
218            .field("start_time", &self.start_time)
219            .field("attributes_count", &self.attributes.len())
220            .finish()
221    }
222}
223
224// ---------------------------------------------------------------------------
225// Middleware + Next traits
226// ---------------------------------------------------------------------------
227
228/// Trait for the "rest of the pipeline" that a middleware calls to continue.
229#[async_trait]
230pub trait Next: Send + Sync {
231    async fn run(&self, ctx: &mut RequestContext) -> Result<Response>;
232}
233
234/// Trait implemented by each middleware layer.
235#[async_trait]
236pub trait Middleware: Send + Sync {
237    /// Process the request. Call `next.run(ctx)` to continue the pipeline.
238    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response>;
239
240    /// Human-readable name of this middleware.
241    fn name(&self) -> &str;
242
243    /// Execution order — lower values run first.
244    fn order(&self) -> i32 {
245        0
246    }
247}
248
249// ---------------------------------------------------------------------------
250// Pipeline internals
251// ---------------------------------------------------------------------------
252
253/// Represents the tail of the middleware chain (produces the default response).
254struct PipelineTail;
255
256#[async_trait]
257impl Next for PipelineTail {
258    async fn run(&self, _ctx: &mut RequestContext) -> Result<Response> {
259        Ok(Response::ok())
260    }
261}
262
263/// Wraps one middleware layer + the remaining chain as a [`Next`].
264struct PipelineLink {
265    middleware: Arc<dyn Middleware>,
266    next: Arc<dyn Next>,
267}
268
269#[async_trait]
270impl Next for PipelineLink {
271    async fn run(&self, ctx: &mut RequestContext) -> Result<Response> {
272        self.middleware.process(ctx, self.next.as_ref()).await
273    }
274}
275
276// ---------------------------------------------------------------------------
277// MiddlewarePipeline + Builder
278// ---------------------------------------------------------------------------
279
280/// An immutable, ordered pipeline of middleware.
281///
282/// Built via [`MiddlewarePipelineBuilder`].
283pub struct MiddlewarePipeline {
284    chain: Arc<dyn Next>,
285}
286
287impl MiddlewarePipeline {
288    /// Execute the pipeline with the given context.
289    pub async fn execute(&self, ctx: &mut RequestContext) -> Result<Response> {
290        let result = self.chain.run(ctx).await;
291        // Stamp the duration on the response.
292        match result {
293            Ok(mut resp) => {
294                resp.duration = ctx.elapsed();
295                Ok(resp)
296            }
297            Err(e) => Err(e),
298        }
299    }
300}
301
302/// Builder for [`MiddlewarePipeline`].
303pub struct MiddlewarePipelineBuilder {
304    middleware: Vec<Arc<dyn Middleware>>,
305}
306
307impl Default for MiddlewarePipelineBuilder {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313impl MiddlewarePipelineBuilder {
314    /// Create an empty builder.
315    pub fn new() -> Self {
316        Self {
317            middleware: Vec::new(),
318        }
319    }
320
321    /// Add a middleware to the pipeline.
322    pub fn with<M: Middleware + 'static>(mut self, m: M) -> Self {
323        self.middleware.push(Arc::new(m));
324        self
325    }
326
327    /// Add an already-arc'd middleware to the pipeline.
328    pub fn add_arc(mut self, m: Arc<dyn Middleware>) -> Self {
329        self.middleware.push(m);
330        self
331    }
332
333    /// Build the pipeline, sorting middleware by [`Middleware::order`].
334    pub fn build(mut self) -> MiddlewarePipeline {
335        // Stable sort so insertion order breaks ties.
336        self.middleware.sort_by_key(|m| m.order());
337
338        // Build the chain from back to front.
339        let mut next: Arc<dyn Next> = Arc::new(PipelineTail);
340        for mw in self.middleware.into_iter().rev() {
341            next = Arc::new(PipelineLink {
342                middleware: mw,
343                next,
344            });
345        }
346
347        MiddlewarePipeline { chain: next }
348    }
349}
350
351// ===========================================================================
352// Built-in middleware implementations
353// ===========================================================================
354
355// ---------------------------------------------------------------------------
356// LoggingMiddleware
357// ---------------------------------------------------------------------------
358
359/// Logs every request and its outcome.
360pub struct LoggingMiddleware {
361    level: LogLevel,
362}
363
364/// Log verbosity used by [`LoggingMiddleware`].
365#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub enum LogLevel {
367    /// Log at `debug!` level.
368    Debug,
369    /// Log at `info!` level.
370    Info,
371}
372
373impl Default for LoggingMiddleware {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379impl LoggingMiddleware {
380    pub fn new() -> Self {
381        Self {
382            level: LogLevel::Info,
383        }
384    }
385
386    pub fn with_level(mut self, level: LogLevel) -> Self {
387        self.level = level;
388        self
389    }
390}
391
392#[async_trait]
393impl Middleware for LoggingMiddleware {
394    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
395        let method = ctx.method.clone();
396        let request_id = ctx.request_id.clone();
397        let client = ctx
398            .client_addr
399            .map_or_else(|| "unknown".to_string(), |a| a.to_string());
400
401        match self.level {
402            LogLevel::Info => info!(
403                request_id = %request_id,
404                method = %method,
405                client = %client,
406                "Request started"
407            ),
408            LogLevel::Debug => debug!(
409                request_id = %request_id,
410                method = %method,
411                client = %client,
412                "Request started"
413            ),
414        }
415
416        let result = next.run(ctx).await;
417
418        match &result {
419            Ok(resp) => match self.level {
420                LogLevel::Info => info!(
421                    request_id = %request_id,
422                    method = %method,
423                    status = %resp.status,
424                    duration_ms = %ctx.elapsed().as_millis(),
425                    "Request completed"
426                ),
427                LogLevel::Debug => debug!(
428                    request_id = %request_id,
429                    method = %method,
430                    status = %resp.status,
431                    duration_ms = %ctx.elapsed().as_millis(),
432                    "Request completed"
433                ),
434            },
435            Err(e) => warn!(
436                request_id = %request_id,
437                method = %method,
438                error = %e,
439                duration_ms = %ctx.elapsed().as_millis(),
440                "Request failed"
441            ),
442        }
443
444        result
445    }
446
447    fn name(&self) -> &str {
448        "logging"
449    }
450
451    fn order(&self) -> i32 {
452        -100
453    }
454}
455
456// ---------------------------------------------------------------------------
457// MetricsMiddleware
458// ---------------------------------------------------------------------------
459
460/// Records request metrics via the existing [`MetricsCollector`].
461pub struct MetricsMiddleware {
462    collector: MetricsCollector,
463}
464
465impl MetricsMiddleware {
466    pub fn new(collector: MetricsCollector) -> Self {
467        Self { collector }
468    }
469}
470
471#[async_trait]
472impl Middleware for MetricsMiddleware {
473    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
474        let result = next.run(ctx).await;
475        let duration = ctx.elapsed();
476
477        self.collector.inc_requests();
478        self.collector.observe_request_latency(duration);
479
480        match &result {
481            Ok(resp) => {
482                if resp.status == ResponseStatus::Ok {
483                    self.collector.inc_success();
484                } else {
485                    self.collector.inc_failed();
486                }
487            }
488            Err(_) => {
489                self.collector.inc_failed();
490            }
491        }
492
493        result
494    }
495
496    fn name(&self) -> &str {
497        "metrics"
498    }
499
500    fn order(&self) -> i32 {
501        -90
502    }
503}
504
505// ---------------------------------------------------------------------------
506// TracingMiddleware
507// ---------------------------------------------------------------------------
508
509/// Creates a tracing span around the remainder of the pipeline.
510pub struct TracingMiddleware;
511
512impl Default for TracingMiddleware {
513    fn default() -> Self {
514        Self::new()
515    }
516}
517
518impl TracingMiddleware {
519    pub fn new() -> Self {
520        Self
521    }
522}
523
524#[async_trait]
525impl Middleware for TracingMiddleware {
526    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
527        let span = tracing::info_span!(
528            "amaters.request",
529            "amaters.node_id" = "local",
530            "amaters.request_id" = %ctx.request_id,
531            method = %ctx.method,
532            client_addr = ?ctx.client_addr,
533        );
534
535        let _guard = span.enter();
536        next.run(ctx).await
537    }
538
539    fn name(&self) -> &str {
540        "tracing"
541    }
542
543    fn order(&self) -> i32 {
544        -95
545    }
546}
547
548// ---------------------------------------------------------------------------
549// OtelSpanMiddleware
550// ---------------------------------------------------------------------------
551
552/// Creates OTel-compatible spans for key server lifecycle events.
553///
554/// Unlike [`TracingMiddleware`] (which uses a static `"local"` node id),
555/// `OtelSpanMiddleware` carries a real `node_id` that is known at construction
556/// time, enabling per-node filtering in distributed tracing back-ends.
557pub struct OtelSpanMiddleware {
558    node_id: String,
559}
560
561impl OtelSpanMiddleware {
562    pub fn new(node_id: impl Into<String>) -> Self {
563        Self {
564            node_id: node_id.into(),
565        }
566    }
567}
568
569#[async_trait]
570impl Middleware for OtelSpanMiddleware {
571    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
572        let span = tracing::info_span!(
573            "amaters.server.request",
574            "amaters.node_id" = self.node_id.as_str(),
575            "amaters.request_id" = %ctx.request_id,
576            "amaters.method" = %ctx.method,
577        );
578
579        let _guard = span.enter();
580        next.run(ctx).await
581    }
582
583    fn name(&self) -> &str {
584        "otel_span"
585    }
586
587    fn order(&self) -> i32 {
588        -97
589    }
590}
591
592// ---------------------------------------------------------------------------
593// AuthMiddleware
594// ---------------------------------------------------------------------------
595
596/// Validates authentication credentials found in request metadata.
597///
598/// Looks for an `"authorization"` key in [`RequestContext::metadata`].
599/// On success, stores the authenticated identity as an attribute under
600/// `"auth_principal"`.
601pub struct AuthMiddleware {
602    /// Valid API keys (key -> user-id mapping).
603    api_keys: HashMap<String, String>,
604    /// Whether to allow unauthenticated requests to pass through.
605    allow_anonymous: bool,
606}
607
608impl AuthMiddleware {
609    pub fn new(api_keys: HashMap<String, String>) -> Self {
610        Self {
611            api_keys,
612            allow_anonymous: false,
613        }
614    }
615
616    /// When `true`, requests without credentials are passed through instead of
617    /// being rejected.
618    pub fn with_allow_anonymous(mut self, allow: bool) -> Self {
619        self.allow_anonymous = allow;
620        self
621    }
622}
623
624#[async_trait]
625impl Middleware for AuthMiddleware {
626    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
627        let auth_header = ctx.metadata.get("authorization").cloned();
628
629        match auth_header {
630            Some(key) => {
631                // Try API-key lookup with constant-time comparison to prevent
632                // timing side-channel attacks on the stored key values.
633                let key_bytes = key.as_bytes();
634                if let Some(user_id) = self
635                    .api_keys
636                    .iter()
637                    .find(|(k, _)| constant_time_eq(k.as_bytes(), key_bytes))
638                    .map(|(_, v)| v)
639                {
640                    ctx.set_attribute("auth_principal", user_id.clone());
641                    debug!(
642                        request_id = %ctx.request_id,
643                        user_id = %user_id,
644                        "Authentication successful"
645                    );
646                    next.run(ctx).await
647                } else {
648                    warn!(
649                        request_id = %ctx.request_id,
650                        "Authentication failed: invalid credentials"
651                    );
652                    Ok(Response::unauthorized("Invalid credentials"))
653                }
654            }
655            None => {
656                if self.allow_anonymous {
657                    next.run(ctx).await
658                } else {
659                    warn!(
660                        request_id = %ctx.request_id,
661                        "Authentication failed: no credentials provided"
662                    );
663                    Ok(Response::unauthorized("No credentials provided"))
664                }
665            }
666        }
667    }
668
669    fn name(&self) -> &str {
670        "auth"
671    }
672
673    fn order(&self) -> i32 {
674        -80
675    }
676}
677
678// ---------------------------------------------------------------------------
679// RateLimitMiddleware
680// ---------------------------------------------------------------------------
681
682/// Simple token-bucket rate limiter.
683///
684/// Tracks a global bucket of available tokens, refilled at a fixed rate.
685pub struct RateLimitMiddleware {
686    state: Arc<parking_lot::Mutex<RateLimitState>>,
687    max_tokens: u64,
688    refill_rate: f64, // tokens per second
689}
690
691struct RateLimitState {
692    tokens: f64,
693    last_refill: Instant,
694}
695
696impl RateLimitMiddleware {
697    /// Create a rate limiter with `max_tokens` capacity, refilling at
698    /// `refill_rate` tokens per second.
699    pub fn new(max_tokens: u64, refill_rate: f64) -> Self {
700        Self {
701            state: Arc::new(parking_lot::Mutex::new(RateLimitState {
702                tokens: max_tokens as f64,
703                last_refill: Instant::now(),
704            })),
705            max_tokens,
706            refill_rate,
707        }
708    }
709
710    fn try_acquire(&self) -> bool {
711        let mut state = self.state.lock();
712        let now = Instant::now();
713        let elapsed = now.duration_since(state.last_refill).as_secs_f64();
714        state.tokens = (state.tokens + elapsed * self.refill_rate).min(self.max_tokens as f64);
715        state.last_refill = now;
716
717        if state.tokens >= 1.0 {
718            state.tokens -= 1.0;
719            true
720        } else {
721            false
722        }
723    }
724}
725
726#[async_trait]
727impl Middleware for RateLimitMiddleware {
728    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
729        if self.try_acquire() {
730            next.run(ctx).await
731        } else {
732            warn!(
733                request_id = %ctx.request_id,
734                "Rate limit exceeded"
735            );
736            Ok(Response::rate_limited("Rate limit exceeded"))
737        }
738    }
739
740    fn name(&self) -> &str {
741        "rate_limit"
742    }
743
744    fn order(&self) -> i32 {
745        -70
746    }
747}
748
749// ---------------------------------------------------------------------------
750// AdaptiveRateLimiter / AdaptiveRateLimitMiddleware
751// ---------------------------------------------------------------------------
752
753/// Adaptive rate limiter that reduces effective limits on high error rates.
754///
755/// Tracks a rolling window of request outcomes. When the error rate exceeds
756/// `error_threshold`, the effective limit is multiplied by `reduction_factor`.
757/// When the error rate drops back below the threshold, the limit recovers
758/// multiplicatively by `recovery_factor`, capped at `base_limit`.
759pub struct AdaptiveRateLimiter {
760    base_limit: u64,
761    current_limit: Arc<parking_lot::Mutex<u64>>,
762    error_window: Arc<parking_lot::Mutex<std::collections::VecDeque<bool>>>,
763    window_size: usize,
764    reduction_factor: f64,
765    recovery_factor: f64,
766    error_threshold: f64,
767}
768
769impl AdaptiveRateLimiter {
770    /// Create with `base_limit` tokens, default window (100), reduction 0.8,
771    /// recovery 1.05, error threshold 10 %.
772    pub fn new(base_limit: u64) -> Self {
773        Self {
774            base_limit,
775            current_limit: Arc::new(parking_lot::Mutex::new(base_limit)),
776            error_window: Arc::new(parking_lot::Mutex::new(
777                std::collections::VecDeque::with_capacity(101),
778            )),
779            window_size: 100,
780            reduction_factor: 0.8,
781            recovery_factor: 1.05,
782            error_threshold: 0.1,
783        }
784    }
785
786    /// Record a successful request and potentially recover the limit.
787    pub fn record_success(&self) {
788        self.push(false);
789        self.adjust();
790    }
791
792    /// Record a failed/errored request and potentially reduce the limit.
793    pub fn record_error(&self) {
794        self.push(true);
795        self.adjust();
796    }
797
798    /// Current effective token-bucket capacity.
799    pub fn current_limit(&self) -> u64 {
800        *self.current_limit.lock()
801    }
802
803    fn push(&self, is_error: bool) {
804        let mut window = self.error_window.lock();
805        if window.len() >= self.window_size {
806            window.pop_front();
807        }
808        window.push_back(is_error);
809    }
810
811    fn adjust(&self) {
812        let error_rate = {
813            let window = self.error_window.lock();
814            if window.is_empty() {
815                return;
816            }
817            let errors = window.iter().filter(|&&e| e).count();
818            errors as f64 / window.len() as f64
819        };
820
821        let mut limit = self.current_limit.lock();
822        if error_rate >= self.error_threshold {
823            let reduced = (*limit as f64 * self.reduction_factor).floor() as u64;
824            *limit = reduced.max(1);
825        } else {
826            let recovered = (*limit as f64 * self.recovery_factor).ceil() as u64;
827            *limit = recovered.min(self.base_limit);
828        }
829    }
830}
831
832/// Middleware wrapper around [`AdaptiveRateLimiter`].
833///
834/// Maintains a token-bucket whose *capacity* tracks the limiter's current
835/// effective limit. On a successful token acquisition it records a success;
836/// when the bucket is exhausted it records an error, which may further reduce
837/// the effective limit.
838pub struct AdaptiveRateLimitMiddleware {
839    limiter: Arc<AdaptiveRateLimiter>,
840    token_state: Arc<parking_lot::Mutex<RateLimitState>>,
841}
842
843impl AdaptiveRateLimitMiddleware {
844    pub fn new(base_limit: u64) -> Self {
845        let limiter = Arc::new(AdaptiveRateLimiter::new(base_limit));
846        Self {
847            token_state: Arc::new(parking_lot::Mutex::new(RateLimitState {
848                tokens: base_limit as f64,
849                last_refill: Instant::now(),
850            })),
851            limiter,
852        }
853    }
854
855    fn try_acquire(&self) -> bool {
856        let capacity = self.limiter.current_limit() as f64;
857        let mut state = self.token_state.lock();
858        let now = Instant::now();
859        let elapsed = now.duration_since(state.last_refill).as_secs_f64();
860        // Refill at a rate equal to capacity per second, capped at capacity.
861        state.tokens = (state.tokens + elapsed * capacity).min(capacity);
862        state.last_refill = now;
863
864        if state.tokens >= 1.0 {
865            state.tokens -= 1.0;
866            true
867        } else {
868            false
869        }
870    }
871}
872
873#[async_trait]
874impl Middleware for AdaptiveRateLimitMiddleware {
875    async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
876        if self.try_acquire() {
877            let result = next.run(ctx).await;
878            match &result {
879                Ok(resp) if resp.status == ResponseStatus::Ok => self.limiter.record_success(),
880                _ => self.limiter.record_error(),
881            }
882            result
883        } else {
884            self.limiter.record_error();
885            warn!(
886                request_id = %ctx.request_id,
887                "Adaptive rate limit exceeded"
888            );
889            Ok(Response::rate_limited("Adaptive rate limit exceeded"))
890        }
891    }
892
893    fn name(&self) -> &str {
894        "adaptive_rate_limit"
895    }
896
897    fn order(&self) -> i32 {
898        -65
899    }
900}
901
902// ===========================================================================
903// Tests
904// ===========================================================================
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909    use std::sync::atomic::{AtomicUsize, Ordering};
910
911    // ---- helpers ----------------------------------------------------------
912
913    /// A trivial middleware that records the order it was called.
914    struct OrderRecorder {
915        id: i32,
916        log: Arc<parking_lot::Mutex<Vec<i32>>>,
917    }
918
919    #[async_trait]
920    impl Middleware for OrderRecorder {
921        async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
922            self.log.lock().push(self.id);
923            next.run(ctx).await
924        }
925        fn name(&self) -> &str {
926            "order_recorder"
927        }
928        fn order(&self) -> i32 {
929            self.id
930        }
931    }
932
933    /// Middleware that short-circuits (does **not** call `next`).
934    struct ShortCircuit;
935
936    #[async_trait]
937    impl Middleware for ShortCircuit {
938        async fn process(&self, _ctx: &mut RequestContext, _next: &dyn Next) -> Result<Response> {
939            Ok(Response::unauthorized("blocked"))
940        }
941        fn name(&self) -> &str {
942            "short_circuit"
943        }
944        fn order(&self) -> i32 {
945            0
946        }
947    }
948
949    /// Middleware that sets an attribute for downstream consumption.
950    struct AttributeSetter {
951        key: String,
952        value: String,
953    }
954
955    #[async_trait]
956    impl Middleware for AttributeSetter {
957        async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
958            ctx.set_attribute(&self.key, self.value.clone());
959            next.run(ctx).await
960        }
961        fn name(&self) -> &str {
962            "attr_setter"
963        }
964        fn order(&self) -> i32 {
965            -10
966        }
967    }
968
969    /// Middleware that reads an attribute set by an earlier middleware.
970    struct AttributeReader {
971        key: String,
972        found: Arc<parking_lot::Mutex<Option<String>>>,
973    }
974
975    #[async_trait]
976    impl Middleware for AttributeReader {
977        async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
978            if let Some(val) = ctx.get_attribute::<String>(&self.key) {
979                *self.found.lock() = Some(val.clone());
980            }
981            next.run(ctx).await
982        }
983        fn name(&self) -> &str {
984            "attr_reader"
985        }
986        fn order(&self) -> i32 {
987            10
988        }
989    }
990
991    /// Middleware that propagates an error.
992    struct ErrorMiddleware;
993
994    #[async_trait]
995    impl Middleware for ErrorMiddleware {
996        async fn process(&self, _ctx: &mut RequestContext, _next: &dyn Next) -> Result<Response> {
997            Err(MiddlewareError::Internal("boom".to_string()))
998        }
999        fn name(&self) -> &str {
1000            "error"
1001        }
1002    }
1003
1004    /// Counter middleware — increments an atomic counter each call.
1005    struct CounterMiddleware {
1006        counter: Arc<AtomicUsize>,
1007        ord: i32,
1008    }
1009
1010    #[async_trait]
1011    impl Middleware for CounterMiddleware {
1012        async fn process(&self, ctx: &mut RequestContext, next: &dyn Next) -> Result<Response> {
1013            self.counter.fetch_add(1, Ordering::SeqCst);
1014            next.run(ctx).await
1015        }
1016        fn name(&self) -> &str {
1017            "counter"
1018        }
1019        fn order(&self) -> i32 {
1020            self.ord
1021        }
1022    }
1023
1024    // ---- tests ------------------------------------------------------------
1025
1026    #[tokio::test]
1027    async fn test_empty_pipeline_passes_through() {
1028        let pipeline = MiddlewarePipelineBuilder::new().build();
1029        let mut ctx = RequestContext::new("TEST");
1030        let resp = pipeline
1031            .execute(&mut ctx)
1032            .await
1033            .expect("empty pipeline should succeed");
1034        assert_eq!(resp.status, ResponseStatus::Ok);
1035    }
1036
1037    #[tokio::test]
1038    async fn test_pipeline_executes_in_order() {
1039        let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
1040
1041        let pipeline = MiddlewarePipelineBuilder::new()
1042            .with(OrderRecorder {
1043                id: 3,
1044                log: Arc::clone(&log),
1045            })
1046            .with(OrderRecorder {
1047                id: 1,
1048                log: Arc::clone(&log),
1049            })
1050            .with(OrderRecorder {
1051                id: 2,
1052                log: Arc::clone(&log),
1053            })
1054            .build();
1055
1056        let mut ctx = RequestContext::new("TEST");
1057        pipeline
1058            .execute(&mut ctx)
1059            .await
1060            .expect("pipeline should succeed");
1061
1062        let order = log.lock().clone();
1063        assert_eq!(
1064            order,
1065            vec![1, 2, 3],
1066            "middleware should run sorted by order()"
1067        );
1068    }
1069
1070    #[tokio::test]
1071    async fn test_short_circuit_on_auth_failure() {
1072        let counter = Arc::new(AtomicUsize::new(0));
1073
1074        let pipeline = MiddlewarePipelineBuilder::new()
1075            .with(ShortCircuit)
1076            .with(CounterMiddleware {
1077                counter: Arc::clone(&counter),
1078                ord: 10,
1079            })
1080            .build();
1081
1082        let mut ctx = RequestContext::new("TEST");
1083        let resp = pipeline
1084            .execute(&mut ctx)
1085            .await
1086            .expect("should get unauthorized response");
1087
1088        assert_eq!(resp.status, ResponseStatus::Unauthorized);
1089        assert_eq!(
1090            counter.load(Ordering::SeqCst),
1091            0,
1092            "downstream middleware must not run after short-circuit"
1093        );
1094    }
1095
1096    #[tokio::test]
1097    async fn test_context_attributes_passed_between_middleware() {
1098        let found = Arc::new(parking_lot::Mutex::new(None));
1099
1100        let pipeline = MiddlewarePipelineBuilder::new()
1101            .with(AttributeSetter {
1102                key: "user".to_string(),
1103                value: "alice".to_string(),
1104            })
1105            .with(AttributeReader {
1106                key: "user".to_string(),
1107                found: Arc::clone(&found),
1108            })
1109            .build();
1110
1111        let mut ctx = RequestContext::new("TEST");
1112        pipeline
1113            .execute(&mut ctx)
1114            .await
1115            .expect("pipeline should succeed");
1116
1117        let val = found.lock().clone();
1118        assert_eq!(val, Some("alice".to_string()));
1119    }
1120
1121    #[tokio::test]
1122    async fn test_metrics_recorded_correctly() {
1123        let collector = MetricsCollector::new();
1124
1125        let pipeline = MiddlewarePipelineBuilder::new()
1126            .with(MetricsMiddleware::new(collector.clone()))
1127            .build();
1128
1129        let mut ctx = RequestContext::new("GET");
1130        pipeline
1131            .execute(&mut ctx)
1132            .await
1133            .expect("pipeline should succeed");
1134
1135        let snapshot = collector.snapshot();
1136        assert_eq!(snapshot.requests_total, 1);
1137        assert_eq!(snapshot.requests_success, 1);
1138        assert_eq!(snapshot.requests_failed, 0);
1139    }
1140
1141    #[tokio::test]
1142    async fn test_rate_limit_blocks_request() {
1143        // One token, no refill.
1144        let rl = RateLimitMiddleware::new(1, 0.0);
1145
1146        let pipeline = MiddlewarePipelineBuilder::new().with(rl).build();
1147
1148        // First request should pass.
1149        let mut ctx1 = RequestContext::new("GET");
1150        let r1 = pipeline
1151            .execute(&mut ctx1)
1152            .await
1153            .expect("first request should pass");
1154        assert_eq!(r1.status, ResponseStatus::Ok);
1155
1156        // Second request should be rate-limited.
1157        let mut ctx2 = RequestContext::new("GET");
1158        let r2 = pipeline
1159            .execute(&mut ctx2)
1160            .await
1161            .expect("second request should be rate-limited");
1162        assert_eq!(r2.status, ResponseStatus::RateLimited);
1163    }
1164
1165    #[tokio::test]
1166    async fn test_auth_middleware_valid_key() {
1167        let mut keys = HashMap::new();
1168        keys.insert("secret-key".to_string(), "user-42".to_string());
1169
1170        let pipeline = MiddlewarePipelineBuilder::new()
1171            .with(AuthMiddleware::new(keys))
1172            .build();
1173
1174        let mut ctx = RequestContext::new("GET").with_metadata("authorization", "secret-key");
1175        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1176        assert_eq!(resp.status, ResponseStatus::Ok);
1177
1178        let principal = ctx
1179            .get_attribute::<String>("auth_principal")
1180            .expect("principal should be set");
1181        assert_eq!(principal, "user-42");
1182    }
1183
1184    #[tokio::test]
1185    async fn test_auth_middleware_invalid_key() {
1186        let mut keys = HashMap::new();
1187        keys.insert("secret-key".to_string(), "user-42".to_string());
1188
1189        let pipeline = MiddlewarePipelineBuilder::new()
1190            .with(AuthMiddleware::new(keys))
1191            .build();
1192
1193        let mut ctx = RequestContext::new("GET").with_metadata("authorization", "wrong-key");
1194        let resp = pipeline
1195            .execute(&mut ctx)
1196            .await
1197            .expect("should get unauthorized");
1198        assert_eq!(resp.status, ResponseStatus::Unauthorized);
1199    }
1200
1201    #[tokio::test]
1202    async fn test_auth_middleware_no_credentials() {
1203        let keys = HashMap::new();
1204        let pipeline = MiddlewarePipelineBuilder::new()
1205            .with(AuthMiddleware::new(keys))
1206            .build();
1207
1208        let mut ctx = RequestContext::new("GET");
1209        let resp = pipeline
1210            .execute(&mut ctx)
1211            .await
1212            .expect("should get unauthorized");
1213        assert_eq!(resp.status, ResponseStatus::Unauthorized);
1214    }
1215
1216    #[tokio::test]
1217    async fn test_auth_middleware_anonymous_allowed() {
1218        let keys = HashMap::new();
1219        let pipeline = MiddlewarePipelineBuilder::new()
1220            .with(AuthMiddleware::new(keys).with_allow_anonymous(true))
1221            .build();
1222
1223        let mut ctx = RequestContext::new("GET");
1224        let resp = pipeline
1225            .execute(&mut ctx)
1226            .await
1227            .expect("should pass through");
1228        assert_eq!(resp.status, ResponseStatus::Ok);
1229    }
1230
1231    #[tokio::test]
1232    async fn test_error_propagation() {
1233        let pipeline = MiddlewarePipelineBuilder::new()
1234            .with(ErrorMiddleware)
1235            .build();
1236
1237        let mut ctx = RequestContext::new("GET");
1238        let result = pipeline.execute(&mut ctx).await;
1239        assert!(result.is_err());
1240        let err = result.expect_err("should be an error");
1241        assert!(
1242            err.to_string().contains("boom"),
1243            "error message should propagate"
1244        );
1245    }
1246
1247    #[tokio::test]
1248    async fn test_middleware_ordering_by_order() {
1249        let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
1250
1251        // Add in reverse order — builder should sort by order().
1252        let pipeline = MiddlewarePipelineBuilder::new()
1253            .with(OrderRecorder {
1254                id: 50,
1255                log: Arc::clone(&log),
1256            })
1257            .with(OrderRecorder {
1258                id: 10,
1259                log: Arc::clone(&log),
1260            })
1261            .with(OrderRecorder {
1262                id: 30,
1263                log: Arc::clone(&log),
1264            })
1265            .with(OrderRecorder {
1266                id: 20,
1267                log: Arc::clone(&log),
1268            })
1269            .with(OrderRecorder {
1270                id: 40,
1271                log: Arc::clone(&log),
1272            })
1273            .build();
1274
1275        let mut ctx = RequestContext::new("TEST");
1276        pipeline
1277            .execute(&mut ctx)
1278            .await
1279            .expect("pipeline should succeed");
1280
1281        let order = log.lock().clone();
1282        assert_eq!(order, vec![10, 20, 30, 40, 50]);
1283    }
1284
1285    #[tokio::test]
1286    async fn test_response_duration_is_set() {
1287        let pipeline = MiddlewarePipelineBuilder::new().build();
1288        let mut ctx = RequestContext::new("TEST");
1289        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1290        // Duration should have been stamped by execute().
1291        // (Any Duration is valid; we just confirm execute didn't panic.)
1292        let _ = resp.duration;
1293    }
1294
1295    #[tokio::test]
1296    async fn test_logging_middleware_runs() {
1297        // Smoke test — just ensure it doesn't panic.
1298        let pipeline = MiddlewarePipelineBuilder::new()
1299            .with(LoggingMiddleware::new())
1300            .build();
1301
1302        let mut ctx = RequestContext::new("GET");
1303        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1304        assert_eq!(resp.status, ResponseStatus::Ok);
1305    }
1306
1307    #[tokio::test]
1308    async fn test_tracing_middleware_runs() {
1309        let pipeline = MiddlewarePipelineBuilder::new()
1310            .with(TracingMiddleware::new())
1311            .build();
1312
1313        let mut ctx = RequestContext::new("QUERY");
1314        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1315        assert_eq!(resp.status, ResponseStatus::Ok);
1316    }
1317
1318    #[tokio::test]
1319    async fn test_full_pipeline_integration() {
1320        let collector = MetricsCollector::new();
1321
1322        let mut api_keys = HashMap::new();
1323        api_keys.insert("valid-key".to_string(), "user-1".to_string());
1324
1325        let pipeline = MiddlewarePipelineBuilder::new()
1326            .with(LoggingMiddleware::new().with_level(LogLevel::Debug))
1327            .with(TracingMiddleware::new())
1328            .with(MetricsMiddleware::new(collector.clone()))
1329            .with(AuthMiddleware::new(api_keys))
1330            .with(RateLimitMiddleware::new(100, 100.0))
1331            .build();
1332
1333        // Authenticated request
1334        let mut ctx = RequestContext::new("QUERY").with_metadata("authorization", "valid-key");
1335        let resp = pipeline.execute(&mut ctx).await.expect("should succeed");
1336        assert_eq!(resp.status, ResponseStatus::Ok);
1337
1338        let snapshot = collector.snapshot();
1339        assert_eq!(snapshot.requests_total, 1);
1340        assert_eq!(snapshot.requests_success, 1);
1341    }
1342
1343    #[tokio::test]
1344    async fn test_pipeline_builder_default() {
1345        let builder = MiddlewarePipelineBuilder::default();
1346        let pipeline = builder.build();
1347        let mut ctx = RequestContext::new("TEST");
1348        let resp = pipeline
1349            .execute(&mut ctx)
1350            .await
1351            .expect("default pipeline should succeed");
1352        assert_eq!(resp.status, ResponseStatus::Ok);
1353    }
1354
1355    #[tokio::test]
1356    async fn test_request_context_debug() {
1357        let ctx = RequestContext::new("GET");
1358        let debug_str = format!("{:?}", ctx);
1359        assert!(debug_str.contains("RequestContext"));
1360        assert!(debug_str.contains("GET"));
1361    }
1362
1363    #[tokio::test]
1364    async fn test_response_status_display() {
1365        assert_eq!(ResponseStatus::Ok.to_string(), "OK");
1366        assert_eq!(ResponseStatus::Error.to_string(), "Error");
1367        assert_eq!(ResponseStatus::RateLimited.to_string(), "RateLimited");
1368        assert_eq!(ResponseStatus::Unauthorized.to_string(), "Unauthorized");
1369    }
1370
1371    #[tokio::test]
1372    async fn test_response_builders() {
1373        let r = Response::ok()
1374            .with_header("x-req", "123")
1375            .with_body(b"hello".to_vec());
1376        assert_eq!(r.status, ResponseStatus::Ok);
1377        assert_eq!(r.body, Some(b"hello".to_vec()));
1378        assert_eq!(r.headers.get("x-req"), Some(&"123".to_string()));
1379
1380        let r2 = Response::error("oops");
1381        assert_eq!(r2.status, ResponseStatus::Error);
1382        assert_eq!(r2.body, Some(b"oops".to_vec()));
1383    }
1384
1385    #[test]
1386    fn test_adaptive_rate_limiter_reduces_on_errors() {
1387        let limiter = AdaptiveRateLimiter::new(100);
1388        assert_eq!(limiter.current_limit(), 100);
1389
1390        // Flood the window with errors (>10% threshold).
1391        for _ in 0..50 {
1392            limiter.record_error();
1393        }
1394        assert!(
1395            limiter.current_limit() < 100,
1396            "limit should have decreased after high error rate"
1397        );
1398    }
1399
1400    #[test]
1401    fn test_adaptive_rate_limiter_recovers() {
1402        let limiter = AdaptiveRateLimiter::new(100);
1403
1404        // Drive the limit down first.
1405        for _ in 0..50 {
1406            limiter.record_error();
1407        }
1408        let reduced = limiter.current_limit();
1409        assert!(reduced < 100, "limit should be reduced");
1410
1411        // Now flood with successes to push error rate below threshold.
1412        for _ in 0..200 {
1413            limiter.record_success();
1414        }
1415        assert!(
1416            limiter.current_limit() > reduced,
1417            "limit should recover after sustained successes"
1418        );
1419    }
1420}