pmcp/shared/
middleware.rs

1//! Advanced middleware support for request/response processing.
2//!
3//! PMCP-4004: Enhanced transport middleware system with advanced capabilities:
4//! - Rate limiting and circuit breaker patterns
5//! - Metrics collection and performance monitoring
6//! - Conditional middleware execution
7//! - Priority-based middleware ordering
8//! - Compression and caching middleware
9//! - Context propagation across middleware layers
10
11use crate::error::Result;
12use crate::shared::TransportMessage;
13use crate::types::{JSONRPCRequest, JSONRPCResponse};
14use async_trait::async_trait;
15use dashmap::DashMap;
16use parking_lot::RwLock;
17use std::fmt;
18use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21
22/// Execution context for middleware chains with performance tracking.
23#[derive(Debug, Clone)]
24pub struct MiddlewareContext {
25    /// Request ID for correlation
26    pub request_id: Option<String>,
27    /// Custom metadata that can be passed between middleware
28    pub metadata: Arc<DashMap<String, String>>,
29    /// Performance metrics for the request
30    pub metrics: Arc<PerformanceMetrics>,
31    /// Start time of the middleware chain execution
32    pub start_time: Instant,
33    /// Priority level for the request
34    pub priority: Option<crate::shared::transport::MessagePriority>,
35}
36
37impl Default for MiddlewareContext {
38    fn default() -> Self {
39        Self {
40            request_id: None,
41            metadata: Arc::new(DashMap::new()),
42            metrics: Arc::new(PerformanceMetrics::new()),
43            start_time: Instant::now(),
44            priority: None,
45        }
46    }
47}
48
49impl MiddlewareContext {
50    /// Create a new context with request ID
51    pub fn with_request_id(request_id: String) -> Self {
52        Self {
53            request_id: Some(request_id),
54            ..Default::default()
55        }
56    }
57
58    /// Set metadata value
59    pub fn set_metadata(&self, key: String, value: String) {
60        self.metadata.insert(key, value);
61    }
62
63    /// Get metadata value
64    pub fn get_metadata(&self, key: &str) -> Option<String> {
65        self.metadata.get(key).map(|v| v.clone())
66    }
67
68    /// Record a metric
69    pub fn record_metric(&self, name: String, value: f64) {
70        self.metrics.record(name, value);
71    }
72
73    /// Get elapsed time since context creation
74    pub fn elapsed(&self) -> Duration {
75        self.start_time.elapsed()
76    }
77}
78
79/// Performance metrics collection for middleware operations.
80#[derive(Debug, Default)]
81pub struct PerformanceMetrics {
82    /// Custom metrics storage
83    metrics: DashMap<String, f64>,
84    /// Request count
85    request_count: AtomicU64,
86    /// Error count
87    error_count: AtomicU64,
88    /// Total processing time in microseconds
89    total_time_us: AtomicU64,
90}
91
92impl PerformanceMetrics {
93    /// Create new performance metrics
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    /// Record a custom metric
99    pub fn record(&self, name: String, value: f64) {
100        self.metrics.insert(name, value);
101    }
102
103    /// Get a metric value
104    pub fn get(&self, name: &str) -> Option<f64> {
105        self.metrics.get(name).map(|v| *v)
106    }
107
108    /// Increment request count
109    pub fn inc_requests(&self) {
110        self.request_count.fetch_add(1, Ordering::Relaxed);
111    }
112
113    /// Increment error count
114    pub fn inc_errors(&self) {
115        self.error_count.fetch_add(1, Ordering::Relaxed);
116    }
117
118    /// Add processing time
119    pub fn add_time(&self, duration: Duration) {
120        self.total_time_us
121            .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
122    }
123
124    /// Get total request count
125    pub fn request_count(&self) -> u64 {
126        self.request_count.load(Ordering::Relaxed)
127    }
128
129    /// Get total error count
130    pub fn error_count(&self) -> u64 {
131        self.error_count.load(Ordering::Relaxed)
132    }
133
134    /// Get average processing time
135    pub fn average_time(&self) -> Duration {
136        let total_time = self.total_time_us.load(Ordering::Relaxed);
137        let count = self.request_count.load(Ordering::Relaxed);
138        if count > 0 {
139            Duration::from_micros(total_time / count)
140        } else {
141            Duration::ZERO
142        }
143    }
144}
145
146/// Middleware execution priority for ordering.
147#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
148pub enum MiddlewarePriority {
149    /// Highest priority - executed first in chain
150    Critical = 0,
151    /// High priority - authentication, security
152    High = 1,
153    /// Normal priority - business logic
154    Normal = 2,
155    /// Low priority - logging, metrics
156    Low = 3,
157    /// Lowest priority - cleanup, finalization
158    Lowest = 4,
159}
160
161impl Default for MiddlewarePriority {
162    fn default() -> Self {
163        Self::Normal
164    }
165}
166
167/// Enhanced middleware trait with context support and priority.
168#[async_trait]
169pub trait AdvancedMiddleware: Send + Sync {
170    /// Get middleware priority for execution ordering
171    fn priority(&self) -> MiddlewarePriority {
172        MiddlewarePriority::Normal
173    }
174
175    /// Get middleware name for identification
176    fn name(&self) -> &'static str {
177        "unknown"
178    }
179
180    /// Check if middleware should be executed for this context
181    async fn should_execute(&self, _context: &MiddlewareContext) -> bool {
182        true
183    }
184
185    /// Called before a request is sent with context.
186    async fn on_request_with_context(
187        &self,
188        request: &mut JSONRPCRequest,
189        context: &MiddlewareContext,
190    ) -> Result<()> {
191        let _ = (request, context);
192        Ok(())
193    }
194
195    /// Called after a response is received with context.
196    async fn on_response_with_context(
197        &self,
198        response: &mut JSONRPCResponse,
199        context: &MiddlewareContext,
200    ) -> Result<()> {
201        let _ = (response, context);
202        Ok(())
203    }
204
205    /// Called when a message is sent with context.
206    async fn on_send_with_context(
207        &self,
208        message: &TransportMessage,
209        context: &MiddlewareContext,
210    ) -> Result<()> {
211        let _ = (message, context);
212        Ok(())
213    }
214
215    /// Called when a message is received with context.
216    async fn on_receive_with_context(
217        &self,
218        message: &TransportMessage,
219        context: &MiddlewareContext,
220    ) -> Result<()> {
221        let _ = (message, context);
222        Ok(())
223    }
224
225    /// Called when middleware chain starts
226    async fn on_chain_start(&self, _context: &MiddlewareContext) -> Result<()> {
227        Ok(())
228    }
229
230    /// Called when middleware chain completes
231    async fn on_chain_complete(&self, _context: &MiddlewareContext) -> Result<()> {
232        Ok(())
233    }
234
235    /// Called when an error occurs in the chain
236    async fn on_error(
237        &self,
238        _error: &crate::error::Error,
239        _context: &MiddlewareContext,
240    ) -> Result<()> {
241        Ok(())
242    }
243}
244
245/// Middleware that can intercept and modify requests and responses.
246///
247/// # Examples
248///
249/// ```rust
250/// use pmcp::shared::{Middleware, TransportMessage};
251/// use pmcp::types::{JSONRPCRequest, JSONRPCResponse, RequestId};
252/// use async_trait::async_trait;
253///
254/// // Custom middleware that adds timing information
255/// #[derive(Debug)]
256/// struct TimingMiddleware {
257///     start_time: std::time::Instant,
258/// }
259///
260/// impl TimingMiddleware {
261///     fn new() -> Self {
262///         Self { start_time: std::time::Instant::now() }
263///     }
264/// }
265///
266/// #[async_trait]
267/// impl Middleware for TimingMiddleware {
268///     async fn on_request(&self, request: &mut JSONRPCRequest) -> pmcp::Result<()> {
269///         // Add timing metadata to request params
270///         println!("Processing request {} at {}ms",
271///             request.method,
272///             self.start_time.elapsed().as_millis());
273///         Ok(())
274///     }
275///
276///     async fn on_response(&self, response: &mut JSONRPCResponse) -> pmcp::Result<()> {
277///         println!("Response for {:?} received at {}ms",
278///             response.id,
279///             self.start_time.elapsed().as_millis());
280///         Ok(())
281///     }
282/// }
283///
284/// # async fn example() -> pmcp::Result<()> {
285/// let middleware = TimingMiddleware::new();
286/// let mut request = JSONRPCRequest {
287///     jsonrpc: "2.0".to_string(),
288///     method: "test".to_string(),
289///     params: None,
290///     id: RequestId::from(123i64),
291/// };
292///
293/// // Process request through middleware
294/// middleware.on_request(&mut request).await?;
295/// # Ok(())
296/// # }
297/// ```
298#[async_trait]
299pub trait Middleware: Send + Sync {
300    /// Called before a request is sent.
301    async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
302        let _ = request;
303        Ok(())
304    }
305
306    /// Called after a response is received.
307    async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
308        let _ = response;
309        Ok(())
310    }
311
312    /// Called when a message is sent (any type).
313    async fn on_send(&self, message: &TransportMessage) -> Result<()> {
314        let _ = message;
315        Ok(())
316    }
317
318    /// Called when a message is received (any type).
319    async fn on_receive(&self, message: &TransportMessage) -> Result<()> {
320        let _ = message;
321        Ok(())
322    }
323}
324
325/// Enhanced middleware chain with priority ordering and context support.
326///
327/// # Examples
328///
329/// ```rust
330/// use pmcp::shared::{EnhancedMiddlewareChain, MiddlewareContext};
331/// use pmcp::types::{JSONRPCRequest, JSONRPCResponse, RequestId};
332/// use std::sync::Arc;
333///
334/// # async fn example() -> pmcp::Result<()> {
335/// // Create an enhanced middleware chain
336/// let mut chain = EnhancedMiddlewareChain::new();
337/// let context = MiddlewareContext::with_request_id("req-123".to_string());
338///
339/// // Create a request to process
340/// let mut request = JSONRPCRequest {
341///     jsonrpc: "2.0".to_string(),
342///     method: "prompts.get".to_string(),
343///     params: Some(serde_json::json!({
344///         "name": "code_review",
345///         "arguments": {"language": "rust", "style": "detailed"}
346///     })),
347///     id: RequestId::from(1001i64),
348/// };
349///
350/// // Process request through all middleware with context
351/// chain.process_request_with_context(&mut request, &context).await?;
352/// # Ok(())
353/// # }
354/// ```
355pub struct EnhancedMiddlewareChain {
356    middlewares: Vec<Arc<dyn AdvancedMiddleware>>,
357    auto_sort: bool,
358}
359
360impl fmt::Debug for EnhancedMiddlewareChain {
361    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362        f.debug_struct("EnhancedMiddlewareChain")
363            .field("count", &self.middlewares.len())
364            .field("auto_sort", &self.auto_sort)
365            .finish()
366    }
367}
368
369impl Default for EnhancedMiddlewareChain {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375impl EnhancedMiddlewareChain {
376    /// Create a new enhanced middleware chain with automatic sorting by priority.
377    pub fn new() -> Self {
378        Self {
379            middlewares: Vec::new(),
380            auto_sort: true,
381        }
382    }
383
384    /// Create a new chain without automatic sorting.
385    pub fn new_no_sort() -> Self {
386        Self {
387            middlewares: Vec::new(),
388            auto_sort: false,
389        }
390    }
391
392    /// Add an advanced middleware to the chain.
393    pub fn add(&mut self, middleware: Arc<dyn AdvancedMiddleware>) {
394        self.middlewares.push(middleware);
395        if self.auto_sort {
396            self.sort_by_priority();
397        }
398    }
399
400    /// Sort middleware by priority (critical first).
401    pub fn sort_by_priority(&mut self) {
402        self.middlewares.sort_by_key(|m| m.priority());
403    }
404
405    /// Get middleware count.
406    pub fn len(&self) -> usize {
407        self.middlewares.len()
408    }
409
410    /// Check if chain is empty.
411    pub fn is_empty(&self) -> bool {
412        self.middlewares.is_empty()
413    }
414
415    /// Process a request through all applicable middleware with context.
416    pub async fn process_request_with_context(
417        &self,
418        request: &mut JSONRPCRequest,
419        context: &MiddlewareContext,
420    ) -> Result<()> {
421        context.metrics.inc_requests();
422        let start_time = Instant::now();
423
424        // Notify chain start
425        for middleware in &self.middlewares {
426            if middleware.should_execute(context).await {
427                middleware.on_chain_start(context).await?;
428            }
429        }
430
431        // Process through middleware
432        for middleware in &self.middlewares {
433            if middleware.should_execute(context).await {
434                if let Err(e) = middleware.on_request_with_context(request, context).await {
435                    context.metrics.inc_errors();
436                    // Notify error to all middleware
437                    for m in &self.middlewares {
438                        if m.should_execute(context).await {
439                            let _ = m.on_error(&e, context).await;
440                        }
441                    }
442                    return Err(e);
443                }
444            }
445        }
446
447        // Notify chain complete
448        for middleware in &self.middlewares {
449            if middleware.should_execute(context).await {
450                middleware.on_chain_complete(context).await?;
451            }
452        }
453
454        context.metrics.add_time(start_time.elapsed());
455        Ok(())
456    }
457
458    /// Process a response through all applicable middleware with context.
459    pub async fn process_response_with_context(
460        &self,
461        response: &mut JSONRPCResponse,
462        context: &MiddlewareContext,
463    ) -> Result<()> {
464        let start_time = Instant::now();
465
466        // Process through middleware in reverse order for responses
467        for middleware in self.middlewares.iter().rev() {
468            if middleware.should_execute(context).await {
469                if let Err(e) = middleware.on_response_with_context(response, context).await {
470                    context.metrics.inc_errors();
471                    // Notify error to all middleware
472                    for m in &self.middlewares {
473                        if m.should_execute(context).await {
474                            let _ = m.on_error(&e, context).await;
475                        }
476                    }
477                    return Err(e);
478                }
479            }
480        }
481
482        context.metrics.add_time(start_time.elapsed());
483        Ok(())
484    }
485
486    /// Process an outgoing message through all applicable middleware.
487    pub async fn process_send_with_context(
488        &self,
489        message: &TransportMessage,
490        context: &MiddlewareContext,
491    ) -> Result<()> {
492        let start_time = Instant::now();
493
494        for middleware in &self.middlewares {
495            if middleware.should_execute(context).await {
496                if let Err(e) = middleware.on_send_with_context(message, context).await {
497                    context.metrics.inc_errors();
498                    for m in &self.middlewares {
499                        if m.should_execute(context).await {
500                            let _ = m.on_error(&e, context).await;
501                        }
502                    }
503                    return Err(e);
504                }
505            }
506        }
507
508        context.metrics.add_time(start_time.elapsed());
509        Ok(())
510    }
511
512    /// Process an incoming message through all applicable middleware.
513    pub async fn process_receive_with_context(
514        &self,
515        message: &TransportMessage,
516        context: &MiddlewareContext,
517    ) -> Result<()> {
518        let start_time = Instant::now();
519
520        for middleware in &self.middlewares {
521            if middleware.should_execute(context).await {
522                if let Err(e) = middleware.on_receive_with_context(message, context).await {
523                    context.metrics.inc_errors();
524                    for m in &self.middlewares {
525                        if m.should_execute(context).await {
526                            let _ = m.on_error(&e, context).await;
527                        }
528                    }
529                    return Err(e);
530                }
531            }
532        }
533
534        context.metrics.add_time(start_time.elapsed());
535        Ok(())
536    }
537
538    /// Get performance metrics for the chain.
539    pub fn get_metrics(&self) -> Vec<Arc<PerformanceMetrics>> {
540        // This would collect metrics from all contexts that have been processed
541        // For now, we return an empty vector as metrics are stored per-context
542        Vec::new()
543    }
544}
545
546/// Chain of middleware handlers (legacy).
547///
548/// # Examples
549///
550/// ```rust
551/// use pmcp::shared::{MiddlewareChain, LoggingMiddleware, AuthMiddleware, RetryMiddleware};
552/// use pmcp::types::{JSONRPCRequest, JSONRPCResponse, RequestId};
553/// use std::sync::Arc;
554/// use tracing::Level;
555///
556/// # async fn example() -> pmcp::Result<()> {
557/// // Create a middleware chain
558/// let mut chain = MiddlewareChain::new();
559///
560/// // Add different types of middleware in order
561/// chain.add(Arc::new(LoggingMiddleware::new(Level::INFO)));
562/// chain.add(Arc::new(AuthMiddleware::new("Bearer token-123".to_string())));
563/// chain.add(Arc::new(RetryMiddleware::default()));
564///
565/// // Create a request to process
566/// let mut request = JSONRPCRequest {
567///     jsonrpc: "2.0".to_string(),
568///     method: "prompts.get".to_string(),
569///     params: Some(serde_json::json!({
570///         "name": "code_review",
571///         "arguments": {"language": "rust", "style": "detailed"}
572///     })),
573///     id: RequestId::from(1001i64),
574/// };
575///
576/// // Process request through all middleware in order
577/// chain.process_request(&mut request).await?;
578///
579/// // Create a response to process
580/// let mut response = JSONRPCResponse {
581///     jsonrpc: "2.0".to_string(),
582///     id: RequestId::from(1001i64),
583///     payload: pmcp::types::jsonrpc::ResponsePayload::Result(
584///         serde_json::json!({"prompt": "Review the following code..."})
585///     ),
586/// };
587///
588/// // Process response through all middleware
589/// chain.process_response(&mut response).await?;
590///
591/// // The chain processes middleware in the order they were added
592/// // 1. LoggingMiddleware logs the request/response
593/// // 2. AuthMiddleware adds authentication
594/// // 3. RetryMiddleware configures retry behavior
595/// # Ok(())
596/// # }
597/// ```
598pub struct MiddlewareChain {
599    middlewares: Vec<Arc<dyn Middleware>>,
600}
601
602impl fmt::Debug for MiddlewareChain {
603    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
604        f.debug_struct("MiddlewareChain")
605            .field("count", &self.middlewares.len())
606            .finish()
607    }
608}
609
610impl Default for MiddlewareChain {
611    fn default() -> Self {
612        Self::new()
613    }
614}
615
616impl MiddlewareChain {
617    /// Create a new empty middleware chain.
618    pub fn new() -> Self {
619        Self {
620            middlewares: Vec::new(),
621        }
622    }
623
624    /// Add a middleware to the chain.
625    pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
626        self.middlewares.push(middleware);
627    }
628
629    /// Process a request through all middleware.
630    pub async fn process_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
631        for middleware in &self.middlewares {
632            middleware.on_request(request).await?;
633        }
634        Ok(())
635    }
636
637    /// Process a response through all middleware.
638    pub async fn process_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
639        for middleware in &self.middlewares {
640            middleware.on_response(response).await?;
641        }
642        Ok(())
643    }
644
645    /// Process an outgoing message through all middleware.
646    pub async fn process_send(&self, message: &TransportMessage) -> Result<()> {
647        for middleware in &self.middlewares {
648            middleware.on_send(message).await?;
649        }
650        Ok(())
651    }
652
653    /// Process an incoming message through all middleware.
654    pub async fn process_receive(&self, message: &TransportMessage) -> Result<()> {
655        for middleware in &self.middlewares {
656            middleware.on_receive(message).await?;
657        }
658        Ok(())
659    }
660}
661
662/// Logging middleware that logs all messages.
663///
664/// # Examples
665///
666/// ```rust
667/// use pmcp::shared::{LoggingMiddleware, Middleware};
668/// use pmcp::types::{JSONRPCRequest, RequestId};
669/// use tracing::Level;
670///
671/// # async fn example() -> pmcp::Result<()> {
672/// // Create logging middleware with different levels
673/// let debug_logger = LoggingMiddleware::new(Level::DEBUG);
674/// let info_logger = LoggingMiddleware::new(Level::INFO);
675/// let default_logger = LoggingMiddleware::default(); // Uses DEBUG level
676///
677/// let mut request = JSONRPCRequest {
678///     jsonrpc: "2.0".to_string(),
679///     method: "tools.list".to_string(),
680///     params: Some(serde_json::json!({"category": "development"})),
681///     id: RequestId::from(456i64),
682/// };
683///
684/// // Log at different levels
685/// debug_logger.on_request(&mut request).await?;
686/// info_logger.on_request(&mut request).await?;
687/// default_logger.on_request(&mut request).await?;
688/// # Ok(())
689/// # }
690/// ```
691#[derive(Debug)]
692pub struct LoggingMiddleware {
693    level: tracing::Level,
694}
695
696impl LoggingMiddleware {
697    /// Create a new logging middleware with the specified level.
698    pub fn new(level: tracing::Level) -> Self {
699        Self { level }
700    }
701}
702
703impl Default for LoggingMiddleware {
704    fn default() -> Self {
705        Self::new(tracing::Level::DEBUG)
706    }
707}
708
709#[async_trait]
710impl Middleware for LoggingMiddleware {
711    async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
712        match self.level {
713            tracing::Level::TRACE => tracing::trace!("Sending request: {:?}", request),
714            tracing::Level::DEBUG => tracing::debug!("Sending request: {}", request.method),
715            tracing::Level::INFO => tracing::info!("Sending request: {}", request.method),
716            tracing::Level::WARN => tracing::warn!("Sending request: {}", request.method),
717            tracing::Level::ERROR => tracing::error!("Sending request: {}", request.method),
718        }
719        Ok(())
720    }
721
722    async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
723        match self.level {
724            tracing::Level::TRACE => tracing::trace!("Received response: {:?}", response),
725            tracing::Level::DEBUG => tracing::debug!("Received response for: {:?}", response.id),
726            tracing::Level::INFO => tracing::info!("Received response"),
727            tracing::Level::WARN => tracing::warn!("Received response"),
728            tracing::Level::ERROR => tracing::error!("Received response"),
729        }
730        Ok(())
731    }
732}
733
734/// Authentication middleware that adds auth headers.
735///
736/// # Examples
737///
738/// ```rust
739/// use pmcp::shared::{AuthMiddleware, Middleware};
740/// use pmcp::types::{JSONRPCRequest, RequestId};
741///
742/// # async fn example() -> pmcp::Result<()> {
743/// // Create auth middleware with API token
744/// let auth_middleware = AuthMiddleware::new("Bearer api-token-12345".to_string());
745///
746/// let mut request = JSONRPCRequest {
747///     jsonrpc: "2.0".to_string(),
748///     method: "resources.read".to_string(),
749///     params: Some(serde_json::json!({"uri": "file:///secure/data.txt"})),
750///     id: RequestId::from(789i64),
751/// };
752///
753/// // Process request and add authentication
754/// auth_middleware.on_request(&mut request).await?;
755///
756/// // In a real implementation, the middleware would modify the request
757/// // to include authentication information
758/// # Ok(())
759/// # }
760/// ```
761#[derive(Debug)]
762pub struct AuthMiddleware {
763    #[allow(dead_code)]
764    auth_token: String,
765}
766
767impl AuthMiddleware {
768    /// Create a new auth middleware with the given token.
769    pub fn new(auth_token: String) -> Self {
770        Self { auth_token }
771    }
772}
773
774#[async_trait]
775impl Middleware for AuthMiddleware {
776    async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
777        // In a real implementation, this would add auth headers
778        // For JSON-RPC, we might add auth to params or use a wrapper
779        tracing::debug!("Adding authentication to request: {}", request.method);
780        Ok(())
781    }
782}
783
784/// Retry middleware that implements exponential backoff.
785///
786/// # Examples
787///
788/// ```rust
789/// use pmcp::shared::{RetryMiddleware, Middleware};
790/// use pmcp::types::{JSONRPCRequest, RequestId};
791///
792/// # async fn example() -> pmcp::Result<()> {
793/// // Create retry middleware with custom settings
794/// let retry_middleware = RetryMiddleware::new(
795///     5,      // max_retries
796///     1000,   // initial_delay_ms (1 second)
797///     30000   // max_delay_ms (30 seconds)
798/// );
799///
800/// // Default retry middleware (3 retries, 1s initial, 30s max)
801/// let default_retry = RetryMiddleware::default();
802///
803/// let mut request = JSONRPCRequest {
804///     jsonrpc: "2.0".to_string(),
805///     method: "tools.call".to_string(),
806///     params: Some(serde_json::json!({
807///         "name": "network_tool",
808///         "arguments": {"url": "https://api.example.com/data"}
809///     })),
810///     id: RequestId::from(999i64),
811/// };
812///
813/// // Configure request for retry handling
814/// retry_middleware.on_request(&mut request).await?;
815/// default_retry.on_request(&mut request).await?;
816///
817/// // The actual retry logic would be implemented at the transport level
818/// # Ok(())
819/// # }
820/// ```
821#[derive(Debug)]
822pub struct RetryMiddleware {
823    max_retries: u32,
824    #[allow(dead_code)]
825    initial_delay_ms: u64,
826    #[allow(dead_code)]
827    max_delay_ms: u64,
828}
829
830impl RetryMiddleware {
831    /// Create a new retry middleware.
832    pub fn new(max_retries: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
833        Self {
834            max_retries,
835            initial_delay_ms,
836            max_delay_ms,
837        }
838    }
839}
840
841impl Default for RetryMiddleware {
842    fn default() -> Self {
843        Self::new(3, 1000, 30000)
844    }
845}
846
847#[async_trait]
848impl Middleware for RetryMiddleware {
849    async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
850        // Retry logic would be implemented at the transport level
851        // This middleware just adds metadata for retry handling
852        tracing::debug!(
853            "Request {} configured with max {} retries",
854            request.method,
855            self.max_retries
856        );
857        Ok(())
858    }
859}
860
861/// Rate limiting middleware with token bucket algorithm.
862///
863/// # Examples
864///
865/// ```rust
866/// use pmcp::shared::{RateLimitMiddleware, AdvancedMiddleware, MiddlewareContext};
867/// use pmcp::types::{JSONRPCRequest, RequestId};
868/// use std::time::Duration;
869///
870/// # async fn example() -> pmcp::Result<()> {
871/// // Create rate limiter: 10 requests per second, burst of 20
872/// let rate_limiter = RateLimitMiddleware::new(10, 20, Duration::from_secs(1));
873/// let context = MiddlewareContext::default();
874///
875/// let mut request = JSONRPCRequest {
876///     jsonrpc: "2.0".to_string(),
877///     method: "tools.call".to_string(),
878///     params: Some(serde_json::json!({"name": "api_call"})),
879///     id: RequestId::from(123i64),
880/// };
881///
882/// // This will succeed if under rate limit, fail if over
883/// rate_limiter.on_request_with_context(&mut request, &context).await?;
884/// # Ok(())
885/// # }
886/// ```
887#[derive(Debug)]
888pub struct RateLimitMiddleware {
889    max_requests: u32,
890    bucket_size: u32,
891    refill_duration: Duration,
892    tokens: Arc<AtomicUsize>,
893    last_refill: Arc<RwLock<Instant>>,
894}
895
896impl RateLimitMiddleware {
897    /// Create a new rate limiting middleware.
898    pub fn new(max_requests: u32, bucket_size: u32, refill_duration: Duration) -> Self {
899        Self {
900            max_requests,
901            bucket_size,
902            refill_duration,
903            tokens: Arc::new(AtomicUsize::new(bucket_size as usize)),
904            last_refill: Arc::new(RwLock::new(Instant::now())),
905        }
906    }
907
908    /// Check if request is within rate limits.
909    fn check_rate_limit(&self) -> bool {
910        // Refill tokens based on time elapsed
911        let now = Instant::now();
912        let mut last_refill = self.last_refill.write();
913        let elapsed = now.duration_since(*last_refill);
914
915        if elapsed >= self.refill_duration {
916            let refill_count = (elapsed.as_millis() / self.refill_duration.as_millis()) as u32;
917            let tokens_to_add = (refill_count * self.max_requests).min(self.bucket_size);
918
919            self.tokens.store(
920                (self.tokens.load(Ordering::Relaxed) + tokens_to_add as usize)
921                    .min(self.bucket_size as usize),
922                Ordering::Relaxed,
923            );
924            *last_refill = now;
925        }
926
927        // Try to consume a token
928        loop {
929            let current = self.tokens.load(Ordering::Relaxed);
930            if current == 0 {
931                return false;
932            }
933            if self
934                .tokens
935                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
936                .is_ok()
937            {
938                return true;
939            }
940        }
941    }
942}
943
944#[async_trait]
945impl AdvancedMiddleware for RateLimitMiddleware {
946    fn name(&self) -> &'static str {
947        "rate_limit"
948    }
949
950    fn priority(&self) -> MiddlewarePriority {
951        MiddlewarePriority::High
952    }
953
954    async fn on_request_with_context(
955        &self,
956        request: &mut JSONRPCRequest,
957        context: &MiddlewareContext,
958    ) -> Result<()> {
959        if !self.check_rate_limit() {
960            tracing::warn!("Rate limit exceeded for request: {}", request.method);
961            context.record_metric("rate_limit_exceeded".to_string(), 1.0);
962            return Err(crate::error::Error::RateLimited);
963        }
964
965        tracing::debug!("Rate limit check passed for request: {}", request.method);
966        context.record_metric("rate_limit_passed".to_string(), 1.0);
967        Ok(())
968    }
969}
970
971/// Circuit breaker middleware for fault tolerance.
972///
973/// # Examples
974///
975/// ```rust
976/// use pmcp::shared::{CircuitBreakerMiddleware, AdvancedMiddleware, MiddlewareContext};
977/// use pmcp::types::{JSONRPCRequest, RequestId};
978/// use std::time::Duration;
979///
980/// # async fn example() -> pmcp::Result<()> {
981/// // Circuit breaker: 5 failures in 60s window trips for 30s
982/// let circuit_breaker = CircuitBreakerMiddleware::new(
983///     5,                          // failure_threshold
984///     Duration::from_secs(60),    // time_window
985///     Duration::from_secs(30),    // timeout_duration
986/// );
987/// let context = MiddlewareContext::default();
988///
989/// let mut request = JSONRPCRequest {
990///     jsonrpc: "2.0".to_string(),
991///     method: "external_service.call".to_string(),
992///     params: Some(serde_json::json!({"data": "test"})),
993///     id: RequestId::from(456i64),
994/// };
995///
996/// // This will fail fast if circuit is open
997/// circuit_breaker.on_request_with_context(&mut request, &context).await?;
998/// # Ok(())
999/// # }
1000/// ```
1001#[derive(Debug)]
1002pub struct CircuitBreakerMiddleware {
1003    failure_threshold: u32,
1004    time_window: Duration,
1005    timeout_duration: Duration,
1006    failure_count: Arc<AtomicU64>,
1007    last_failure: Arc<RwLock<Option<Instant>>>,
1008    circuit_open_time: Arc<RwLock<Option<Instant>>>,
1009}
1010
1011impl CircuitBreakerMiddleware {
1012    /// Create a new circuit breaker middleware.
1013    pub fn new(failure_threshold: u32, time_window: Duration, timeout_duration: Duration) -> Self {
1014        Self {
1015            failure_threshold,
1016            time_window,
1017            timeout_duration,
1018            failure_count: Arc::new(AtomicU64::new(0)),
1019            last_failure: Arc::new(RwLock::new(None)),
1020            circuit_open_time: Arc::new(RwLock::new(None)),
1021        }
1022    }
1023
1024    /// Check if circuit breaker should allow the request.
1025    fn should_allow_request(&self) -> bool {
1026        let now = Instant::now();
1027
1028        // Check if circuit is open and should transition to half-open
1029        let open_time_value = *self.circuit_open_time.read();
1030        if let Some(open_time) = open_time_value {
1031            if now.duration_since(open_time) > self.timeout_duration {
1032                // Transition to half-open: allow one request through
1033                *self.circuit_open_time.write() = None;
1034                self.failure_count.store(0, Ordering::Relaxed);
1035                return true;
1036            }
1037            return false; // Circuit is still open
1038        }
1039
1040        // Reset failure count if outside time window
1041        let last_failure_value = *self.last_failure.read();
1042        if let Some(last_failure) = last_failure_value {
1043            if now.duration_since(last_failure) > self.time_window {
1044                self.failure_count.store(0, Ordering::Relaxed);
1045            }
1046        }
1047
1048        // Check if failure threshold exceeded
1049        self.failure_count.load(Ordering::Relaxed) < self.failure_threshold as u64
1050    }
1051
1052    /// Record a failure and possibly open the circuit.
1053    fn record_failure(&self) {
1054        let now = Instant::now();
1055        let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
1056        *self.last_failure.write() = Some(now);
1057
1058        if failures >= self.failure_threshold as u64 {
1059            *self.circuit_open_time.write() = Some(now);
1060            tracing::warn!("Circuit breaker opened due to {} failures", failures);
1061        }
1062    }
1063}
1064
1065#[async_trait]
1066impl AdvancedMiddleware for CircuitBreakerMiddleware {
1067    fn name(&self) -> &'static str {
1068        "circuit_breaker"
1069    }
1070
1071    fn priority(&self) -> MiddlewarePriority {
1072        MiddlewarePriority::High
1073    }
1074
1075    async fn on_request_with_context(
1076        &self,
1077        request: &mut JSONRPCRequest,
1078        context: &MiddlewareContext,
1079    ) -> Result<()> {
1080        if !self.should_allow_request() {
1081            tracing::warn!(
1082                "Circuit breaker open, rejecting request: {}",
1083                request.method
1084            );
1085            context.record_metric("circuit_breaker_open".to_string(), 1.0);
1086            return Err(crate::error::Error::CircuitBreakerOpen);
1087        }
1088
1089        context.record_metric("circuit_breaker_allowed".to_string(), 1.0);
1090        Ok(())
1091    }
1092
1093    async fn on_error(
1094        &self,
1095        _error: &crate::error::Error,
1096        _context: &MiddlewareContext,
1097    ) -> Result<()> {
1098        self.record_failure();
1099        Ok(())
1100    }
1101}
1102
1103/// Metrics collection middleware for observability.
1104///
1105/// # Examples
1106///
1107/// ```rust
1108/// use pmcp::shared::{MetricsMiddleware, AdvancedMiddleware, MiddlewareContext};
1109/// use pmcp::types::{JSONRPCRequest, RequestId};
1110///
1111/// # async fn example() -> pmcp::Result<()> {
1112/// let metrics = MetricsMiddleware::new("pmcp_client".to_string());
1113/// let context = MiddlewareContext::default();
1114///
1115/// let mut request = JSONRPCRequest {
1116///     jsonrpc: "2.0".to_string(),
1117///     method: "resources.list".to_string(),
1118///     params: None,
1119///     id: RequestId::from(789i64),
1120/// };
1121///
1122/// // Automatically collects timing and count metrics
1123/// metrics.on_request_with_context(&mut request, &context).await?;
1124/// # Ok(())
1125/// # }
1126/// ```
1127#[derive(Debug)]
1128pub struct MetricsMiddleware {
1129    service_name: String,
1130    request_counts: Arc<DashMap<String, AtomicU64>>,
1131    request_durations: Arc<DashMap<String, AtomicU64>>,
1132    error_counts: Arc<DashMap<String, AtomicU64>>,
1133}
1134
1135impl MetricsMiddleware {
1136    /// Create a new metrics collection middleware.
1137    pub fn new(service_name: String) -> Self {
1138        Self {
1139            service_name,
1140            request_counts: Arc::new(DashMap::new()),
1141            request_durations: Arc::new(DashMap::new()),
1142            error_counts: Arc::new(DashMap::new()),
1143        }
1144    }
1145
1146    /// Get request count for a method.
1147    pub fn get_request_count(&self, method: &str) -> u64 {
1148        self.request_counts
1149            .get(method)
1150            .map_or(0, |c| c.load(Ordering::Relaxed))
1151    }
1152
1153    /// Get error count for a method.
1154    pub fn get_error_count(&self, method: &str) -> u64 {
1155        self.error_counts
1156            .get(method)
1157            .map_or(0, |c| c.load(Ordering::Relaxed))
1158    }
1159
1160    /// Get average duration for a method in microseconds.
1161    pub fn get_average_duration(&self, method: &str) -> u64 {
1162        let total_duration = self
1163            .request_durations
1164            .get(method)
1165            .map_or(0, |d| d.load(Ordering::Relaxed));
1166        let count = self.get_request_count(method);
1167        if count > 0 {
1168            total_duration / count
1169        } else {
1170            0
1171        }
1172    }
1173}
1174
1175#[async_trait]
1176impl AdvancedMiddleware for MetricsMiddleware {
1177    fn name(&self) -> &'static str {
1178        "metrics"
1179    }
1180
1181    fn priority(&self) -> MiddlewarePriority {
1182        MiddlewarePriority::Low
1183    }
1184
1185    async fn on_request_with_context(
1186        &self,
1187        request: &mut JSONRPCRequest,
1188        context: &MiddlewareContext,
1189    ) -> Result<()> {
1190        // Increment request count
1191        self.request_counts
1192            .entry(request.method.clone())
1193            .or_insert_with(|| AtomicU64::new(0))
1194            .fetch_add(1, Ordering::Relaxed);
1195
1196        context.set_metadata(
1197            "request_start_time".to_string(),
1198            context.start_time.elapsed().as_micros().to_string(),
1199        );
1200        context.set_metadata("service_name".to_string(), self.service_name.clone());
1201
1202        tracing::debug!(
1203            "Metrics recorded for request: {} (service: {})",
1204            request.method,
1205            self.service_name
1206        );
1207        Ok(())
1208    }
1209
1210    async fn on_response_with_context(
1211        &self,
1212        response: &mut JSONRPCResponse,
1213        context: &MiddlewareContext,
1214    ) -> Result<()> {
1215        // Record response time if we have a request method in context
1216        let duration_us = context.elapsed().as_micros() as u64;
1217
1218        if let Some(method) = context.get_metadata("method") {
1219            self.request_durations
1220                .entry(method)
1221                .or_insert_with(|| AtomicU64::new(0))
1222                .fetch_add(duration_us, Ordering::Relaxed);
1223        }
1224
1225        tracing::debug!(
1226            "Response metrics recorded for ID: {:?} ({}μs)",
1227            response.id,
1228            duration_us
1229        );
1230        Ok(())
1231    }
1232
1233    async fn on_error(
1234        &self,
1235        error: &crate::error::Error,
1236        context: &MiddlewareContext,
1237    ) -> Result<()> {
1238        if let Some(method) = context.get_metadata("method") {
1239            self.error_counts
1240                .entry(method)
1241                .or_insert_with(|| AtomicU64::new(0))
1242                .fetch_add(1, Ordering::Relaxed);
1243        }
1244
1245        tracing::warn!("Error recorded in metrics: {:?}", error);
1246        Ok(())
1247    }
1248}
1249
1250/// Compression middleware for reducing message size.
1251///
1252/// # Examples
1253///
1254/// ```rust
1255/// use pmcp::shared::{CompressionMiddleware, AdvancedMiddleware, MiddlewareContext, CompressionType};
1256/// use pmcp::types::{JSONRPCRequest, RequestId};
1257///
1258/// # async fn example() -> pmcp::Result<()> {
1259/// let compression = CompressionMiddleware::new(CompressionType::Gzip, 1024);
1260/// let context = MiddlewareContext::default();
1261///
1262/// let mut request = JSONRPCRequest {
1263///     jsonrpc: "2.0".to_string(),
1264///     method: "resources.read".to_string(),
1265///     params: Some(serde_json::json!({"uri": "file:///large_file.json"})),
1266///     id: RequestId::from(101i64),
1267/// };
1268///
1269/// // Compresses request if over threshold
1270/// compression.on_request_with_context(&mut request, &context).await?;
1271/// # Ok(())
1272/// # }
1273/// ```
1274#[derive(Debug, Clone, Copy)]
1275pub enum CompressionType {
1276    /// No compression
1277    None,
1278    /// Gzip compression
1279    Gzip,
1280    /// Deflate compression
1281    Deflate,
1282}
1283
1284/// Compression middleware for reducing message size.
1285#[derive(Debug)]
1286pub struct CompressionMiddleware {
1287    compression_type: CompressionType,
1288    min_size: usize,
1289}
1290
1291impl CompressionMiddleware {
1292    /// Create a new compression middleware.
1293    pub fn new(compression_type: CompressionType, min_size: usize) -> Self {
1294        Self {
1295            compression_type,
1296            min_size,
1297        }
1298    }
1299
1300    /// Check if content should be compressed.
1301    fn should_compress(&self, content_size: usize) -> bool {
1302        content_size >= self.min_size && !matches!(self.compression_type, CompressionType::None)
1303    }
1304}
1305
1306#[async_trait]
1307impl AdvancedMiddleware for CompressionMiddleware {
1308    fn name(&self) -> &'static str {
1309        "compression"
1310    }
1311
1312    fn priority(&self) -> MiddlewarePriority {
1313        MiddlewarePriority::Normal
1314    }
1315
1316    async fn on_send_with_context(
1317        &self,
1318        message: &TransportMessage,
1319        context: &MiddlewareContext,
1320    ) -> Result<()> {
1321        let serialized = serde_json::to_string(message).unwrap_or_default();
1322        let content_size = serialized.len();
1323
1324        if self.should_compress(content_size) {
1325            context.set_metadata(
1326                "compression_type".to_string(),
1327                format!("{:?}", self.compression_type),
1328            );
1329            context.record_metric("compression_original_size".to_string(), content_size as f64);
1330
1331            tracing::debug!("Compression applied to message of {} bytes", content_size);
1332            // In a real implementation, this would compress the message content
1333        }
1334
1335        Ok(())
1336    }
1337}
1338
1339#[cfg(test)]
1340mod tests {
1341    use super::*;
1342    use crate::types::RequestId;
1343
1344    #[tokio::test]
1345    async fn test_middleware_chain() {
1346        let mut chain = MiddlewareChain::new();
1347        chain.add(Arc::new(LoggingMiddleware::default()));
1348
1349        let mut request = JSONRPCRequest {
1350            jsonrpc: "2.0".to_string(),
1351            id: RequestId::from(1i64),
1352            method: "test".to_string(),
1353            params: None,
1354        };
1355
1356        assert!(chain.process_request(&mut request).await.is_ok());
1357    }
1358
1359    #[tokio::test]
1360    async fn test_auth_middleware() {
1361        let middleware = AuthMiddleware::new("test-token".to_string());
1362
1363        let mut request = JSONRPCRequest {
1364            jsonrpc: "2.0".to_string(),
1365            id: RequestId::from(1i64),
1366            method: "test".to_string(),
1367            params: None,
1368        };
1369
1370        assert!(middleware.on_request(&mut request).await.is_ok());
1371    }
1372}