1use crate::protocol::{Message, Notification, Request, RequestId, Response};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, thiserror::Error)]
11pub enum ValidationError {
12 #[error("orphan response: no request found for ID {id:?}")]
14 OrphanResponse {
15 id: RequestId,
17 },
18
19 #[error("duplicate request ID: {id:?}")]
21 DuplicateRequestId {
22 id: RequestId,
24 },
25
26 #[error("unmatched request: {method} (ID: {id:?})")]
28 UnmatchedRequest {
29 id: RequestId,
31 method: String,
33 },
34
35 #[error("unknown method: {method}")]
37 UnknownMethod {
38 method: String,
40 },
41
42 #[error("invalid sequence: {message}")]
44 InvalidSequence {
45 message: String,
47 },
48
49 #[error("missing initialization: {message}")]
51 MissingInitialization {
52 message: String,
54 },
55}
56
57#[derive(Debug, Clone)]
59pub struct ValidationResult {
60 pub valid: bool,
62 pub errors: Vec<ValidationError>,
64 pub warnings: Vec<String>,
66 pub stats: ValidationStats,
68}
69
70impl ValidationResult {
71 #[must_use]
73 pub fn pass() -> Self {
74 Self {
75 valid: true,
76 errors: Vec::new(),
77 warnings: Vec::new(),
78 stats: ValidationStats::default(),
79 }
80 }
81
82 #[must_use]
84 pub fn fail(errors: Vec<ValidationError>) -> Self {
85 Self {
86 valid: false,
87 errors,
88 warnings: Vec::new(),
89 stats: ValidationStats::default(),
90 }
91 }
92
93 pub fn add_warning(&mut self, warning: impl Into<String>) {
95 self.warnings.push(warning.into());
96 }
97}
98
99#[derive(Debug, Clone, Default)]
101pub struct ValidationStats {
102 pub total_messages: usize,
104 pub requests: usize,
106 pub responses: usize,
108 pub notifications: usize,
110 pub matched_pairs: usize,
112}
113
114#[derive(Debug)]
122pub struct ProtocolValidator {
123 known_request_methods: HashSet<String>,
125 known_notification_methods: HashSet<String>,
127 pending_requests: HashMap<RequestId, String>,
129 seen_request_ids: HashSet<RequestId>,
131 initialized: bool,
133 errors: Vec<ValidationError>,
135 warnings: Vec<String>,
137 stats: ValidationStats,
139 strict_mode: bool,
141}
142
143impl Default for ProtocolValidator {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149impl ProtocolValidator {
150 #[must_use]
152 pub fn new() -> Self {
153 let mut validator = Self {
154 known_request_methods: HashSet::new(),
155 known_notification_methods: HashSet::new(),
156 pending_requests: HashMap::new(),
157 seen_request_ids: HashSet::new(),
158 initialized: false,
159 errors: Vec::new(),
160 warnings: Vec::new(),
161 stats: ValidationStats::default(),
162 strict_mode: false,
163 };
164
165 validator.register_request_methods(&[
167 "initialize",
168 "ping",
169 "tools/list",
170 "tools/call",
171 "resources/list",
172 "resources/read",
173 "resources/subscribe",
174 "resources/unsubscribe",
175 "prompts/list",
176 "prompts/get",
177 "logging/setLevel",
178 "completion/complete",
179 "sampling/createMessage",
180 "roots/list",
181 ]);
182
183 validator.register_notification_methods(&[
184 "initialized",
185 "notifications/cancelled",
186 "notifications/progress",
187 "notifications/message",
188 "notifications/resources/updated",
189 "notifications/resources/list_changed",
190 "notifications/tools/list_changed",
191 "notifications/prompts/list_changed",
192 "notifications/roots/list_changed",
193 ]);
194
195 validator
196 }
197
198 #[must_use]
200 pub fn strict(mut self) -> Self {
201 self.strict_mode = true;
202 self
203 }
204
205 pub fn register_request_methods(&mut self, methods: &[&str]) {
207 for method in methods {
208 self.known_request_methods.insert((*method).to_string());
209 }
210 }
211
212 pub fn register_notification_methods(&mut self, methods: &[&str]) {
214 for method in methods {
215 self.known_notification_methods
216 .insert((*method).to_string());
217 }
218 }
219
220 pub fn validate(&mut self, message: &Message) {
222 self.stats.total_messages += 1;
223
224 match message {
225 Message::Request(req) => self.validate_request(req),
226 Message::Response(res) => self.validate_response(res),
227 Message::Notification(notif) => self.validate_notification(notif),
228 }
229 }
230
231 fn validate_request(&mut self, request: &Request) {
232 self.stats.requests += 1;
233
234 if self.seen_request_ids.contains(&request.id) {
236 self.errors.push(ValidationError::DuplicateRequestId {
237 id: request.id.clone(),
238 });
239 }
240 self.seen_request_ids.insert(request.id.clone());
241
242 self.pending_requests
244 .insert(request.id.clone(), request.method.to_string());
245
246 let method = request.method.as_ref();
248 if !self.known_request_methods.contains(method) {
249 if self.strict_mode {
250 self.errors.push(ValidationError::UnknownMethod {
251 method: method.to_string(),
252 });
253 } else {
254 self.warnings
255 .push(format!("Unknown request method: {method}"));
256 }
257 }
258
259 if method == "initialize" {
261 if self.initialized {
262 self.warnings
263 .push("Duplicate initialize request".to_string());
264 }
265 } else if !self.initialized && method != "ping" {
266 self.warnings
267 .push(format!("Request before initialization: {method}"));
268 }
269 }
270
271 fn validate_response(&mut self, response: &Response) {
272 self.stats.responses += 1;
273
274 if self.pending_requests.remove(&response.id).is_some() {
276 self.stats.matched_pairs += 1;
277 } else {
278 self.errors.push(ValidationError::OrphanResponse {
279 id: response.id.clone(),
280 });
281 }
282 }
283
284 fn validate_notification(&mut self, notification: &Notification) {
285 self.stats.notifications += 1;
286
287 let method = notification.method.as_ref();
288
289 if !self.known_notification_methods.contains(method) {
291 if self.strict_mode {
292 self.errors.push(ValidationError::UnknownMethod {
293 method: method.to_string(),
294 });
295 } else {
296 self.warnings
297 .push(format!("Unknown notification method: {method}"));
298 }
299 }
300
301 if method == "initialized" {
303 self.initialized = true;
304 }
305 }
306
307 pub fn check_unmatched_requests(&mut self) {
309 for (id, method) in self.pending_requests.drain() {
310 self.errors
311 .push(ValidationError::UnmatchedRequest { id, method });
312 }
313 }
314
315 #[must_use]
317 pub fn result(&self) -> ValidationResult {
318 ValidationResult {
319 valid: self.errors.is_empty(),
320 errors: self.errors.clone(),
321 warnings: self.warnings.clone(),
322 stats: self.stats.clone(),
323 }
324 }
325
326 #[must_use]
328 pub fn finalize(mut self) -> ValidationResult {
329 self.check_unmatched_requests();
330 self.result()
331 }
332
333 pub fn reset(&mut self) {
335 self.pending_requests.clear();
336 self.seen_request_ids.clear();
337 self.initialized = false;
338 self.errors.clear();
339 self.warnings.clear();
340 self.stats = ValidationStats::default();
341 }
342}
343
344#[must_use]
346pub fn validate_message_sequence(messages: &[Message]) -> ValidationResult {
347 let mut validator = ProtocolValidator::new();
348
349 for msg in messages {
350 validator.validate(msg);
351 }
352
353 validator.finalize()
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_valid_sequence() {
362 let messages = vec![
363 Message::Request(Request::new("initialize", 1)),
364 Message::Response(Response::success(RequestId::from(1), serde_json::json!({}))),
365 Message::Notification(Notification::new("initialized")),
366 Message::Request(Request::new("tools/list", 2)),
367 Message::Response(Response::success(
368 RequestId::from(2),
369 serde_json::json!({ "tools": [] }),
370 )),
371 ];
372
373 let result = validate_message_sequence(&messages);
374 assert!(result.valid);
375 assert!(result.errors.is_empty());
376 assert_eq!(result.stats.matched_pairs, 2);
377 }
378
379 #[test]
380 fn test_orphan_response() {
381 let messages = vec![Message::Response(Response::success(
382 RequestId::from(999),
383 serde_json::json!({}),
384 ))];
385
386 let result = validate_message_sequence(&messages);
387 assert!(!result.valid);
388 assert!(
389 result
390 .errors
391 .iter()
392 .any(|e| matches!(e, ValidationError::OrphanResponse { .. }))
393 );
394 }
395
396 #[test]
397 fn test_duplicate_request_id() {
398 let messages = vec![
399 Message::Request(Request::new("ping", 1)),
400 Message::Request(Request::new("ping", 1)), ];
402
403 let result = validate_message_sequence(&messages);
404 assert!(!result.valid);
405 assert!(
406 result
407 .errors
408 .iter()
409 .any(|e| matches!(e, ValidationError::DuplicateRequestId { .. }))
410 );
411 }
412
413 #[test]
414 fn test_unmatched_request() {
415 let messages = vec![
416 Message::Request(Request::new("ping", 1)),
417 ];
419
420 let result = validate_message_sequence(&messages);
421 assert!(!result.valid);
422 assert!(
423 result
424 .errors
425 .iter()
426 .any(|e| matches!(e, ValidationError::UnmatchedRequest { .. }))
427 );
428 }
429
430 #[test]
431 fn test_strict_mode() {
432 let mut validator = ProtocolValidator::new().strict();
433
434 validator.validate(&Message::Request(Request::new("unknown/method", 1)));
435
436 let result = validator.result();
437 assert!(!result.valid);
438 assert!(
439 result
440 .errors
441 .iter()
442 .any(|e| matches!(e, ValidationError::UnknownMethod { .. }))
443 );
444 }
445}