1use mcpkit_core::debug::{
7 MessageInspector, MessageRecord, MessageStats, ProtocolValidator, RecordedSession,
8 SessionRecorder, ValidationResult,
9};
10use mcpkit_core::protocol::Message;
11use std::sync::Arc;
12
13pub struct TestSession {
15 name: String,
17 inspector: Arc<MessageInspector>,
19 recorder: SessionRecorder,
21 validator: Arc<std::sync::RwLock<ProtocolValidator>>,
23 strict_mode: bool,
25}
26
27impl std::fmt::Debug for TestSession {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("TestSession")
30 .field("name", &self.name)
31 .field("inspector", &self.inspector)
32 .field("strict_mode", &self.strict_mode)
33 .finish_non_exhaustive()
34 }
35}
36
37impl TestSession {
38 pub fn new(name: impl Into<String>) -> Self {
40 let name = name.into();
41 Self {
42 name: name.clone(),
43 inspector: Arc::new(MessageInspector::new()),
44 recorder: SessionRecorder::new(name),
45 validator: Arc::new(std::sync::RwLock::new(ProtocolValidator::new())),
46 strict_mode: false,
47 }
48 }
49
50 #[must_use]
52 pub fn strict(mut self) -> Self {
53 self.strict_mode = true;
54 if let Ok(mut validator) = self.validator.write() {
55 *validator = ProtocolValidator::new().strict();
56 }
57 self
58 }
59
60 #[must_use]
62 pub fn name(&self) -> &str {
63 &self.name
64 }
65
66 #[must_use]
68 pub fn inspector(&self) -> &MessageInspector {
69 &self.inspector
70 }
71
72 #[must_use]
74 pub fn recorder(&self) -> &SessionRecorder {
75 &self.recorder
76 }
77
78 pub fn record_outbound(&self, message: Message) {
80 self.inspector.record_outbound(message.clone());
81 self.recorder.record_sent(message.clone());
82 if let Ok(mut validator) = self.validator.write() {
83 validator.validate(&message);
84 }
85 }
86
87 pub fn record_inbound(&self, message: Message) {
89 self.inspector.record_inbound(message.clone());
90 self.recorder.record_received(message.clone());
91 if let Ok(mut validator) = self.validator.write() {
92 validator.validate(&message);
93 }
94 }
95
96 pub fn record_error(&self, error: impl Into<String>) {
98 self.recorder.record_error(error);
99 }
100
101 #[must_use]
103 pub fn stats(&self) -> MessageStats {
104 self.inspector.stats()
105 }
106
107 #[must_use]
109 pub fn records(&self) -> Vec<MessageRecord> {
110 self.inspector.records()
111 }
112
113 #[must_use]
115 pub fn validation_result(&self) -> ValidationResult {
116 self.validator
117 .read()
118 .map_or_else(|_| ValidationResult::pass(), |v| v.result())
119 }
120
121 #[must_use]
123 pub fn finalize(self) -> TestSessionResult {
124 let validation = self.validator.write().map_or_else(
125 |_| ValidationResult::pass(),
126 |mut v| {
127 v.check_unmatched_requests();
128 v.result()
129 },
130 );
131
132 let session = self.recorder.finalize();
133
134 TestSessionResult {
135 name: self.name,
136 session,
137 stats: self.inspector.stats(),
138 validation,
139 }
140 }
141
142 pub fn assert_valid(&self) {
148 let result = self.validation_result();
149 assert!(
150 result.valid,
151 "Session validation failed:\n{}",
152 result
153 .errors
154 .iter()
155 .map(|e| format!(" - {e}"))
156 .collect::<Vec<_>>()
157 .join("\n")
158 );
159 }
160
161 pub fn assert_stats(&self, check: impl FnOnce(&MessageStats) -> bool, message: &str) {
167 let stats = self.stats();
168 assert!(check(&stats), "{message}. Stats: {stats:?}");
169 }
170}
171
172impl Clone for TestSession {
173 fn clone(&self) -> Self {
174 Self {
175 name: self.name.clone(),
176 inspector: Arc::clone(&self.inspector),
177 recorder: self.recorder.clone(),
178 validator: Arc::clone(&self.validator),
179 strict_mode: self.strict_mode,
180 }
181 }
182}
183
184#[derive(Debug)]
186pub struct TestSessionResult {
187 pub name: String,
189 pub session: RecordedSession,
191 pub stats: MessageStats,
193 pub validation: ValidationResult,
195}
196
197impl TestSessionResult {
198 #[must_use]
200 pub fn is_valid(&self) -> bool {
201 self.validation.valid
202 }
203
204 #[must_use]
206 pub fn message_count(&self) -> usize {
207 self.stats.total_messages
208 }
209
210 #[must_use]
212 pub fn error_count(&self) -> usize {
213 self.stats.errors
214 }
215
216 #[must_use]
218 pub fn errors(&self) -> &[mcpkit_core::debug::ValidationError] {
219 &self.validation.errors
220 }
221
222 #[must_use]
224 pub fn warnings(&self) -> &[String] {
225 &self.validation.warnings
226 }
227
228 pub fn to_json(&self) -> Result<String, serde_json::Error> {
234 self.session.to_json()
235 }
236
237 pub fn assert_valid(&self) {
243 assert!(
244 self.is_valid(),
245 "Session '{}' validation failed:\n{}",
246 self.name,
247 self.validation
248 .errors
249 .iter()
250 .map(|e| format!(" - {e}"))
251 .collect::<Vec<_>>()
252 .join("\n")
253 );
254 }
255
256 pub fn assert_message_count(&self, expected: usize) {
262 assert_eq!(
263 self.message_count(),
264 expected,
265 "Expected {} messages, got {}",
266 expected,
267 self.message_count()
268 );
269 }
270
271 pub fn assert_no_errors(&self) {
277 assert_eq!(
278 self.error_count(),
279 0,
280 "Expected no errors, got {}",
281 self.error_count()
282 );
283 }
284}
285
286#[derive(Debug, Default)]
288pub struct TestSessionBuilder {
289 name: Option<String>,
290 strict_mode: bool,
291 max_records: Option<usize>,
292}
293
294impl TestSessionBuilder {
295 #[must_use]
297 pub fn new() -> Self {
298 Self::default()
299 }
300
301 pub fn name(mut self, name: impl Into<String>) -> Self {
303 self.name = Some(name.into());
304 self
305 }
306
307 #[must_use]
309 pub fn strict(mut self) -> Self {
310 self.strict_mode = true;
311 self
312 }
313
314 #[must_use]
316 pub fn max_records(mut self, max: usize) -> Self {
317 self.max_records = Some(max);
318 self
319 }
320
321 #[must_use]
323 pub fn build(self) -> TestSession {
324 let name = self.name.unwrap_or_else(|| "test-session".to_string());
325 let mut session = TestSession::new(name);
326
327 if self.strict_mode {
328 session = session.strict();
329 }
330
331 session
332 }
333}
334
335#[derive(Debug)]
337pub struct SessionDiff {
338 pub only_in_a: Vec<usize>,
340 pub only_in_b: Vec<usize>,
342 pub different: Vec<(usize, String)>,
344}
345
346impl SessionDiff {
347 #[must_use]
349 pub fn compare(a: &RecordedSession, b: &RecordedSession) -> Self {
350 let mut diff = SessionDiff {
351 only_in_a: Vec::new(),
352 only_in_b: Vec::new(),
353 different: Vec::new(),
354 };
355
356 let a_messages = a.messages();
357 let b_messages = b.messages();
358
359 for i in 0..a_messages.len().max(b_messages.len()) {
360 match (a_messages.get(i), b_messages.get(i)) {
361 (Some(_), None) => diff.only_in_a.push(i),
362 (None, Some(_)) => diff.only_in_b.push(i),
363 (Some(ma), Some(mb)) => {
364 if format!("{ma:?}") != format!("{mb:?}") {
365 diff.different.push((i, format!("Message {i} differs")));
366 }
367 }
368 (None, None) => {}
369 }
370 }
371
372 diff
373 }
374
375 #[must_use]
377 pub fn is_identical(&self) -> bool {
378 self.only_in_a.is_empty() && self.only_in_b.is_empty() && self.different.is_empty()
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use mcpkit_core::protocol::{Request, RequestId, Response};
386
387 #[test]
388 fn test_test_session() {
389 let session = TestSession::new("test");
390
391 session.record_outbound(Message::Request(Request::new("ping", 1)));
392 session.record_inbound(Message::Response(Response::success(
393 RequestId::from(1),
394 serde_json::json!({}),
395 )));
396
397 let stats = session.stats();
398 assert_eq!(stats.requests, 1);
399 assert_eq!(stats.responses, 1);
400
401 let result = session.finalize();
402 assert!(result.is_valid());
403 assert_eq!(result.message_count(), 2);
404 }
405
406 #[test]
407 fn test_test_session_validation() {
408 let session = TestSession::new("test");
409
410 session.record_inbound(Message::Response(Response::success(
412 RequestId::from(999),
413 serde_json::json!({}),
414 )));
415
416 let result = session.finalize();
417 assert!(!result.is_valid());
418 assert!(!result.errors().is_empty());
419 }
420
421 #[test]
422 fn test_session_builder() {
423 let session = TestSessionBuilder::new()
424 .name("custom-session")
425 .strict()
426 .build();
427
428 assert_eq!(session.name(), "custom-session");
429 }
430
431 #[test]
432 fn test_session_diff() {
433 let recorder_a = SessionRecorder::new("a");
434 recorder_a.record_sent(Message::Request(Request::new("ping", 1)));
435 let session_a = recorder_a.finalize();
436
437 let recorder_b = SessionRecorder::new("b");
438 recorder_b.record_sent(Message::Request(Request::new("ping", 1)));
439 recorder_b.record_sent(Message::Request(Request::new("pong", 2)));
440 let session_b = recorder_b.finalize();
441
442 let diff = SessionDiff::compare(&session_a, &session_b);
443 assert!(!diff.is_identical());
444 assert!(!diff.only_in_b.is_empty());
445 }
446}