mcpkit_testing/
session.rs

1//! Session testing utilities.
2//!
3//! This module provides utilities for testing MCP sessions,
4//! integrating with the debug module for recording and validation.
5
6use mcpkit_core::debug::{
7    MessageInspector, MessageRecord, MessageStats, ProtocolValidator, RecordedSession,
8    SessionRecorder, ValidationResult,
9};
10use mcpkit_core::protocol::Message;
11use std::sync::Arc;
12
13/// A test session wrapper that combines recording, inspection, and validation.
14pub struct TestSession {
15    /// Session name.
16    name: String,
17    /// Message inspector.
18    inspector: Arc<MessageInspector>,
19    /// Session recorder.
20    recorder: SessionRecorder,
21    /// Protocol validator.
22    validator: Arc<std::sync::RwLock<ProtocolValidator>>,
23    /// Whether to validate in strict mode.
24    strict_mode: bool,
25}
26
27impl std::fmt::Debug for TestSession {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("TestSession")
30            .field("name", &self.name)
31            .field("inspector", &self.inspector)
32            .field("strict_mode", &self.strict_mode)
33            .finish_non_exhaustive()
34    }
35}
36
37impl TestSession {
38    /// Create a new test session.
39    pub fn new(name: impl Into<String>) -> Self {
40        let name = name.into();
41        Self {
42            name: name.clone(),
43            inspector: Arc::new(MessageInspector::new()),
44            recorder: SessionRecorder::new(name),
45            validator: Arc::new(std::sync::RwLock::new(ProtocolValidator::new())),
46            strict_mode: false,
47        }
48    }
49
50    /// Enable strict validation mode.
51    #[must_use]
52    pub fn strict(mut self) -> Self {
53        self.strict_mode = true;
54        if let Ok(mut validator) = self.validator.write() {
55            *validator = ProtocolValidator::new().strict();
56        }
57        self
58    }
59
60    /// Get the session name.
61    #[must_use]
62    pub fn name(&self) -> &str {
63        &self.name
64    }
65
66    /// Get the message inspector.
67    #[must_use]
68    pub fn inspector(&self) -> &MessageInspector {
69        &self.inspector
70    }
71
72    /// Get the session recorder.
73    #[must_use]
74    pub fn recorder(&self) -> &SessionRecorder {
75        &self.recorder
76    }
77
78    /// Record an outbound message (sent by client).
79    pub fn record_outbound(&self, message: Message) {
80        self.inspector.record_outbound(message.clone());
81        self.recorder.record_sent(message.clone());
82        if let Ok(mut validator) = self.validator.write() {
83            validator.validate(&message);
84        }
85    }
86
87    /// Record an inbound message (received by client).
88    pub fn record_inbound(&self, message: Message) {
89        self.inspector.record_inbound(message.clone());
90        self.recorder.record_received(message.clone());
91        if let Ok(mut validator) = self.validator.write() {
92            validator.validate(&message);
93        }
94    }
95
96    /// Record an error.
97    pub fn record_error(&self, error: impl Into<String>) {
98        self.recorder.record_error(error);
99    }
100
101    /// Get message statistics.
102    #[must_use]
103    pub fn stats(&self) -> MessageStats {
104        self.inspector.stats()
105    }
106
107    /// Get all message records.
108    #[must_use]
109    pub fn records(&self) -> Vec<MessageRecord> {
110        self.inspector.records()
111    }
112
113    /// Get the current validation result.
114    #[must_use]
115    pub fn validation_result(&self) -> ValidationResult {
116        self.validator
117            .read()
118            .map_or_else(|_| ValidationResult::pass(), |v| v.result())
119    }
120
121    /// Finalize the session and get the recorded session.
122    #[must_use]
123    pub fn finalize(self) -> TestSessionResult {
124        let validation = self.validator.write().map_or_else(
125            |_| ValidationResult::pass(),
126            |mut v| {
127                v.check_unmatched_requests();
128                v.result()
129            },
130        );
131
132        let session = self.recorder.finalize();
133
134        TestSessionResult {
135            name: self.name,
136            session,
137            stats: self.inspector.stats(),
138            validation,
139        }
140    }
141
142    /// Assert that the session is valid so far.
143    ///
144    /// # Panics
145    ///
146    /// Panics if there are validation errors.
147    pub fn assert_valid(&self) {
148        let result = self.validation_result();
149        assert!(
150            result.valid,
151            "Session validation failed:\n{}",
152            result
153                .errors
154                .iter()
155                .map(|e| format!("  - {e}"))
156                .collect::<Vec<_>>()
157                .join("\n")
158        );
159    }
160
161    /// Assert specific statistics.
162    ///
163    /// # Panics
164    ///
165    /// Panics if the assertion fails.
166    pub fn assert_stats(&self, check: impl FnOnce(&MessageStats) -> bool, message: &str) {
167        let stats = self.stats();
168        assert!(check(&stats), "{message}. Stats: {stats:?}");
169    }
170}
171
172impl Clone for TestSession {
173    fn clone(&self) -> Self {
174        Self {
175            name: self.name.clone(),
176            inspector: Arc::clone(&self.inspector),
177            recorder: self.recorder.clone(),
178            validator: Arc::clone(&self.validator),
179            strict_mode: self.strict_mode,
180        }
181    }
182}
183
184/// Result of a completed test session.
185#[derive(Debug)]
186pub struct TestSessionResult {
187    /// Session name.
188    pub name: String,
189    /// Recorded session data.
190    pub session: RecordedSession,
191    /// Final statistics.
192    pub stats: MessageStats,
193    /// Validation result.
194    pub validation: ValidationResult,
195}
196
197impl TestSessionResult {
198    /// Check if the session is valid.
199    #[must_use]
200    pub fn is_valid(&self) -> bool {
201        self.validation.valid
202    }
203
204    /// Get the number of messages.
205    #[must_use]
206    pub fn message_count(&self) -> usize {
207        self.stats.total_messages
208    }
209
210    /// Get the number of errors.
211    #[must_use]
212    pub fn error_count(&self) -> usize {
213        self.stats.errors
214    }
215
216    /// Get all validation errors.
217    #[must_use]
218    pub fn errors(&self) -> &[mcpkit_core::debug::ValidationError] {
219        &self.validation.errors
220    }
221
222    /// Get all warnings.
223    #[must_use]
224    pub fn warnings(&self) -> &[String] {
225        &self.validation.warnings
226    }
227
228    /// Export the session to JSON.
229    ///
230    /// # Errors
231    ///
232    /// Returns an error if serialization fails.
233    pub fn to_json(&self) -> Result<String, serde_json::Error> {
234        self.session.to_json()
235    }
236
237    /// Assert that the session is valid.
238    ///
239    /// # Panics
240    ///
241    /// Panics if there are validation errors.
242    pub fn assert_valid(&self) {
243        assert!(
244            self.is_valid(),
245            "Session '{}' validation failed:\n{}",
246            self.name,
247            self.validation
248                .errors
249                .iter()
250                .map(|e| format!("  - {e}"))
251                .collect::<Vec<_>>()
252                .join("\n")
253        );
254    }
255
256    /// Assert message count.
257    ///
258    /// # Panics
259    ///
260    /// Panics if the count doesn't match.
261    pub fn assert_message_count(&self, expected: usize) {
262        assert_eq!(
263            self.message_count(),
264            expected,
265            "Expected {} messages, got {}",
266            expected,
267            self.message_count()
268        );
269    }
270
271    /// Assert no errors.
272    ///
273    /// # Panics
274    ///
275    /// Panics if there are errors.
276    pub fn assert_no_errors(&self) {
277        assert_eq!(
278            self.error_count(),
279            0,
280            "Expected no errors, got {}",
281            self.error_count()
282        );
283    }
284}
285
286/// Builder for creating test sessions with custom configuration.
287#[derive(Debug, Default)]
288pub struct TestSessionBuilder {
289    name: Option<String>,
290    strict_mode: bool,
291    max_records: Option<usize>,
292}
293
294impl TestSessionBuilder {
295    /// Create a new builder.
296    #[must_use]
297    pub fn new() -> Self {
298        Self::default()
299    }
300
301    /// Set the session name.
302    pub fn name(mut self, name: impl Into<String>) -> Self {
303        self.name = Some(name.into());
304        self
305    }
306
307    /// Enable strict validation mode.
308    #[must_use]
309    pub fn strict(mut self) -> Self {
310        self.strict_mode = true;
311        self
312    }
313
314    /// Set maximum records to keep.
315    #[must_use]
316    pub fn max_records(mut self, max: usize) -> Self {
317        self.max_records = Some(max);
318        self
319    }
320
321    /// Build the test session.
322    #[must_use]
323    pub fn build(self) -> TestSession {
324        let name = self.name.unwrap_or_else(|| "test-session".to_string());
325        let mut session = TestSession::new(name);
326
327        if self.strict_mode {
328            session = session.strict();
329        }
330
331        session
332    }
333}
334
335/// Compare two sessions for differences.
336#[derive(Debug)]
337pub struct SessionDiff {
338    /// Messages only in session A.
339    pub only_in_a: Vec<usize>,
340    /// Messages only in session B.
341    pub only_in_b: Vec<usize>,
342    /// Messages that differ.
343    pub different: Vec<(usize, String)>,
344}
345
346impl SessionDiff {
347    /// Compare two recorded sessions.
348    #[must_use]
349    pub fn compare(a: &RecordedSession, b: &RecordedSession) -> Self {
350        let mut diff = SessionDiff {
351            only_in_a: Vec::new(),
352            only_in_b: Vec::new(),
353            different: Vec::new(),
354        };
355
356        let a_messages = a.messages();
357        let b_messages = b.messages();
358
359        for i in 0..a_messages.len().max(b_messages.len()) {
360            match (a_messages.get(i), b_messages.get(i)) {
361                (Some(_), None) => diff.only_in_a.push(i),
362                (None, Some(_)) => diff.only_in_b.push(i),
363                (Some(ma), Some(mb)) => {
364                    if format!("{ma:?}") != format!("{mb:?}") {
365                        diff.different.push((i, format!("Message {i} differs")));
366                    }
367                }
368                (None, None) => {}
369            }
370        }
371
372        diff
373    }
374
375    /// Check if sessions are identical.
376    #[must_use]
377    pub fn is_identical(&self) -> bool {
378        self.only_in_a.is_empty() && self.only_in_b.is_empty() && self.different.is_empty()
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use mcpkit_core::protocol::{Request, RequestId, Response};
386
387    #[test]
388    fn test_test_session() {
389        let session = TestSession::new("test");
390
391        session.record_outbound(Message::Request(Request::new("ping", 1)));
392        session.record_inbound(Message::Response(Response::success(
393            RequestId::from(1),
394            serde_json::json!({}),
395        )));
396
397        let stats = session.stats();
398        assert_eq!(stats.requests, 1);
399        assert_eq!(stats.responses, 1);
400
401        let result = session.finalize();
402        assert!(result.is_valid());
403        assert_eq!(result.message_count(), 2);
404    }
405
406    #[test]
407    fn test_test_session_validation() {
408        let session = TestSession::new("test");
409
410        // Record orphan response (no matching request)
411        session.record_inbound(Message::Response(Response::success(
412            RequestId::from(999),
413            serde_json::json!({}),
414        )));
415
416        let result = session.finalize();
417        assert!(!result.is_valid());
418        assert!(!result.errors().is_empty());
419    }
420
421    #[test]
422    fn test_session_builder() {
423        let session = TestSessionBuilder::new()
424            .name("custom-session")
425            .strict()
426            .build();
427
428        assert_eq!(session.name(), "custom-session");
429    }
430
431    #[test]
432    fn test_session_diff() {
433        let recorder_a = SessionRecorder::new("a");
434        recorder_a.record_sent(Message::Request(Request::new("ping", 1)));
435        let session_a = recorder_a.finalize();
436
437        let recorder_b = SessionRecorder::new("b");
438        recorder_b.record_sent(Message::Request(Request::new("ping", 1)));
439        recorder_b.record_sent(Message::Request(Request::new("pong", 2)));
440        let session_b = recorder_b.finalize();
441
442        let diff = SessionDiff::compare(&session_a, &session_b);
443        assert!(!diff.is_identical());
444        assert!(!diff.only_in_b.is_empty());
445    }
446}