1use crate::plugins::core::{PluginResult, RequestContext, ResponseContext};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tracing::{debug, error};
11
12pub type MiddlewareResult<T> = PluginResult<T>;
14
15#[async_trait]
20pub trait RequestMiddleware: Send + Sync + std::fmt::Debug {
21 async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()>;
29
30 fn name(&self) -> &str;
32}
33
34#[async_trait]
39pub trait ResponseMiddleware: Send + Sync + std::fmt::Debug {
40 async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()>;
48
49 fn name(&self) -> &str;
51}
52
53#[derive(Debug)]
70pub struct MiddlewareChain {
71 request_middleware: Vec<Arc<dyn RequestMiddleware>>,
73
74 response_middleware: Vec<Arc<dyn ResponseMiddleware>>,
76}
77
78impl Default for MiddlewareChain {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl MiddlewareChain {
85 #[must_use]
87 pub fn new() -> Self {
88 Self {
89 request_middleware: Vec::new(),
90 response_middleware: Vec::new(),
91 }
92 }
93
94 pub fn add_request_middleware(&mut self, middleware: Arc<dyn RequestMiddleware>) {
101 debug!("Adding request middleware: {}", middleware.name());
102 self.request_middleware.push(middleware);
103 }
104
105 pub fn add_response_middleware(&mut self, middleware: Arc<dyn ResponseMiddleware>) {
112 debug!("Adding response middleware: {}", middleware.name());
113 self.response_middleware.push(middleware);
114 }
115
116 pub async fn execute_request_chain(
128 &self,
129 context: &mut RequestContext,
130 ) -> MiddlewareResult<()> {
131 debug!(
132 "Executing request middleware chain ({} middleware) for method: {}",
133 self.request_middleware.len(),
134 context.method()
135 );
136
137 for (index, middleware) in self.request_middleware.iter().enumerate() {
138 debug!(
139 "Processing request middleware {} of {}: {}",
140 index + 1,
141 self.request_middleware.len(),
142 middleware.name()
143 );
144
145 middleware.process_request(context).await.map_err(|e| {
146 error!(
147 "Request middleware '{}' failed for method '{}': {}",
148 middleware.name(),
149 context.method(),
150 e
151 );
152 e
153 })?;
154 }
155
156 debug!("Request middleware chain completed successfully");
157 Ok(())
158 }
159
160 pub async fn execute_response_chain(
172 &self,
173 context: &mut ResponseContext,
174 ) -> MiddlewareResult<()> {
175 debug!(
176 "Executing response middleware chain ({} middleware) for method: {}",
177 self.response_middleware.len(),
178 context.method()
179 );
180
181 let mut _last_error = None;
182
183 for (index, middleware) in self.response_middleware.iter().enumerate() {
184 debug!(
185 "Processing response middleware {} of {}: {}",
186 index + 1,
187 self.response_middleware.len(),
188 middleware.name()
189 );
190
191 if let Err(e) = middleware.process_response(context).await {
192 error!(
193 "Response middleware '{}' failed for method '{}': {}",
194 middleware.name(),
195 context.method(),
196 e
197 );
198 _last_error = Some(e);
199 }
201 }
202
203 debug!("Response middleware chain completed");
204
205 Ok(())
208 }
209
210 #[must_use]
212 pub fn request_middleware_count(&self) -> usize {
213 self.request_middleware.len()
214 }
215
216 #[must_use]
218 pub fn response_middleware_count(&self) -> usize {
219 self.response_middleware.len()
220 }
221
222 #[must_use]
224 pub fn get_request_middleware_names(&self) -> Vec<String> {
225 self.request_middleware
226 .iter()
227 .map(|m| m.name().to_string())
228 .collect()
229 }
230
231 #[must_use]
233 pub fn get_response_middleware_names(&self) -> Vec<String> {
234 self.response_middleware
235 .iter()
236 .map(|m| m.name().to_string())
237 .collect()
238 }
239
240 pub fn clear(&mut self) {
242 debug!("Clearing all middleware from chain");
243 self.request_middleware.clear();
244 self.response_middleware.clear();
245 }
246}
247
248#[derive(Debug)]
250pub struct PluginRequestMiddleware<P> {
251 plugin: P,
252}
253
254impl<P> PluginRequestMiddleware<P> {
255 pub fn new(plugin: P) -> Self {
257 Self { plugin }
258 }
259}
260
261#[async_trait]
262impl<P> RequestMiddleware for PluginRequestMiddleware<P>
263where
264 P: crate::plugins::core::ClientPlugin,
265{
266 async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()> {
267 self.plugin.before_request(context).await
268 }
269
270 fn name(&self) -> &str {
271 self.plugin.name()
272 }
273}
274
275#[derive(Debug)]
277pub struct PluginResponseMiddleware<P> {
278 plugin: P,
279}
280
281impl<P> PluginResponseMiddleware<P> {
282 pub fn new(plugin: P) -> Self {
284 Self { plugin }
285 }
286}
287
288#[async_trait]
289impl<P> ResponseMiddleware for PluginResponseMiddleware<P>
290where
291 P: crate::plugins::core::ClientPlugin,
292{
293 async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()> {
294 self.plugin.after_response(context).await
295 }
296
297 fn name(&self) -> &str {
298 self.plugin.name()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::plugins::core::{PluginError, RequestContext};
306 use serde_json::json;
307 use std::collections::HashMap;
308 use std::sync::{Arc, Mutex};
309 use tokio;
310 use turbomcp_protocol::MessageId;
311 use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
312
313 #[derive(Debug)]
315 struct TestRequestMiddleware {
316 name: String,
317 calls: Arc<Mutex<Vec<String>>>,
318 should_fail: bool,
319 }
320
321 impl TestRequestMiddleware {
322 fn new(name: &str) -> Self {
323 Self {
324 name: name.to_string(),
325 calls: Arc::new(Mutex::new(Vec::new())),
326 should_fail: false,
327 }
328 }
329
330 fn with_failure(mut self) -> Self {
331 self.should_fail = true;
332 self
333 }
334
335 fn get_calls(&self) -> Vec<String> {
336 self.calls.lock().unwrap().clone()
337 }
338 }
339
340 #[async_trait]
341 impl RequestMiddleware for TestRequestMiddleware {
342 async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()> {
343 self.calls
344 .lock()
345 .unwrap()
346 .push(format!("process_request:{}", context.method()));
347
348 if self.should_fail {
349 Err(PluginError::request_processing("Test middleware failure"))
350 } else {
351 Ok(())
352 }
353 }
354
355 fn name(&self) -> &str {
356 &self.name
357 }
358 }
359
360 #[derive(Debug)]
361 struct TestResponseMiddleware {
362 name: String,
363 calls: Arc<Mutex<Vec<String>>>,
364 should_fail: bool,
365 }
366
367 impl TestResponseMiddleware {
368 fn new(name: &str) -> Self {
369 Self {
370 name: name.to_string(),
371 calls: Arc::new(Mutex::new(Vec::new())),
372 should_fail: false,
373 }
374 }
375
376 fn with_failure(mut self) -> Self {
377 self.should_fail = true;
378 self
379 }
380
381 fn get_calls(&self) -> Vec<String> {
382 self.calls.lock().unwrap().clone()
383 }
384 }
385
386 #[async_trait]
387 impl ResponseMiddleware for TestResponseMiddleware {
388 async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()> {
389 self.calls
390 .lock()
391 .unwrap()
392 .push(format!("process_response:{}", context.method()));
393
394 if self.should_fail {
395 Err(PluginError::response_processing("Test middleware failure"))
396 } else {
397 Ok(())
398 }
399 }
400
401 fn name(&self) -> &str {
402 &self.name
403 }
404 }
405
406 #[tokio::test]
407 async fn test_middleware_chain_creation() {
408 let chain = MiddlewareChain::new();
409 assert_eq!(chain.request_middleware_count(), 0);
410 assert_eq!(chain.response_middleware_count(), 0);
411 }
412
413 #[tokio::test]
414 async fn test_request_middleware_registration() {
415 let mut chain = MiddlewareChain::new();
416 let middleware = Arc::new(TestRequestMiddleware::new("test"));
417
418 chain.add_request_middleware(middleware);
419
420 assert_eq!(chain.request_middleware_count(), 1);
421 assert_eq!(chain.get_request_middleware_names(), vec!["test"]);
422 }
423
424 #[tokio::test]
425 async fn test_response_middleware_registration() {
426 let mut chain = MiddlewareChain::new();
427 let middleware = Arc::new(TestResponseMiddleware::new("test"));
428
429 chain.add_response_middleware(middleware);
430
431 assert_eq!(chain.response_middleware_count(), 1);
432 assert_eq!(chain.get_response_middleware_names(), vec!["test"]);
433 }
434
435 #[tokio::test]
436 async fn test_request_middleware_execution() {
437 let mut chain = MiddlewareChain::new();
438 let middleware = Arc::new(TestRequestMiddleware::new("test"));
439
440 chain.add_request_middleware(middleware.clone());
441
442 let request = JsonRpcRequest {
443 jsonrpc: JsonRpcVersion,
444 id: MessageId::from("test"),
445 method: "test/method".to_string(),
446 params: None,
447 };
448
449 let mut context = RequestContext::new(request, HashMap::new());
450 chain.execute_request_chain(&mut context).await.unwrap();
451
452 let calls = middleware.get_calls();
453 assert!(calls.contains(&"process_request:test/method".to_string()));
454 }
455
456 #[tokio::test]
457 async fn test_response_middleware_execution() {
458 let mut chain = MiddlewareChain::new();
459 let middleware = Arc::new(TestResponseMiddleware::new("test"));
460
461 chain.add_response_middleware(middleware.clone());
462
463 let request = JsonRpcRequest {
464 jsonrpc: JsonRpcVersion,
465 id: MessageId::from("test"),
466 method: "test/method".to_string(),
467 params: None,
468 };
469
470 let request_context = RequestContext::new(request, HashMap::new());
471 let mut response_context = ResponseContext::new(
472 request_context,
473 Some(json!({"result": "success"})),
474 None,
475 std::time::Duration::from_millis(100),
476 );
477
478 chain
479 .execute_response_chain(&mut response_context)
480 .await
481 .unwrap();
482
483 let calls = middleware.get_calls();
484 assert!(calls.contains(&"process_response:test/method".to_string()));
485 }
486
487 #[tokio::test]
488 async fn test_request_middleware_error_handling() {
489 let mut chain = MiddlewareChain::new();
490 let good_middleware = Arc::new(TestRequestMiddleware::new("good"));
491 let bad_middleware = Arc::new(TestRequestMiddleware::new("bad").with_failure());
492
493 chain.add_request_middleware(good_middleware.clone());
494 chain.add_request_middleware(bad_middleware.clone());
495
496 let request = JsonRpcRequest {
497 jsonrpc: JsonRpcVersion,
498 id: MessageId::from("test"),
499 method: "test/method".to_string(),
500 params: None,
501 };
502
503 let mut context = RequestContext::new(request, HashMap::new());
504 let result = chain.execute_request_chain(&mut context).await;
505
506 assert!(result.is_err());
507 assert!(
508 good_middleware
509 .get_calls()
510 .contains(&"process_request:test/method".to_string())
511 );
512 assert!(
513 bad_middleware
514 .get_calls()
515 .contains(&"process_request:test/method".to_string())
516 );
517 }
518
519 #[tokio::test]
520 async fn test_response_middleware_error_handling() {
521 let mut chain = MiddlewareChain::new();
522 let good_middleware = Arc::new(TestResponseMiddleware::new("good"));
523 let bad_middleware = Arc::new(TestResponseMiddleware::new("bad").with_failure());
524
525 chain.add_response_middleware(good_middleware.clone());
526 chain.add_response_middleware(bad_middleware.clone());
527
528 let request = JsonRpcRequest {
529 jsonrpc: JsonRpcVersion,
530 id: MessageId::from("test"),
531 method: "test/method".to_string(),
532 params: None,
533 };
534
535 let request_context = RequestContext::new(request, HashMap::new());
536 let mut response_context = ResponseContext::new(
537 request_context,
538 Some(json!({"result": "success"})),
539 None,
540 std::time::Duration::from_millis(100),
541 );
542
543 let result = chain.execute_response_chain(&mut response_context).await;
545 assert!(result.is_ok());
546
547 assert!(
548 good_middleware
549 .get_calls()
550 .contains(&"process_response:test/method".to_string())
551 );
552 assert!(
553 bad_middleware
554 .get_calls()
555 .contains(&"process_response:test/method".to_string())
556 );
557 }
558
559 #[tokio::test]
560 async fn test_middleware_execution_order() {
561 let mut chain = MiddlewareChain::new();
562 let middleware1 = Arc::new(TestRequestMiddleware::new("first"));
563 let middleware2 = Arc::new(TestRequestMiddleware::new("second"));
564 let middleware3 = Arc::new(TestRequestMiddleware::new("third"));
565
566 chain.add_request_middleware(middleware1.clone());
567 chain.add_request_middleware(middleware2.clone());
568 chain.add_request_middleware(middleware3.clone());
569
570 let request = JsonRpcRequest {
571 jsonrpc: JsonRpcVersion,
572 id: MessageId::from("test"),
573 method: "test/method".to_string(),
574 params: None,
575 };
576
577 let mut context = RequestContext::new(request, HashMap::new());
578 chain.execute_request_chain(&mut context).await.unwrap();
579
580 assert!(
582 middleware1
583 .get_calls()
584 .contains(&"process_request:test/method".to_string())
585 );
586 assert!(
587 middleware2
588 .get_calls()
589 .contains(&"process_request:test/method".to_string())
590 );
591 assert!(
592 middleware3
593 .get_calls()
594 .contains(&"process_request:test/method".to_string())
595 );
596
597 let names = chain.get_request_middleware_names();
599 assert_eq!(names, vec!["first", "second", "third"]);
600 }
601
602 #[tokio::test]
603 async fn test_chain_clear() {
604 let mut chain = MiddlewareChain::new();
605 let req_middleware = Arc::new(TestRequestMiddleware::new("request"));
606 let resp_middleware = Arc::new(TestResponseMiddleware::new("response"));
607
608 chain.add_request_middleware(req_middleware);
609 chain.add_response_middleware(resp_middleware);
610
611 assert_eq!(chain.request_middleware_count(), 1);
612 assert_eq!(chain.response_middleware_count(), 1);
613
614 chain.clear();
615
616 assert_eq!(chain.request_middleware_count(), 0);
617 assert_eq!(chain.response_middleware_count(), 0);
618 }
619}