turbomcp_client/plugins/
middleware.rs

1//! Middleware pattern implementation for plugin system
2//!
3//! Provides middleware abstractions and chain execution patterns for
4//! request/response processing. This module focuses on the middleware
5//! pattern specifically, allowing plugins to be composed as middleware.
6
7use crate::plugins::core::{PluginResult, RequestContext, ResponseContext};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tracing::{debug, error};
11
12/// Result type for middleware operations
13pub type MiddlewareResult<T> = PluginResult<T>;
14
15/// Trait for request middleware
16///
17/// Request middleware can modify the request before it's sent to the server.
18/// They are executed in the order they are registered.
19#[async_trait]
20pub trait RequestMiddleware: Send + Sync + std::fmt::Debug {
21    /// Process the request context
22    ///
23    /// # Arguments
24    /// * `context` - Mutable request context that can be modified
25    ///
26    /// # Returns
27    /// Returns `Ok(())` to continue processing, or `PluginError` to abort.
28    async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()>;
29
30    /// Get middleware name for debugging
31    fn name(&self) -> &str;
32}
33
34/// Trait for response middleware
35///
36/// Response middleware process responses after they're received from the server.
37/// They are executed in the order they are registered.
38#[async_trait]
39pub trait ResponseMiddleware: Send + Sync + std::fmt::Debug {
40    /// Process the response context
41    ///
42    /// # Arguments
43    /// * `context` - Mutable response context that can be modified
44    ///
45    /// # Returns
46    /// Returns `Ok(())` if processing succeeds, or `PluginError` if it fails.
47    async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()>;
48
49    /// Get middleware name for debugging
50    fn name(&self) -> &str;
51}
52
53/// Chain of middleware for sequential execution
54///
55/// The MiddlewareChain manages the execution of multiple middleware
56/// components in a defined order. It provides error handling and
57/// short-circuiting behavior.
58///
59/// # Examples
60///
61/// ```rust,no_run
62/// use turbomcp_client::plugins::middleware::{MiddlewareChain, RequestMiddleware};
63/// use std::sync::Arc;
64///
65/// let mut chain = MiddlewareChain::new();
66/// // chain.add_request_middleware(Arc::new(some_middleware));
67/// // chain.add_response_middleware(Arc::new(other_middleware));
68/// ```
69#[derive(Debug)]
70pub struct MiddlewareChain {
71    /// Request middleware in execution order
72    request_middleware: Vec<Arc<dyn RequestMiddleware>>,
73
74    /// Response middleware in execution order
75    response_middleware: Vec<Arc<dyn ResponseMiddleware>>,
76}
77
78impl Default for MiddlewareChain {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl MiddlewareChain {
85    /// Create a new empty middleware chain
86    #[must_use]
87    pub fn new() -> Self {
88        Self {
89            request_middleware: Vec::new(),
90            response_middleware: Vec::new(),
91        }
92    }
93
94    /// Add request middleware to the chain
95    ///
96    /// Middleware will be executed in the order they are added.
97    ///
98    /// # Arguments
99    /// * `middleware` - The request middleware to add
100    pub fn add_request_middleware(&mut self, middleware: Arc<dyn RequestMiddleware>) {
101        debug!("Adding request middleware: {}", middleware.name());
102        self.request_middleware.push(middleware);
103    }
104
105    /// Add response middleware to the chain
106    ///
107    /// Middleware will be executed in the order they are added.
108    ///
109    /// # Arguments
110    /// * `middleware` - The response middleware to add
111    pub fn add_response_middleware(&mut self, middleware: Arc<dyn ResponseMiddleware>) {
112        debug!("Adding response middleware: {}", middleware.name());
113        self.response_middleware.push(middleware);
114    }
115
116    /// Execute the request middleware chain
117    ///
118    /// Processes the request context through all registered request middleware
119    /// in order. If any middleware returns an error, processing is aborted
120    /// and the error is returned.
121    ///
122    /// # Arguments
123    /// * `context` - Mutable request context
124    ///
125    /// # Returns
126    /// Returns `Ok(())` if all middleware succeed, or the first error encountered.
127    pub async fn execute_request_chain(
128        &self,
129        context: &mut RequestContext,
130    ) -> MiddlewareResult<()> {
131        debug!(
132            "Executing request middleware chain ({} middleware) for method: {}",
133            self.request_middleware.len(),
134            context.method()
135        );
136
137        for (index, middleware) in self.request_middleware.iter().enumerate() {
138            debug!(
139                "Processing request middleware {} of {}: {}",
140                index + 1,
141                self.request_middleware.len(),
142                middleware.name()
143            );
144
145            middleware.process_request(context).await.map_err(|e| {
146                error!(
147                    "Request middleware '{}' failed for method '{}': {}",
148                    middleware.name(),
149                    context.method(),
150                    e
151                );
152                e
153            })?;
154        }
155
156        debug!("Request middleware chain completed successfully");
157        Ok(())
158    }
159
160    /// Execute the response middleware chain
161    ///
162    /// Processes the response context through all registered response middleware
163    /// in order. Unlike request middleware, this continues execution even if
164    /// a middleware fails, logging errors but not aborting the chain.
165    ///
166    /// # Arguments
167    /// * `context` - Mutable response context
168    ///
169    /// # Returns
170    /// Returns `Ok(())` unless all middleware fail, in which case returns the last error.
171    pub async fn execute_response_chain(
172        &self,
173        context: &mut ResponseContext,
174    ) -> MiddlewareResult<()> {
175        debug!(
176            "Executing response middleware chain ({} middleware) for method: {}",
177            self.response_middleware.len(),
178            context.method()
179        );
180
181        let mut _last_error = None;
182
183        for (index, middleware) in self.response_middleware.iter().enumerate() {
184            debug!(
185                "Processing response middleware {} of {}: {}",
186                index + 1,
187                self.response_middleware.len(),
188                middleware.name()
189            );
190
191            if let Err(e) = middleware.process_response(context).await {
192                error!(
193                    "Response middleware '{}' failed for method '{}': {}",
194                    middleware.name(),
195                    context.method(),
196                    e
197                );
198                _last_error = Some(e);
199                // Continue with other middleware
200            }
201        }
202
203        debug!("Response middleware chain completed");
204
205        // For now, we don't propagate response middleware errors
206        // as they shouldn't break the response processing
207        Ok(())
208    }
209
210    /// Get the number of request middleware
211    #[must_use]
212    pub fn request_middleware_count(&self) -> usize {
213        self.request_middleware.len()
214    }
215
216    /// Get the number of response middleware
217    #[must_use]
218    pub fn response_middleware_count(&self) -> usize {
219        self.response_middleware.len()
220    }
221
222    /// Get names of all request middleware
223    #[must_use]
224    pub fn get_request_middleware_names(&self) -> Vec<String> {
225        self.request_middleware
226            .iter()
227            .map(|m| m.name().to_string())
228            .collect()
229    }
230
231    /// Get names of all response middleware
232    #[must_use]
233    pub fn get_response_middleware_names(&self) -> Vec<String> {
234        self.response_middleware
235            .iter()
236            .map(|m| m.name().to_string())
237            .collect()
238    }
239
240    /// Clear all middleware
241    pub fn clear(&mut self) {
242        debug!("Clearing all middleware from chain");
243        self.request_middleware.clear();
244        self.response_middleware.clear();
245    }
246}
247
248/// Adapter to use a ClientPlugin as RequestMiddleware
249#[derive(Debug)]
250pub struct PluginRequestMiddleware<P> {
251    plugin: P,
252}
253
254impl<P> PluginRequestMiddleware<P> {
255    /// Create a new plugin request middleware adapter
256    pub fn new(plugin: P) -> Self {
257        Self { plugin }
258    }
259}
260
261#[async_trait]
262impl<P> RequestMiddleware for PluginRequestMiddleware<P>
263where
264    P: crate::plugins::core::ClientPlugin,
265{
266    async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()> {
267        self.plugin.before_request(context).await
268    }
269
270    fn name(&self) -> &str {
271        self.plugin.name()
272    }
273}
274
275/// Adapter to use a ClientPlugin as ResponseMiddleware
276#[derive(Debug)]
277pub struct PluginResponseMiddleware<P> {
278    plugin: P,
279}
280
281impl<P> PluginResponseMiddleware<P> {
282    /// Create a new plugin response middleware adapter
283    pub fn new(plugin: P) -> Self {
284        Self { plugin }
285    }
286}
287
288#[async_trait]
289impl<P> ResponseMiddleware for PluginResponseMiddleware<P>
290where
291    P: crate::plugins::core::ClientPlugin,
292{
293    async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()> {
294        self.plugin.after_response(context).await
295    }
296
297    fn name(&self) -> &str {
298        self.plugin.name()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::plugins::core::{PluginError, RequestContext};
306    use serde_json::json;
307    use std::collections::HashMap;
308    use std::sync::{Arc, Mutex};
309    use tokio;
310    use turbomcp_protocol::MessageId;
311    use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
312
313    // Test middleware implementations
314    #[derive(Debug)]
315    struct TestRequestMiddleware {
316        name: String,
317        calls: Arc<Mutex<Vec<String>>>,
318        should_fail: bool,
319    }
320
321    impl TestRequestMiddleware {
322        fn new(name: &str) -> Self {
323            Self {
324                name: name.to_string(),
325                calls: Arc::new(Mutex::new(Vec::new())),
326                should_fail: false,
327            }
328        }
329
330        fn with_failure(mut self) -> Self {
331            self.should_fail = true;
332            self
333        }
334
335        fn get_calls(&self) -> Vec<String> {
336            self.calls.lock().unwrap().clone()
337        }
338    }
339
340    #[async_trait]
341    impl RequestMiddleware for TestRequestMiddleware {
342        async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()> {
343            self.calls
344                .lock()
345                .unwrap()
346                .push(format!("process_request:{}", context.method()));
347
348            if self.should_fail {
349                Err(PluginError::request_processing("Test middleware failure"))
350            } else {
351                Ok(())
352            }
353        }
354
355        fn name(&self) -> &str {
356            &self.name
357        }
358    }
359
360    #[derive(Debug)]
361    struct TestResponseMiddleware {
362        name: String,
363        calls: Arc<Mutex<Vec<String>>>,
364        should_fail: bool,
365    }
366
367    impl TestResponseMiddleware {
368        fn new(name: &str) -> Self {
369            Self {
370                name: name.to_string(),
371                calls: Arc::new(Mutex::new(Vec::new())),
372                should_fail: false,
373            }
374        }
375
376        fn with_failure(mut self) -> Self {
377            self.should_fail = true;
378            self
379        }
380
381        fn get_calls(&self) -> Vec<String> {
382            self.calls.lock().unwrap().clone()
383        }
384    }
385
386    #[async_trait]
387    impl ResponseMiddleware for TestResponseMiddleware {
388        async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()> {
389            self.calls
390                .lock()
391                .unwrap()
392                .push(format!("process_response:{}", context.method()));
393
394            if self.should_fail {
395                Err(PluginError::response_processing("Test middleware failure"))
396            } else {
397                Ok(())
398            }
399        }
400
401        fn name(&self) -> &str {
402            &self.name
403        }
404    }
405
406    #[tokio::test]
407    async fn test_middleware_chain_creation() {
408        let chain = MiddlewareChain::new();
409        assert_eq!(chain.request_middleware_count(), 0);
410        assert_eq!(chain.response_middleware_count(), 0);
411    }
412
413    #[tokio::test]
414    async fn test_request_middleware_registration() {
415        let mut chain = MiddlewareChain::new();
416        let middleware = Arc::new(TestRequestMiddleware::new("test"));
417
418        chain.add_request_middleware(middleware);
419
420        assert_eq!(chain.request_middleware_count(), 1);
421        assert_eq!(chain.get_request_middleware_names(), vec!["test"]);
422    }
423
424    #[tokio::test]
425    async fn test_response_middleware_registration() {
426        let mut chain = MiddlewareChain::new();
427        let middleware = Arc::new(TestResponseMiddleware::new("test"));
428
429        chain.add_response_middleware(middleware);
430
431        assert_eq!(chain.response_middleware_count(), 1);
432        assert_eq!(chain.get_response_middleware_names(), vec!["test"]);
433    }
434
435    #[tokio::test]
436    async fn test_request_middleware_execution() {
437        let mut chain = MiddlewareChain::new();
438        let middleware = Arc::new(TestRequestMiddleware::new("test"));
439
440        chain.add_request_middleware(middleware.clone());
441
442        let request = JsonRpcRequest {
443            jsonrpc: JsonRpcVersion,
444            id: MessageId::from("test"),
445            method: "test/method".to_string(),
446            params: None,
447        };
448
449        let mut context = RequestContext::new(request, HashMap::new());
450        chain.execute_request_chain(&mut context).await.unwrap();
451
452        let calls = middleware.get_calls();
453        assert!(calls.contains(&"process_request:test/method".to_string()));
454    }
455
456    #[tokio::test]
457    async fn test_response_middleware_execution() {
458        let mut chain = MiddlewareChain::new();
459        let middleware = Arc::new(TestResponseMiddleware::new("test"));
460
461        chain.add_response_middleware(middleware.clone());
462
463        let request = JsonRpcRequest {
464            jsonrpc: JsonRpcVersion,
465            id: MessageId::from("test"),
466            method: "test/method".to_string(),
467            params: None,
468        };
469
470        let request_context = RequestContext::new(request, HashMap::new());
471        let mut response_context = ResponseContext::new(
472            request_context,
473            Some(json!({"result": "success"})),
474            None,
475            std::time::Duration::from_millis(100),
476        );
477
478        chain
479            .execute_response_chain(&mut response_context)
480            .await
481            .unwrap();
482
483        let calls = middleware.get_calls();
484        assert!(calls.contains(&"process_response:test/method".to_string()));
485    }
486
487    #[tokio::test]
488    async fn test_request_middleware_error_handling() {
489        let mut chain = MiddlewareChain::new();
490        let good_middleware = Arc::new(TestRequestMiddleware::new("good"));
491        let bad_middleware = Arc::new(TestRequestMiddleware::new("bad").with_failure());
492
493        chain.add_request_middleware(good_middleware.clone());
494        chain.add_request_middleware(bad_middleware.clone());
495
496        let request = JsonRpcRequest {
497            jsonrpc: JsonRpcVersion,
498            id: MessageId::from("test"),
499            method: "test/method".to_string(),
500            params: None,
501        };
502
503        let mut context = RequestContext::new(request, HashMap::new());
504        let result = chain.execute_request_chain(&mut context).await;
505
506        assert!(result.is_err());
507        assert!(
508            good_middleware
509                .get_calls()
510                .contains(&"process_request:test/method".to_string())
511        );
512        assert!(
513            bad_middleware
514                .get_calls()
515                .contains(&"process_request:test/method".to_string())
516        );
517    }
518
519    #[tokio::test]
520    async fn test_response_middleware_error_handling() {
521        let mut chain = MiddlewareChain::new();
522        let good_middleware = Arc::new(TestResponseMiddleware::new("good"));
523        let bad_middleware = Arc::new(TestResponseMiddleware::new("bad").with_failure());
524
525        chain.add_response_middleware(good_middleware.clone());
526        chain.add_response_middleware(bad_middleware.clone());
527
528        let request = JsonRpcRequest {
529            jsonrpc: JsonRpcVersion,
530            id: MessageId::from("test"),
531            method: "test/method".to_string(),
532            params: None,
533        };
534
535        let request_context = RequestContext::new(request, HashMap::new());
536        let mut response_context = ResponseContext::new(
537            request_context,
538            Some(json!({"result": "success"})),
539            None,
540            std::time::Duration::from_millis(100),
541        );
542
543        // Response middleware continues even with errors
544        let result = chain.execute_response_chain(&mut response_context).await;
545        assert!(result.is_ok());
546
547        assert!(
548            good_middleware
549                .get_calls()
550                .contains(&"process_response:test/method".to_string())
551        );
552        assert!(
553            bad_middleware
554                .get_calls()
555                .contains(&"process_response:test/method".to_string())
556        );
557    }
558
559    #[tokio::test]
560    async fn test_middleware_execution_order() {
561        let mut chain = MiddlewareChain::new();
562        let middleware1 = Arc::new(TestRequestMiddleware::new("first"));
563        let middleware2 = Arc::new(TestRequestMiddleware::new("second"));
564        let middleware3 = Arc::new(TestRequestMiddleware::new("third"));
565
566        chain.add_request_middleware(middleware1.clone());
567        chain.add_request_middleware(middleware2.clone());
568        chain.add_request_middleware(middleware3.clone());
569
570        let request = JsonRpcRequest {
571            jsonrpc: JsonRpcVersion,
572            id: MessageId::from("test"),
573            method: "test/method".to_string(),
574            params: None,
575        };
576
577        let mut context = RequestContext::new(request, HashMap::new());
578        chain.execute_request_chain(&mut context).await.unwrap();
579
580        // All middleware should be called
581        assert!(
582            middleware1
583                .get_calls()
584                .contains(&"process_request:test/method".to_string())
585        );
586        assert!(
587            middleware2
588                .get_calls()
589                .contains(&"process_request:test/method".to_string())
590        );
591        assert!(
592            middleware3
593                .get_calls()
594                .contains(&"process_request:test/method".to_string())
595        );
596
597        // Check names are in order
598        let names = chain.get_request_middleware_names();
599        assert_eq!(names, vec!["first", "second", "third"]);
600    }
601
602    #[tokio::test]
603    async fn test_chain_clear() {
604        let mut chain = MiddlewareChain::new();
605        let req_middleware = Arc::new(TestRequestMiddleware::new("request"));
606        let resp_middleware = Arc::new(TestResponseMiddleware::new("response"));
607
608        chain.add_request_middleware(req_middleware);
609        chain.add_response_middleware(resp_middleware);
610
611        assert_eq!(chain.request_middleware_count(), 1);
612        assert_eq!(chain.response_middleware_count(), 1);
613
614        chain.clear();
615
616        assert_eq!(chain.request_middleware_count(), 0);
617        assert_eq!(chain.response_middleware_count(), 0);
618    }
619}