mcpkit_core/debug/
validator.rs

1//! Protocol validation utilities.
2//!
3//! The protocol validator checks MCP message sequences for correctness,
4//! helping identify protocol violations during development.
5
6use crate::protocol::{Message, Notification, Request, RequestId, Response};
7use std::collections::{HashMap, HashSet};
8
9/// Protocol validation error.
10#[derive(Debug, Clone, thiserror::Error)]
11pub enum ValidationError {
12    /// Response without matching request.
13    #[error("orphan response: no request found for ID {id:?}")]
14    OrphanResponse {
15        /// The orphan response ID.
16        id: RequestId,
17    },
18
19    /// Duplicate request ID.
20    #[error("duplicate request ID: {id:?}")]
21    DuplicateRequestId {
22        /// The duplicate request ID.
23        id: RequestId,
24    },
25
26    /// Request without response (timed out).
27    #[error("unmatched request: {method} (ID: {id:?})")]
28    UnmatchedRequest {
29        /// The unmatched request ID.
30        id: RequestId,
31        /// The method name.
32        method: String,
33    },
34
35    /// Unknown method.
36    #[error("unknown method: {method}")]
37    UnknownMethod {
38        /// The unknown method name.
39        method: String,
40    },
41
42    /// Invalid message sequence.
43    #[error("invalid sequence: {message}")]
44    InvalidSequence {
45        /// Description of the sequence error.
46        message: String,
47    },
48
49    /// Missing required initialization.
50    #[error("missing initialization: {message}")]
51    MissingInitialization {
52        /// Description of what is missing.
53        message: String,
54    },
55}
56
57/// Result of protocol validation.
58#[derive(Debug, Clone)]
59pub struct ValidationResult {
60    /// Whether validation passed.
61    pub valid: bool,
62    /// Validation errors found.
63    pub errors: Vec<ValidationError>,
64    /// Warnings (non-fatal issues).
65    pub warnings: Vec<String>,
66    /// Summary statistics.
67    pub stats: ValidationStats,
68}
69
70impl ValidationResult {
71    /// Create a passing result.
72    #[must_use]
73    pub fn pass() -> Self {
74        Self {
75            valid: true,
76            errors: Vec::new(),
77            warnings: Vec::new(),
78            stats: ValidationStats::default(),
79        }
80    }
81
82    /// Create a failing result.
83    #[must_use]
84    pub fn fail(errors: Vec<ValidationError>) -> Self {
85        Self {
86            valid: false,
87            errors,
88            warnings: Vec::new(),
89            stats: ValidationStats::default(),
90        }
91    }
92
93    /// Add a warning.
94    pub fn add_warning(&mut self, warning: impl Into<String>) {
95        self.warnings.push(warning.into());
96    }
97}
98
99/// Validation statistics.
100#[derive(Debug, Clone, Default)]
101pub struct ValidationStats {
102    /// Total messages validated.
103    pub total_messages: usize,
104    /// Requests validated.
105    pub requests: usize,
106    /// Responses validated.
107    pub responses: usize,
108    /// Notifications validated.
109    pub notifications: usize,
110    /// Matched request-response pairs.
111    pub matched_pairs: usize,
112}
113
114/// Protocol validator for checking MCP message sequences.
115///
116/// The validator tracks the protocol state and checks for:
117/// - Request-response matching
118/// - Duplicate request IDs
119/// - Proper initialization sequence
120/// - Known methods
121#[derive(Debug)]
122pub struct ProtocolValidator {
123    /// Known request methods.
124    known_request_methods: HashSet<String>,
125    /// Known notification methods.
126    known_notification_methods: HashSet<String>,
127    /// Pending requests (waiting for response).
128    pending_requests: HashMap<RequestId, String>,
129    /// Seen request IDs (for duplicate detection).
130    seen_request_ids: HashSet<RequestId>,
131    /// Whether initialization is complete.
132    initialized: bool,
133    /// Collected errors.
134    errors: Vec<ValidationError>,
135    /// Collected warnings.
136    warnings: Vec<String>,
137    /// Stats.
138    stats: ValidationStats,
139    /// Strict mode (unknown methods are errors).
140    strict_mode: bool,
141}
142
143impl Default for ProtocolValidator {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149impl ProtocolValidator {
150    /// Create a new validator with default MCP methods.
151    #[must_use]
152    pub fn new() -> Self {
153        let mut validator = Self {
154            known_request_methods: HashSet::new(),
155            known_notification_methods: HashSet::new(),
156            pending_requests: HashMap::new(),
157            seen_request_ids: HashSet::new(),
158            initialized: false,
159            errors: Vec::new(),
160            warnings: Vec::new(),
161            stats: ValidationStats::default(),
162            strict_mode: false,
163        };
164
165        // Add standard MCP methods
166        validator.register_request_methods(&[
167            "initialize",
168            "ping",
169            "tools/list",
170            "tools/call",
171            "resources/list",
172            "resources/read",
173            "resources/subscribe",
174            "resources/unsubscribe",
175            "prompts/list",
176            "prompts/get",
177            "logging/setLevel",
178            "completion/complete",
179            "sampling/createMessage",
180            "roots/list",
181        ]);
182
183        validator.register_notification_methods(&[
184            "initialized",
185            "notifications/cancelled",
186            "notifications/progress",
187            "notifications/message",
188            "notifications/resources/updated",
189            "notifications/resources/list_changed",
190            "notifications/tools/list_changed",
191            "notifications/prompts/list_changed",
192            "notifications/roots/list_changed",
193        ]);
194
195        validator
196    }
197
198    /// Enable strict mode (unknown methods are errors).
199    #[must_use]
200    pub fn strict(mut self) -> Self {
201        self.strict_mode = true;
202        self
203    }
204
205    /// Register additional request methods.
206    pub fn register_request_methods(&mut self, methods: &[&str]) {
207        for method in methods {
208            self.known_request_methods.insert((*method).to_string());
209        }
210    }
211
212    /// Register additional notification methods.
213    pub fn register_notification_methods(&mut self, methods: &[&str]) {
214        for method in methods {
215            self.known_notification_methods
216                .insert((*method).to_string());
217        }
218    }
219
220    /// Validate a single message.
221    pub fn validate(&mut self, message: &Message) {
222        self.stats.total_messages += 1;
223
224        match message {
225            Message::Request(req) => self.validate_request(req),
226            Message::Response(res) => self.validate_response(res),
227            Message::Notification(notif) => self.validate_notification(notif),
228        }
229    }
230
231    fn validate_request(&mut self, request: &Request) {
232        self.stats.requests += 1;
233
234        // Check for duplicate ID
235        if self.seen_request_ids.contains(&request.id) {
236            self.errors.push(ValidationError::DuplicateRequestId {
237                id: request.id.clone(),
238            });
239        }
240        self.seen_request_ids.insert(request.id.clone());
241
242        // Track pending request
243        self.pending_requests
244            .insert(request.id.clone(), request.method.to_string());
245
246        // Check method
247        let method = request.method.as_ref();
248        if !self.known_request_methods.contains(method) {
249            if self.strict_mode {
250                self.errors.push(ValidationError::UnknownMethod {
251                    method: method.to_string(),
252                });
253            } else {
254                self.warnings
255                    .push(format!("Unknown request method: {method}"));
256            }
257        }
258
259        // Check initialization
260        if method == "initialize" {
261            if self.initialized {
262                self.warnings
263                    .push("Duplicate initialize request".to_string());
264            }
265        } else if !self.initialized && method != "ping" {
266            self.warnings
267                .push(format!("Request before initialization: {method}"));
268        }
269    }
270
271    fn validate_response(&mut self, response: &Response) {
272        self.stats.responses += 1;
273
274        // Check for matching request
275        if self.pending_requests.remove(&response.id).is_some() {
276            self.stats.matched_pairs += 1;
277        } else {
278            self.errors.push(ValidationError::OrphanResponse {
279                id: response.id.clone(),
280            });
281        }
282    }
283
284    fn validate_notification(&mut self, notification: &Notification) {
285        self.stats.notifications += 1;
286
287        let method = notification.method.as_ref();
288
289        // Check method
290        if !self.known_notification_methods.contains(method) {
291            if self.strict_mode {
292                self.errors.push(ValidationError::UnknownMethod {
293                    method: method.to_string(),
294                });
295            } else {
296                self.warnings
297                    .push(format!("Unknown notification method: {method}"));
298            }
299        }
300
301        // Track initialization
302        if method == "initialized" {
303            self.initialized = true;
304        }
305    }
306
307    /// Check for unmatched requests (call after all messages are validated).
308    pub fn check_unmatched_requests(&mut self) {
309        for (id, method) in self.pending_requests.drain() {
310            self.errors
311                .push(ValidationError::UnmatchedRequest { id, method });
312        }
313    }
314
315    /// Get the validation result.
316    #[must_use]
317    pub fn result(&self) -> ValidationResult {
318        ValidationResult {
319            valid: self.errors.is_empty(),
320            errors: self.errors.clone(),
321            warnings: self.warnings.clone(),
322            stats: self.stats.clone(),
323        }
324    }
325
326    /// Finalize validation and get result.
327    #[must_use]
328    pub fn finalize(mut self) -> ValidationResult {
329        self.check_unmatched_requests();
330        self.result()
331    }
332
333    /// Reset the validator state.
334    pub fn reset(&mut self) {
335        self.pending_requests.clear();
336        self.seen_request_ids.clear();
337        self.initialized = false;
338        self.errors.clear();
339        self.warnings.clear();
340        self.stats = ValidationStats::default();
341    }
342}
343
344/// Validate a sequence of messages.
345#[must_use]
346pub fn validate_message_sequence(messages: &[Message]) -> ValidationResult {
347    let mut validator = ProtocolValidator::new();
348
349    for msg in messages {
350        validator.validate(msg);
351    }
352
353    validator.finalize()
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_valid_sequence() {
362        let messages = vec![
363            Message::Request(Request::new("initialize", 1)),
364            Message::Response(Response::success(RequestId::from(1), serde_json::json!({}))),
365            Message::Notification(Notification::new("initialized")),
366            Message::Request(Request::new("tools/list", 2)),
367            Message::Response(Response::success(
368                RequestId::from(2),
369                serde_json::json!({ "tools": [] }),
370            )),
371        ];
372
373        let result = validate_message_sequence(&messages);
374        assert!(result.valid);
375        assert!(result.errors.is_empty());
376        assert_eq!(result.stats.matched_pairs, 2);
377    }
378
379    #[test]
380    fn test_orphan_response() {
381        let messages = vec![Message::Response(Response::success(
382            RequestId::from(999),
383            serde_json::json!({}),
384        ))];
385
386        let result = validate_message_sequence(&messages);
387        assert!(!result.valid);
388        assert!(
389            result
390                .errors
391                .iter()
392                .any(|e| matches!(e, ValidationError::OrphanResponse { .. }))
393        );
394    }
395
396    #[test]
397    fn test_duplicate_request_id() {
398        let messages = vec![
399            Message::Request(Request::new("ping", 1)),
400            Message::Request(Request::new("ping", 1)), // Duplicate!
401        ];
402
403        let result = validate_message_sequence(&messages);
404        assert!(!result.valid);
405        assert!(
406            result
407                .errors
408                .iter()
409                .any(|e| matches!(e, ValidationError::DuplicateRequestId { .. }))
410        );
411    }
412
413    #[test]
414    fn test_unmatched_request() {
415        let messages = vec![
416            Message::Request(Request::new("ping", 1)),
417            // No response!
418        ];
419
420        let result = validate_message_sequence(&messages);
421        assert!(!result.valid);
422        assert!(
423            result
424                .errors
425                .iter()
426                .any(|e| matches!(e, ValidationError::UnmatchedRequest { .. }))
427        );
428    }
429
430    #[test]
431    fn test_strict_mode() {
432        let mut validator = ProtocolValidator::new().strict();
433
434        validator.validate(&Message::Request(Request::new("unknown/method", 1)));
435
436        let result = validator.result();
437        assert!(!result.valid);
438        assert!(
439            result
440                .errors
441                .iter()
442                .any(|e| matches!(e, ValidationError::UnknownMethod { .. }))
443        );
444    }
445}