pmcp/shared/
middleware.rs

1//! Middleware support for request/response processing.
2
3use crate::error::Result;
4use crate::shared::TransportMessage;
5use crate::types::{JSONRPCRequest, JSONRPCResponse};
6use async_trait::async_trait;
7use std::fmt;
8use std::sync::Arc;
9
10/// Middleware that can intercept and modify requests and responses.
11///
12/// # Examples
13///
14/// ```rust
15/// use pmcp::shared::{Middleware, TransportMessage};
16/// use pmcp::types::{JSONRPCRequest, JSONRPCResponse, RequestId};
17/// use async_trait::async_trait;
18///
19/// // Custom middleware that adds timing information
20/// #[derive(Debug)]
21/// struct TimingMiddleware {
22///     start_time: std::time::Instant,
23/// }
24///
25/// impl TimingMiddleware {
26///     fn new() -> Self {
27///         Self { start_time: std::time::Instant::now() }
28///     }
29/// }
30///
31/// #[async_trait]
32/// impl Middleware for TimingMiddleware {
33///     async fn on_request(&self, request: &mut JSONRPCRequest) -> pmcp::Result<()> {
34///         // Add timing metadata to request params
35///         println!("Processing request {} at {}ms",
36///             request.method,
37///             self.start_time.elapsed().as_millis());
38///         Ok(())
39///     }
40///
41///     async fn on_response(&self, response: &mut JSONRPCResponse) -> pmcp::Result<()> {
42///         println!("Response for {:?} received at {}ms",
43///             response.id,
44///             self.start_time.elapsed().as_millis());
45///         Ok(())
46///     }
47/// }
48///
49/// # async fn example() -> pmcp::Result<()> {
50/// let middleware = TimingMiddleware::new();
51/// let mut request = JSONRPCRequest {
52///     jsonrpc: "2.0".to_string(),
53///     method: "test".to_string(),
54///     params: None,
55///     id: RequestId::from(123i64),
56/// };
57///
58/// // Process request through middleware
59/// middleware.on_request(&mut request).await?;
60/// # Ok(())
61/// # }
62/// ```
63#[async_trait]
64pub trait Middleware: Send + Sync {
65    /// Called before a request is sent.
66    async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
67        let _ = request;
68        Ok(())
69    }
70
71    /// Called after a response is received.
72    async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
73        let _ = response;
74        Ok(())
75    }
76
77    /// Called when a message is sent (any type).
78    async fn on_send(&self, message: &TransportMessage) -> Result<()> {
79        let _ = message;
80        Ok(())
81    }
82
83    /// Called when a message is received (any type).
84    async fn on_receive(&self, message: &TransportMessage) -> Result<()> {
85        let _ = message;
86        Ok(())
87    }
88}
89
90/// Chain of middleware handlers.
91///
92/// # Examples
93///
94/// ```rust
95/// use pmcp::shared::{MiddlewareChain, LoggingMiddleware, AuthMiddleware, RetryMiddleware};
96/// use pmcp::types::{JSONRPCRequest, JSONRPCResponse, RequestId};
97/// use std::sync::Arc;
98/// use tracing::Level;
99///
100/// # async fn example() -> pmcp::Result<()> {
101/// // Create a middleware chain
102/// let mut chain = MiddlewareChain::new();
103///
104/// // Add different types of middleware in order
105/// chain.add(Arc::new(LoggingMiddleware::new(Level::INFO)));
106/// chain.add(Arc::new(AuthMiddleware::new("Bearer token-123".to_string())));
107/// chain.add(Arc::new(RetryMiddleware::default()));
108///
109/// // Create a request to process
110/// let mut request = JSONRPCRequest {
111///     jsonrpc: "2.0".to_string(),
112///     method: "prompts.get".to_string(),
113///     params: Some(serde_json::json!({
114///         "name": "code_review",
115///         "arguments": {"language": "rust", "style": "detailed"}
116///     })),
117///     id: RequestId::from(1001i64),
118/// };
119///
120/// // Process request through all middleware in order
121/// chain.process_request(&mut request).await?;
122///
123/// // Create a response to process
124/// let mut response = JSONRPCResponse {
125///     jsonrpc: "2.0".to_string(),
126///     id: RequestId::from(1001i64),
127///     payload: pmcp::types::jsonrpc::ResponsePayload::Result(
128///         serde_json::json!({"prompt": "Review the following code..."})
129///     ),
130/// };
131///
132/// // Process response through all middleware
133/// chain.process_response(&mut response).await?;
134///
135/// // The chain processes middleware in the order they were added
136/// // 1. LoggingMiddleware logs the request/response
137/// // 2. AuthMiddleware adds authentication
138/// // 3. RetryMiddleware configures retry behavior
139/// # Ok(())
140/// # }
141/// ```
142pub struct MiddlewareChain {
143    middlewares: Vec<Arc<dyn Middleware>>,
144}
145
146impl fmt::Debug for MiddlewareChain {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        f.debug_struct("MiddlewareChain")
149            .field("count", &self.middlewares.len())
150            .finish()
151    }
152}
153
154impl Default for MiddlewareChain {
155    fn default() -> Self {
156        Self::new()
157    }
158}
159
160impl MiddlewareChain {
161    /// Create a new empty middleware chain.
162    pub fn new() -> Self {
163        Self {
164            middlewares: Vec::new(),
165        }
166    }
167
168    /// Add a middleware to the chain.
169    pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
170        self.middlewares.push(middleware);
171    }
172
173    /// Process a request through all middleware.
174    pub async fn process_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
175        for middleware in &self.middlewares {
176            middleware.on_request(request).await?;
177        }
178        Ok(())
179    }
180
181    /// Process a response through all middleware.
182    pub async fn process_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
183        for middleware in &self.middlewares {
184            middleware.on_response(response).await?;
185        }
186        Ok(())
187    }
188
189    /// Process an outgoing message through all middleware.
190    pub async fn process_send(&self, message: &TransportMessage) -> Result<()> {
191        for middleware in &self.middlewares {
192            middleware.on_send(message).await?;
193        }
194        Ok(())
195    }
196
197    /// Process an incoming message through all middleware.
198    pub async fn process_receive(&self, message: &TransportMessage) -> Result<()> {
199        for middleware in &self.middlewares {
200            middleware.on_receive(message).await?;
201        }
202        Ok(())
203    }
204}
205
206/// Logging middleware that logs all messages.
207///
208/// # Examples
209///
210/// ```rust
211/// use pmcp::shared::{LoggingMiddleware, Middleware};
212/// use pmcp::types::{JSONRPCRequest, RequestId};
213/// use tracing::Level;
214///
215/// # async fn example() -> pmcp::Result<()> {
216/// // Create logging middleware with different levels
217/// let debug_logger = LoggingMiddleware::new(Level::DEBUG);
218/// let info_logger = LoggingMiddleware::new(Level::INFO);
219/// let default_logger = LoggingMiddleware::default(); // Uses DEBUG level
220///
221/// let mut request = JSONRPCRequest {
222///     jsonrpc: "2.0".to_string(),
223///     method: "tools.list".to_string(),
224///     params: Some(serde_json::json!({"category": "development"})),
225///     id: RequestId::from(456i64),
226/// };
227///
228/// // Log at different levels
229/// debug_logger.on_request(&mut request).await?;
230/// info_logger.on_request(&mut request).await?;
231/// default_logger.on_request(&mut request).await?;
232/// # Ok(())
233/// # }
234/// ```
235#[derive(Debug)]
236pub struct LoggingMiddleware {
237    level: tracing::Level,
238}
239
240impl LoggingMiddleware {
241    /// Create a new logging middleware with the specified level.
242    pub fn new(level: tracing::Level) -> Self {
243        Self { level }
244    }
245}
246
247impl Default for LoggingMiddleware {
248    fn default() -> Self {
249        Self::new(tracing::Level::DEBUG)
250    }
251}
252
253#[async_trait]
254impl Middleware for LoggingMiddleware {
255    async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
256        match self.level {
257            tracing::Level::TRACE => tracing::trace!("Sending request: {:?}", request),
258            tracing::Level::DEBUG => tracing::debug!("Sending request: {}", request.method),
259            tracing::Level::INFO => tracing::info!("Sending request: {}", request.method),
260            tracing::Level::WARN => tracing::warn!("Sending request: {}", request.method),
261            tracing::Level::ERROR => tracing::error!("Sending request: {}", request.method),
262        }
263        Ok(())
264    }
265
266    async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
267        match self.level {
268            tracing::Level::TRACE => tracing::trace!("Received response: {:?}", response),
269            tracing::Level::DEBUG => tracing::debug!("Received response for: {:?}", response.id),
270            tracing::Level::INFO => tracing::info!("Received response"),
271            tracing::Level::WARN => tracing::warn!("Received response"),
272            tracing::Level::ERROR => tracing::error!("Received response"),
273        }
274        Ok(())
275    }
276}
277
278/// Authentication middleware that adds auth headers.
279///
280/// # Examples
281///
282/// ```rust
283/// use pmcp::shared::{AuthMiddleware, Middleware};
284/// use pmcp::types::{JSONRPCRequest, RequestId};
285///
286/// # async fn example() -> pmcp::Result<()> {
287/// // Create auth middleware with API token
288/// let auth_middleware = AuthMiddleware::new("Bearer api-token-12345".to_string());
289///
290/// let mut request = JSONRPCRequest {
291///     jsonrpc: "2.0".to_string(),
292///     method: "resources.read".to_string(),
293///     params: Some(serde_json::json!({"uri": "file:///secure/data.txt"})),
294///     id: RequestId::from(789i64),
295/// };
296///
297/// // Process request and add authentication
298/// auth_middleware.on_request(&mut request).await?;
299///
300/// // In a real implementation, the middleware would modify the request
301/// // to include authentication information
302/// # Ok(())
303/// # }
304/// ```
305#[derive(Debug)]
306pub struct AuthMiddleware {
307    #[allow(dead_code)]
308    auth_token: String,
309}
310
311impl AuthMiddleware {
312    /// Create a new auth middleware with the given token.
313    pub fn new(auth_token: String) -> Self {
314        Self { auth_token }
315    }
316}
317
318#[async_trait]
319impl Middleware for AuthMiddleware {
320    async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
321        // In a real implementation, this would add auth headers
322        // For JSON-RPC, we might add auth to params or use a wrapper
323        tracing::debug!("Adding authentication to request: {}", request.method);
324        Ok(())
325    }
326}
327
328/// Retry middleware that implements exponential backoff.
329///
330/// # Examples
331///
332/// ```rust
333/// use pmcp::shared::{RetryMiddleware, Middleware};
334/// use pmcp::types::{JSONRPCRequest, RequestId};
335///
336/// # async fn example() -> pmcp::Result<()> {
337/// // Create retry middleware with custom settings
338/// let retry_middleware = RetryMiddleware::new(
339///     5,      // max_retries
340///     1000,   // initial_delay_ms (1 second)
341///     30000   // max_delay_ms (30 seconds)
342/// );
343///
344/// // Default retry middleware (3 retries, 1s initial, 30s max)
345/// let default_retry = RetryMiddleware::default();
346///
347/// let mut request = JSONRPCRequest {
348///     jsonrpc: "2.0".to_string(),
349///     method: "tools.call".to_string(),
350///     params: Some(serde_json::json!({
351///         "name": "network_tool",
352///         "arguments": {"url": "https://api.example.com/data"}
353///     })),
354///     id: RequestId::from(999i64),
355/// };
356///
357/// // Configure request for retry handling
358/// retry_middleware.on_request(&mut request).await?;
359/// default_retry.on_request(&mut request).await?;
360///
361/// // The actual retry logic would be implemented at the transport level
362/// # Ok(())
363/// # }
364/// ```
365#[derive(Debug)]
366pub struct RetryMiddleware {
367    max_retries: u32,
368    #[allow(dead_code)]
369    initial_delay_ms: u64,
370    #[allow(dead_code)]
371    max_delay_ms: u64,
372}
373
374impl RetryMiddleware {
375    /// Create a new retry middleware.
376    pub fn new(max_retries: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
377        Self {
378            max_retries,
379            initial_delay_ms,
380            max_delay_ms,
381        }
382    }
383}
384
385impl Default for RetryMiddleware {
386    fn default() -> Self {
387        Self::new(3, 1000, 30000)
388    }
389}
390
391#[async_trait]
392impl Middleware for RetryMiddleware {
393    async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
394        // Retry logic would be implemented at the transport level
395        // This middleware just adds metadata for retry handling
396        tracing::debug!(
397            "Request {} configured with max {} retries",
398            request.method,
399            self.max_retries
400        );
401        Ok(())
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use crate::types::RequestId;
409
410    #[tokio::test]
411    async fn test_middleware_chain() {
412        let mut chain = MiddlewareChain::new();
413        chain.add(Arc::new(LoggingMiddleware::default()));
414
415        let mut request = JSONRPCRequest {
416            jsonrpc: "2.0".to_string(),
417            id: RequestId::from(1i64),
418            method: "test".to_string(),
419            params: None,
420        };
421
422        assert!(chain.process_request(&mut request).await.is_ok());
423    }
424
425    #[tokio::test]
426    async fn test_auth_middleware() {
427        let middleware = AuthMiddleware::new("test-token".to_string());
428
429        let mut request = JSONRPCRequest {
430            jsonrpc: "2.0".to_string(),
431            id: RequestId::from(1i64),
432            method: "test".to_string(),
433            params: None,
434        };
435
436        assert!(middleware.on_request(&mut request).await.is_ok());
437    }
438}