1use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
7use std::collections::VecDeque;
8use std::fmt;
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11
12pub type ResponseMatcherFn = Arc<dyn Fn(&Response) -> Result<(), String> + Send + Sync>;
14
15pub type NotificationMatcherFn = Arc<dyn Fn(&Notification) -> Result<(), String> + Send + Sync>;
17
18#[derive(Clone)]
20pub enum TestStep {
21 RequestResponse {
23 request: Request,
25 expected: ResponseMatcher,
27 },
28 SendNotification(Notification),
30 ExpectNotification(NotificationMatcher),
32 Wait(Duration),
34 Assert {
36 description: String,
38 check: Arc<dyn Fn() -> Result<(), String> + Send + Sync>,
40 },
41}
42
43impl fmt::Debug for TestStep {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 Self::RequestResponse { request, expected } => f
47 .debug_struct("RequestResponse")
48 .field("request", request)
49 .field("expected", expected)
50 .finish(),
51 Self::SendNotification(notif) => {
52 f.debug_tuple("SendNotification").field(notif).finish()
53 }
54 Self::ExpectNotification(matcher) => {
55 f.debug_tuple("ExpectNotification").field(matcher).finish()
56 }
57 Self::Wait(duration) => f.debug_tuple("Wait").field(duration).finish(),
58 Self::Assert { description, .. } => f
59 .debug_struct("Assert")
60 .field("description", description)
61 .field("check", &"<fn>")
62 .finish(),
63 }
64 }
65}
66
67#[derive(Clone)]
69pub struct ResponseMatcher {
70 pub expect_success: Option<bool>,
72 pub json_assertions: Vec<JsonAssertion>,
74 pub custom: Option<ResponseMatcherFn>,
76}
77
78impl fmt::Debug for ResponseMatcher {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 f.debug_struct("ResponseMatcher")
81 .field("expect_success", &self.expect_success)
82 .field("json_assertions", &self.json_assertions)
83 .field("custom", &self.custom.is_some())
84 .finish()
85 }
86}
87
88impl Default for ResponseMatcher {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl ResponseMatcher {
95 #[must_use]
97 pub fn new() -> Self {
98 Self {
99 expect_success: None,
100 json_assertions: Vec::new(),
101 custom: None,
102 }
103 }
104
105 #[must_use]
107 pub fn success() -> Self {
108 Self {
109 expect_success: Some(true),
110 ..Default::default()
111 }
112 }
113
114 #[must_use]
116 pub fn error() -> Self {
117 Self {
118 expect_success: Some(false),
119 ..Default::default()
120 }
121 }
122
123 #[must_use]
125 pub fn with_json(mut self, path: impl Into<String>, expected: serde_json::Value) -> Self {
126 self.json_assertions.push(JsonAssertion {
127 path: path.into(),
128 expected,
129 });
130 self
131 }
132
133 pub fn with_custom<F>(mut self, f: F) -> Self
135 where
136 F: Fn(&Response) -> Result<(), String> + Send + Sync + 'static,
137 {
138 self.custom = Some(Arc::new(f));
139 self
140 }
141
142 pub fn validate(&self, response: &Response) -> Result<(), String> {
144 if let Some(expect_success) = self.expect_success {
146 let is_success = response.error.is_none();
147 if expect_success && !is_success {
148 return Err(format!(
149 "Expected successful response, got error: {:?}",
150 response.error
151 ));
152 }
153 if !expect_success && is_success {
154 return Err("Expected error response, got success".to_string());
155 }
156 }
157
158 if let Some(result) = &response.result {
160 for assertion in &self.json_assertions {
161 assertion.validate(result)?;
162 }
163 } else if !self.json_assertions.is_empty() {
164 return Err("Expected result in response, but got none".to_string());
165 }
166
167 if let Some(custom) = &self.custom {
169 custom(response)?;
170 }
171
172 Ok(())
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct JsonAssertion {
179 pub path: String,
181 pub expected: serde_json::Value,
183}
184
185impl JsonAssertion {
186 pub fn validate(&self, value: &serde_json::Value) -> Result<(), String> {
188 let actual = get_json_path(value, &self.path)
189 .ok_or_else(|| format!("Path '{}' not found in response", self.path))?;
190
191 if *actual != self.expected {
192 return Err(format!(
193 "Path '{}': expected {:?}, got {:?}",
194 self.path, self.expected, actual
195 ));
196 }
197
198 Ok(())
199 }
200}
201
202fn get_json_path<'a>(value: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
204 let mut current = value;
205 for part in path.split('.') {
206 if part.is_empty() {
207 continue;
208 }
209 if let Some(index_str) = part.strip_prefix('[').and_then(|s| s.strip_suffix(']')) {
211 if let Ok(index) = index_str.parse::<usize>() {
212 current = current.get(index)?;
213 continue;
214 }
215 }
216 current = current.get(part)?;
217 }
218 Some(current)
219}
220
221#[derive(Clone)]
223pub struct NotificationMatcher {
224 pub method: Option<String>,
226 pub custom: Option<NotificationMatcherFn>,
228}
229
230impl fmt::Debug for NotificationMatcher {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 f.debug_struct("NotificationMatcher")
233 .field("method", &self.method)
234 .field("custom", &self.custom.is_some())
235 .finish()
236 }
237}
238
239impl Default for NotificationMatcher {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245impl NotificationMatcher {
246 #[must_use]
248 pub fn new() -> Self {
249 Self {
250 method: None,
251 custom: None,
252 }
253 }
254
255 pub fn method(mut self, method: impl Into<String>) -> Self {
257 self.method = Some(method.into());
258 self
259 }
260
261 pub fn with_custom<F>(mut self, f: F) -> Self
263 where
264 F: Fn(&Notification) -> Result<(), String> + Send + Sync + 'static,
265 {
266 self.custom = Some(Arc::new(f));
267 self
268 }
269
270 pub fn validate(&self, notification: &Notification) -> Result<(), String> {
272 if let Some(expected_method) = &self.method {
273 if notification.method.as_ref() != expected_method {
274 return Err(format!(
275 "Expected notification method '{}', got '{}'",
276 expected_method, notification.method
277 ));
278 }
279 }
280
281 if let Some(custom) = &self.custom {
282 custom(notification)?;
283 }
284
285 Ok(())
286 }
287}
288
289#[derive(Debug)]
291pub struct TestScenario {
292 pub name: String,
294 pub description: Option<String>,
296 pub steps: Vec<TestStep>,
298}
299
300impl TestScenario {
301 pub fn new(name: impl Into<String>) -> Self {
303 Self {
304 name: name.into(),
305 description: None,
306 steps: Vec::new(),
307 }
308 }
309
310 pub fn description(mut self, description: impl Into<String>) -> Self {
312 self.description = Some(description.into());
313 self
314 }
315
316 #[must_use]
318 pub fn request(mut self, request: Request, expected: ResponseMatcher) -> Self {
319 self.steps
320 .push(TestStep::RequestResponse { request, expected });
321 self
322 }
323
324 #[must_use]
326 pub fn send_notification(mut self, notification: Notification) -> Self {
327 self.steps.push(TestStep::SendNotification(notification));
328 self
329 }
330
331 #[must_use]
333 pub fn expect_notification(mut self, matcher: NotificationMatcher) -> Self {
334 self.steps.push(TestStep::ExpectNotification(matcher));
335 self
336 }
337
338 #[must_use]
340 pub fn wait(mut self, duration: Duration) -> Self {
341 self.steps.push(TestStep::Wait(duration));
342 self
343 }
344
345 pub fn assert<F>(mut self, description: impl Into<String>, check: F) -> Self
347 where
348 F: Fn() -> Result<(), String> + Send + Sync + 'static,
349 {
350 self.steps.push(TestStep::Assert {
351 description: description.into(),
352 check: Arc::new(check),
353 });
354 self
355 }
356
357 #[must_use]
359 pub fn initialize(self, client_name: &str, client_version: &str) -> Self {
360 let request = Request::new("initialize", RequestId::from(1)).params(serde_json::json!({
361 "protocolVersion": "2025-11-25",
362 "capabilities": {},
363 "clientInfo": {
364 "name": client_name,
365 "version": client_version
366 }
367 }));
368
369 self.request(request, ResponseMatcher::success())
370 .send_notification(Notification::new("initialized"))
371 }
372}
373
374#[derive(Debug)]
376pub struct ScenarioResult {
377 pub success: bool,
379 pub step_results: Vec<StepResult>,
381 pub error: Option<String>,
383}
384
385impl ScenarioResult {
386 #[must_use]
388 pub fn pass(step_results: Vec<StepResult>) -> Self {
389 Self {
390 success: true,
391 step_results,
392 error: None,
393 }
394 }
395
396 #[must_use]
398 pub fn fail(step_results: Vec<StepResult>, error: impl Into<String>) -> Self {
399 Self {
400 success: false,
401 step_results,
402 error: Some(error.into()),
403 }
404 }
405}
406
407#[derive(Debug)]
409pub struct StepResult {
410 pub index: usize,
412 pub description: String,
414 pub passed: bool,
416 pub error: Option<String>,
418 pub duration: Duration,
420}
421
422#[derive(Debug, Default)]
424pub struct MessageQueue {
425 outgoing: RwLock<VecDeque<Message>>,
427 incoming: RwLock<VecDeque<Message>>,
429}
430
431impl MessageQueue {
432 #[must_use]
434 pub fn new() -> Self {
435 Self::default()
436 }
437
438 pub fn queue_outgoing(&self, message: Message) {
440 if let Ok(mut queue) = self.outgoing.write() {
441 queue.push_back(message);
442 }
443 }
444
445 pub fn queue_incoming(&self, message: Message) {
447 if let Ok(mut queue) = self.incoming.write() {
448 queue.push_back(message);
449 }
450 }
451
452 pub fn take_outgoing(&self) -> Option<Message> {
454 self.outgoing.write().ok()?.pop_front()
455 }
456
457 pub fn take_incoming(&self) -> Option<Message> {
459 self.incoming.write().ok()?.pop_front()
460 }
461
462 #[must_use]
464 pub fn has_outgoing(&self) -> bool {
465 self.outgoing.read().map(|q| !q.is_empty()).unwrap_or(false)
466 }
467
468 #[must_use]
470 pub fn has_incoming(&self) -> bool {
471 self.incoming.read().map(|q| !q.is_empty()).unwrap_or(false)
472 }
473
474 #[must_use]
476 pub fn outgoing_count(&self) -> usize {
477 self.outgoing.read().map(|q| q.len()).unwrap_or(0)
478 }
479
480 #[must_use]
482 pub fn incoming_count(&self) -> usize {
483 self.incoming.read().map(|q| q.len()).unwrap_or(0)
484 }
485
486 pub fn clear(&self) {
488 if let Ok(mut queue) = self.outgoing.write() {
489 queue.clear();
490 }
491 if let Ok(mut queue) = self.incoming.write() {
492 queue.clear();
493 }
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_response_matcher_success() {
503 let matcher = ResponseMatcher::success();
504 let response = Response::success(RequestId::from(1), serde_json::json!({}));
505 assert!(matcher.validate(&response).is_ok());
506 }
507
508 #[test]
509 fn test_response_matcher_error() {
510 let matcher = ResponseMatcher::error();
511 let response = Response::error(
512 RequestId::from(1),
513 mcpkit_core::JsonRpcError::invalid_request("Invalid"),
514 );
515 assert!(matcher.validate(&response).is_ok());
516 }
517
518 #[test]
519 fn test_response_matcher_json_path() {
520 let matcher = ResponseMatcher::success()
521 .with_json("name", serde_json::json!("test"))
522 .with_json("count", serde_json::json!(42));
523
524 let response = Response::success(
525 RequestId::from(1),
526 serde_json::json!({
527 "name": "test",
528 "count": 42
529 }),
530 );
531 assert!(matcher.validate(&response).is_ok());
532 }
533
534 #[test]
535 fn test_json_path_nested() {
536 let value = serde_json::json!({
537 "user": {
538 "profile": {
539 "name": "Alice"
540 }
541 }
542 });
543
544 let result = get_json_path(&value, "user.profile.name");
545 assert_eq!(result, Some(&serde_json::json!("Alice")));
546 }
547
548 #[test]
549 fn test_notification_matcher() {
550 let matcher = NotificationMatcher::new().method("test/notify");
551 let notification = Notification::new("test/notify");
552 assert!(matcher.validate(¬ification).is_ok());
553 }
554
555 #[test]
556 fn test_scenario_builder() {
557 let scenario = TestScenario::new("test-scenario")
558 .description("A test scenario")
559 .request(Request::new("ping", 1), ResponseMatcher::success());
560
561 assert_eq!(scenario.name, "test-scenario");
562 assert_eq!(scenario.steps.len(), 1);
563 }
564
565 #[test]
566 fn test_message_queue() {
567 let queue = MessageQueue::new();
568
569 queue.queue_outgoing(Message::Request(Request::new("test", 1)));
570 queue.queue_incoming(Message::Response(Response::success(
571 RequestId::from(1),
572 serde_json::json!({}),
573 )));
574
575 assert!(queue.has_outgoing());
576 assert!(queue.has_incoming());
577 assert_eq!(queue.outgoing_count(), 1);
578 assert_eq!(queue.incoming_count(), 1);
579
580 let _ = queue.take_outgoing();
581 let _ = queue.take_incoming();
582
583 assert!(!queue.has_outgoing());
584 assert!(!queue.has_incoming());
585 }
586}