pmcp/shared/middleware.rs
1//! Middleware support for request/response processing.
2
3use crate::error::Result;
4use crate::shared::TransportMessage;
5use crate::types::{JSONRPCRequest, JSONRPCResponse};
6use async_trait::async_trait;
7use std::fmt;
8use std::sync::Arc;
9
10/// Middleware that can intercept and modify requests and responses.
11///
12/// # Examples
13///
14/// ```rust
15/// use pmcp::shared::{Middleware, TransportMessage};
16/// use pmcp::types::{JSONRPCRequest, JSONRPCResponse, RequestId};
17/// use async_trait::async_trait;
18///
19/// // Custom middleware that adds timing information
20/// #[derive(Debug)]
21/// struct TimingMiddleware {
22/// start_time: std::time::Instant,
23/// }
24///
25/// impl TimingMiddleware {
26/// fn new() -> Self {
27/// Self { start_time: std::time::Instant::now() }
28/// }
29/// }
30///
31/// #[async_trait]
32/// impl Middleware for TimingMiddleware {
33/// async fn on_request(&self, request: &mut JSONRPCRequest) -> pmcp::Result<()> {
34/// // Add timing metadata to request params
35/// println!("Processing request {} at {}ms",
36/// request.method,
37/// self.start_time.elapsed().as_millis());
38/// Ok(())
39/// }
40///
41/// async fn on_response(&self, response: &mut JSONRPCResponse) -> pmcp::Result<()> {
42/// println!("Response for {:?} received at {}ms",
43/// response.id,
44/// self.start_time.elapsed().as_millis());
45/// Ok(())
46/// }
47/// }
48///
49/// # async fn example() -> pmcp::Result<()> {
50/// let middleware = TimingMiddleware::new();
51/// let mut request = JSONRPCRequest {
52/// jsonrpc: "2.0".to_string(),
53/// method: "test".to_string(),
54/// params: None,
55/// id: RequestId::from(123i64),
56/// };
57///
58/// // Process request through middleware
59/// middleware.on_request(&mut request).await?;
60/// # Ok(())
61/// # }
62/// ```
63#[async_trait]
64pub trait Middleware: Send + Sync {
65 /// Called before a request is sent.
66 async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
67 let _ = request;
68 Ok(())
69 }
70
71 /// Called after a response is received.
72 async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
73 let _ = response;
74 Ok(())
75 }
76
77 /// Called when a message is sent (any type).
78 async fn on_send(&self, message: &TransportMessage) -> Result<()> {
79 let _ = message;
80 Ok(())
81 }
82
83 /// Called when a message is received (any type).
84 async fn on_receive(&self, message: &TransportMessage) -> Result<()> {
85 let _ = message;
86 Ok(())
87 }
88}
89
90/// Chain of middleware handlers.
91///
92/// # Examples
93///
94/// ```rust
95/// use pmcp::shared::{MiddlewareChain, LoggingMiddleware, AuthMiddleware, RetryMiddleware};
96/// use pmcp::types::{JSONRPCRequest, JSONRPCResponse, RequestId};
97/// use std::sync::Arc;
98/// use tracing::Level;
99///
100/// # async fn example() -> pmcp::Result<()> {
101/// // Create a middleware chain
102/// let mut chain = MiddlewareChain::new();
103///
104/// // Add different types of middleware in order
105/// chain.add(Arc::new(LoggingMiddleware::new(Level::INFO)));
106/// chain.add(Arc::new(AuthMiddleware::new("Bearer token-123".to_string())));
107/// chain.add(Arc::new(RetryMiddleware::default()));
108///
109/// // Create a request to process
110/// let mut request = JSONRPCRequest {
111/// jsonrpc: "2.0".to_string(),
112/// method: "prompts.get".to_string(),
113/// params: Some(serde_json::json!({
114/// "name": "code_review",
115/// "arguments": {"language": "rust", "style": "detailed"}
116/// })),
117/// id: RequestId::from(1001i64),
118/// };
119///
120/// // Process request through all middleware in order
121/// chain.process_request(&mut request).await?;
122///
123/// // Create a response to process
124/// let mut response = JSONRPCResponse {
125/// jsonrpc: "2.0".to_string(),
126/// id: RequestId::from(1001i64),
127/// payload: pmcp::types::jsonrpc::ResponsePayload::Result(
128/// serde_json::json!({"prompt": "Review the following code..."})
129/// ),
130/// };
131///
132/// // Process response through all middleware
133/// chain.process_response(&mut response).await?;
134///
135/// // The chain processes middleware in the order they were added
136/// // 1. LoggingMiddleware logs the request/response
137/// // 2. AuthMiddleware adds authentication
138/// // 3. RetryMiddleware configures retry behavior
139/// # Ok(())
140/// # }
141/// ```
142pub struct MiddlewareChain {
143 middlewares: Vec<Arc<dyn Middleware>>,
144}
145
146impl fmt::Debug for MiddlewareChain {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 f.debug_struct("MiddlewareChain")
149 .field("count", &self.middlewares.len())
150 .finish()
151 }
152}
153
154impl Default for MiddlewareChain {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160impl MiddlewareChain {
161 /// Create a new empty middleware chain.
162 pub fn new() -> Self {
163 Self {
164 middlewares: Vec::new(),
165 }
166 }
167
168 /// Add a middleware to the chain.
169 pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
170 self.middlewares.push(middleware);
171 }
172
173 /// Process a request through all middleware.
174 pub async fn process_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
175 for middleware in &self.middlewares {
176 middleware.on_request(request).await?;
177 }
178 Ok(())
179 }
180
181 /// Process a response through all middleware.
182 pub async fn process_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
183 for middleware in &self.middlewares {
184 middleware.on_response(response).await?;
185 }
186 Ok(())
187 }
188
189 /// Process an outgoing message through all middleware.
190 pub async fn process_send(&self, message: &TransportMessage) -> Result<()> {
191 for middleware in &self.middlewares {
192 middleware.on_send(message).await?;
193 }
194 Ok(())
195 }
196
197 /// Process an incoming message through all middleware.
198 pub async fn process_receive(&self, message: &TransportMessage) -> Result<()> {
199 for middleware in &self.middlewares {
200 middleware.on_receive(message).await?;
201 }
202 Ok(())
203 }
204}
205
206/// Logging middleware that logs all messages.
207///
208/// # Examples
209///
210/// ```rust
211/// use pmcp::shared::{LoggingMiddleware, Middleware};
212/// use pmcp::types::{JSONRPCRequest, RequestId};
213/// use tracing::Level;
214///
215/// # async fn example() -> pmcp::Result<()> {
216/// // Create logging middleware with different levels
217/// let debug_logger = LoggingMiddleware::new(Level::DEBUG);
218/// let info_logger = LoggingMiddleware::new(Level::INFO);
219/// let default_logger = LoggingMiddleware::default(); // Uses DEBUG level
220///
221/// let mut request = JSONRPCRequest {
222/// jsonrpc: "2.0".to_string(),
223/// method: "tools.list".to_string(),
224/// params: Some(serde_json::json!({"category": "development"})),
225/// id: RequestId::from(456i64),
226/// };
227///
228/// // Log at different levels
229/// debug_logger.on_request(&mut request).await?;
230/// info_logger.on_request(&mut request).await?;
231/// default_logger.on_request(&mut request).await?;
232/// # Ok(())
233/// # }
234/// ```
235#[derive(Debug)]
236pub struct LoggingMiddleware {
237 level: tracing::Level,
238}
239
240impl LoggingMiddleware {
241 /// Create a new logging middleware with the specified level.
242 pub fn new(level: tracing::Level) -> Self {
243 Self { level }
244 }
245}
246
247impl Default for LoggingMiddleware {
248 fn default() -> Self {
249 Self::new(tracing::Level::DEBUG)
250 }
251}
252
253#[async_trait]
254impl Middleware for LoggingMiddleware {
255 async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
256 match self.level {
257 tracing::Level::TRACE => tracing::trace!("Sending request: {:?}", request),
258 tracing::Level::DEBUG => tracing::debug!("Sending request: {}", request.method),
259 tracing::Level::INFO => tracing::info!("Sending request: {}", request.method),
260 tracing::Level::WARN => tracing::warn!("Sending request: {}", request.method),
261 tracing::Level::ERROR => tracing::error!("Sending request: {}", request.method),
262 }
263 Ok(())
264 }
265
266 async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
267 match self.level {
268 tracing::Level::TRACE => tracing::trace!("Received response: {:?}", response),
269 tracing::Level::DEBUG => tracing::debug!("Received response for: {:?}", response.id),
270 tracing::Level::INFO => tracing::info!("Received response"),
271 tracing::Level::WARN => tracing::warn!("Received response"),
272 tracing::Level::ERROR => tracing::error!("Received response"),
273 }
274 Ok(())
275 }
276}
277
278/// Authentication middleware that adds auth headers.
279///
280/// # Examples
281///
282/// ```rust
283/// use pmcp::shared::{AuthMiddleware, Middleware};
284/// use pmcp::types::{JSONRPCRequest, RequestId};
285///
286/// # async fn example() -> pmcp::Result<()> {
287/// // Create auth middleware with API token
288/// let auth_middleware = AuthMiddleware::new("Bearer api-token-12345".to_string());
289///
290/// let mut request = JSONRPCRequest {
291/// jsonrpc: "2.0".to_string(),
292/// method: "resources.read".to_string(),
293/// params: Some(serde_json::json!({"uri": "file:///secure/data.txt"})),
294/// id: RequestId::from(789i64),
295/// };
296///
297/// // Process request and add authentication
298/// auth_middleware.on_request(&mut request).await?;
299///
300/// // In a real implementation, the middleware would modify the request
301/// // to include authentication information
302/// # Ok(())
303/// # }
304/// ```
305#[derive(Debug)]
306pub struct AuthMiddleware {
307 #[allow(dead_code)]
308 auth_token: String,
309}
310
311impl AuthMiddleware {
312 /// Create a new auth middleware with the given token.
313 pub fn new(auth_token: String) -> Self {
314 Self { auth_token }
315 }
316}
317
318#[async_trait]
319impl Middleware for AuthMiddleware {
320 async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
321 // In a real implementation, this would add auth headers
322 // For JSON-RPC, we might add auth to params or use a wrapper
323 tracing::debug!("Adding authentication to request: {}", request.method);
324 Ok(())
325 }
326}
327
328/// Retry middleware that implements exponential backoff.
329///
330/// # Examples
331///
332/// ```rust
333/// use pmcp::shared::{RetryMiddleware, Middleware};
334/// use pmcp::types::{JSONRPCRequest, RequestId};
335///
336/// # async fn example() -> pmcp::Result<()> {
337/// // Create retry middleware with custom settings
338/// let retry_middleware = RetryMiddleware::new(
339/// 5, // max_retries
340/// 1000, // initial_delay_ms (1 second)
341/// 30000 // max_delay_ms (30 seconds)
342/// );
343///
344/// // Default retry middleware (3 retries, 1s initial, 30s max)
345/// let default_retry = RetryMiddleware::default();
346///
347/// let mut request = JSONRPCRequest {
348/// jsonrpc: "2.0".to_string(),
349/// method: "tools.call".to_string(),
350/// params: Some(serde_json::json!({
351/// "name": "network_tool",
352/// "arguments": {"url": "https://api.example.com/data"}
353/// })),
354/// id: RequestId::from(999i64),
355/// };
356///
357/// // Configure request for retry handling
358/// retry_middleware.on_request(&mut request).await?;
359/// default_retry.on_request(&mut request).await?;
360///
361/// // The actual retry logic would be implemented at the transport level
362/// # Ok(())
363/// # }
364/// ```
365#[derive(Debug)]
366pub struct RetryMiddleware {
367 max_retries: u32,
368 #[allow(dead_code)]
369 initial_delay_ms: u64,
370 #[allow(dead_code)]
371 max_delay_ms: u64,
372}
373
374impl RetryMiddleware {
375 /// Create a new retry middleware.
376 pub fn new(max_retries: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
377 Self {
378 max_retries,
379 initial_delay_ms,
380 max_delay_ms,
381 }
382 }
383}
384
385impl Default for RetryMiddleware {
386 fn default() -> Self {
387 Self::new(3, 1000, 30000)
388 }
389}
390
391#[async_trait]
392impl Middleware for RetryMiddleware {
393 async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
394 // Retry logic would be implemented at the transport level
395 // This middleware just adds metadata for retry handling
396 tracing::debug!(
397 "Request {} configured with max {} retries",
398 request.method,
399 self.max_retries
400 );
401 Ok(())
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use crate::types::RequestId;
409
410 #[tokio::test]
411 async fn test_middleware_chain() {
412 let mut chain = MiddlewareChain::new();
413 chain.add(Arc::new(LoggingMiddleware::default()));
414
415 let mut request = JSONRPCRequest {
416 jsonrpc: "2.0".to_string(),
417 id: RequestId::from(1i64),
418 method: "test".to_string(),
419 params: None,
420 };
421
422 assert!(chain.process_request(&mut request).await.is_ok());
423 }
424
425 #[tokio::test]
426 async fn test_auth_middleware() {
427 let middleware = AuthMiddleware::new("test-token".to_string());
428
429 let mut request = JSONRPCRequest {
430 jsonrpc: "2.0".to_string(),
431 id: RequestId::from(1i64),
432 method: "test".to_string(),
433 params: None,
434 };
435
436 assert!(middleware.on_request(&mut request).await.is_ok());
437 }
438}