Skip to main content

nika_mcp/validation/
validator.rs

1//! MCP Validator Module (Layer 2)
2//!
3//! Pre-call validation of MCP tool parameters against cached schemas.
4//!
5//! ## Design
6//!
7//! - Uses cached schemas from ToolSchemaCache
8//! - Validates parameters before calling MCP tool
9//! - Returns detailed validation errors with suggestions
10//!
11//! ## Usage
12//!
13//! ```rust,ignore
14//! use nika_mcp::validation::{McpValidator, ValidationConfig};
15//!
16//! let validator = McpValidator::new(ValidationConfig::default());
17//! validator.cache().populate("novanet", &tools)?;
18//!
19//! let result = validator.validate("novanet", "novanet_context", &params);
20//! if !result.is_valid {
21//!     for error in result.errors {
22//!         println!("{}", error.message);
23//!     }
24//! }
25//! ```
26
27use super::schema_cache::{CachedSchema, ToolSchemaCache};
28use super::ValidationConfig;
29
30/// Validation result with detailed errors
31#[derive(Debug, Clone)]
32pub struct ValidationResult {
33    /// Whether validation passed
34    pub is_valid: bool,
35
36    /// List of validation errors (empty if valid)
37    pub errors: Vec<ValidationError>,
38}
39
40/// Single validation error
41#[derive(Debug, Clone)]
42pub struct ValidationError {
43    /// JSON path to the error (e.g., "/entity", "/locale")
44    pub path: String,
45
46    /// Error kind
47    pub kind: ValidationErrorKind,
48
49    /// Human-readable message
50    pub message: String,
51}
52
53/// Validation error kinds
54#[derive(Debug, Clone, PartialEq)]
55pub enum ValidationErrorKind {
56    /// Required field is missing
57    MissingRequired { field: String },
58
59    /// Field type is wrong
60    TypeMismatch { expected: String, actual: String },
61
62    /// Unknown field (not in schema)
63    UnknownField {
64        field: String,
65        suggestions: Vec<String>,
66    },
67
68    /// Value doesn't match pattern/format
69    InvalidValue { reason: String },
70
71    /// Enum value not in allowed list
72    InvalidEnum { value: String, allowed: Vec<String> },
73}
74
75/// MCP parameter validator
76pub struct McpValidator {
77    cache: ToolSchemaCache,
78    config: ValidationConfig,
79}
80
81impl McpValidator {
82    /// Create a new validator with the given config
83    pub fn new(config: ValidationConfig) -> Self {
84        Self {
85            cache: ToolSchemaCache::new(),
86            config,
87        }
88    }
89
90    /// Get reference to schema cache (for populating)
91    pub fn cache(&self) -> &ToolSchemaCache {
92        &self.cache
93    }
94
95    /// Get reference to validation config
96    pub fn config(&self) -> &ValidationConfig {
97        &self.config
98    }
99
100    /// Validate parameters against cached schema
101    pub fn validate(
102        &self,
103        server: &str,
104        tool: &str,
105        params: &serde_json::Value,
106    ) -> ValidationResult {
107        // If validation disabled, always pass
108        if !self.config.pre_validate {
109            return ValidationResult {
110                is_valid: true,
111                errors: vec![],
112            };
113        }
114
115        // Get cached schema
116        let Some(schema_ref) = self.cache.get(server, tool) else {
117            // No schema cached = can't validate, pass through
118            tracing::debug!(
119                server = %server,
120                tool = %tool,
121                "No cached schema, skipping validation"
122            );
123            return ValidationResult {
124                is_valid: true,
125                errors: vec![],
126            };
127        };
128
129        let schema = schema_ref.value();
130        let mut errors = Vec::new();
131
132        // Run JSON Schema validation
133        let validation = schema.validator.iter_errors(params);
134
135        for error in validation {
136            let path = error.instance_path.to_string();
137            let kind = self.classify_error(&error, schema);
138            let message = self.format_error(&error, schema);
139
140            errors.push(ValidationError {
141                path,
142                kind,
143                message,
144            });
145        }
146
147        ValidationResult {
148            is_valid: errors.is_empty(),
149            errors,
150        }
151    }
152
153    /// Classify validation error into a kind
154    fn classify_error(
155        &self,
156        error: &jsonschema::ValidationError,
157        schema: &CachedSchema,
158    ) -> ValidationErrorKind {
159        let error_kind = format!("{:?}", error.kind);
160        let error_msg = error.to_string();
161
162        if error_kind.contains("Required") {
163            // Extract field name from error message
164            let field = self.extract_missing_field(&error_msg);
165            ValidationErrorKind::MissingRequired { field }
166        } else if error_kind.contains("Type") {
167            ValidationErrorKind::TypeMismatch {
168                expected: self.extract_expected_type(&error_msg),
169                actual: self.extract_actual_type(&error_msg),
170            }
171        } else if error_kind.contains("AdditionalProperties") {
172            // Extract unknown field name from the error message.
173            // jsonschema sets instance_path to the parent object, not the unknown field.
174            // The error message format is: "Additional properties are not allowed ('field' ...)"
175            let field = Self::extract_additional_property_field(&error_msg);
176            let suggestions = self.find_suggestions(&field, &schema.properties);
177            ValidationErrorKind::UnknownField { field, suggestions }
178        } else if error_kind.contains("Enum") {
179            ValidationErrorKind::InvalidEnum {
180                value: format!("{}", error.instance),
181                allowed: vec![], // Could extract from schema if needed
182            }
183        } else {
184            ValidationErrorKind::InvalidValue { reason: error_msg }
185        }
186    }
187
188    /// Extract missing field name from error message
189    fn extract_missing_field(&self, error_msg: &str) -> String {
190        // Pattern: "fieldname" is a required property (double quotes)
191        if let Some(start) = error_msg.find('"') {
192            if let Some(end) = error_msg[start + 1..].find('"') {
193                return error_msg[start + 1..start + 1 + end].to_string();
194            }
195        }
196        // Fallback: try single quotes
197        if let Some(start) = error_msg.find('\'') {
198            if let Some(end) = error_msg[start + 1..].find('\'') {
199                return error_msg[start + 1..start + 1 + end].to_string();
200            }
201        }
202        "unknown".to_string()
203    }
204
205    /// Extract the unknown field name from an additionalProperties error message.
206    ///
207    /// jsonschema formats it as: "Additional properties are not allowed ('field_name' ...)"
208    fn extract_additional_property_field(error_msg: &str) -> String {
209        // Try single quotes first (jsonschema's typical format)
210        if let Some(start) = error_msg.find('\'') {
211            if let Some(end) = error_msg[start + 1..].find('\'') {
212                return error_msg[start + 1..start + 1 + end].to_string();
213            }
214        }
215        // Fallback: try double quotes
216        if let Some(start) = error_msg.find('"') {
217            if let Some(end) = error_msg[start + 1..].find('"') {
218                return error_msg[start + 1..start + 1 + end].to_string();
219            }
220        }
221        "unknown".to_string()
222    }
223
224    /// Extract expected type from error message
225    fn extract_expected_type(&self, error_msg: &str) -> String {
226        // Simple extraction - could be improved
227        if error_msg.contains("string") {
228            "string".to_string()
229        } else if error_msg.contains("integer") {
230            "integer".to_string()
231        } else if error_msg.contains("number") {
232            "number".to_string()
233        } else if error_msg.contains("boolean") {
234            "boolean".to_string()
235        } else if error_msg.contains("array") {
236            "array".to_string()
237        } else if error_msg.contains("object") {
238            "object".to_string()
239        } else {
240            "expected".to_string()
241        }
242    }
243
244    /// Extract actual type from error message
245    ///
246    /// jsonschema error messages contain type info like "... is not of type ..."
247    /// We extract what the value actually IS by looking at contextual clues.
248    fn extract_actual_type(&self, error_msg: &str) -> String {
249        // jsonschema errors typically say "X is not of type Y" — the actual type
250        // can be inferred from the value representation in the message
251        let msg = error_msg.to_lowercase();
252        if msg.contains("null") && !msg.contains("not of type \"null\"") {
253            "null".to_string()
254        } else if msg.contains("true") || msg.contains("false") {
255            "boolean".to_string()
256        } else if msg.contains('[') {
257            "array".to_string()
258        } else if msg.contains('{') {
259            "object".to_string()
260        } else if msg.contains("\"\"") || msg.contains("''") {
261            "string".to_string()
262        } else {
263            // Try to infer from what it's NOT — if "not of type string", it's likely a number
264            if msg.contains("not of type \"string\"") {
265                "number".to_string()
266            } else if msg.contains("not of type \"integer\"")
267                || msg.contains("not of type \"number\"")
268            {
269                "string".to_string()
270            } else {
271                "unknown".to_string()
272            }
273        }
274    }
275
276    /// Format a human-readable error message
277    fn format_error(&self, error: &jsonschema::ValidationError, schema: &CachedSchema) -> String {
278        let base = error.to_string();
279
280        // Add suggestions for missing fields
281        if !schema.required.is_empty() {
282            format!(
283                "{}. Required fields: [{}]",
284                base,
285                schema.required.join(", ")
286            )
287        } else {
288            base
289        }
290    }
291
292    /// Find similar field names (for "did you mean?")
293    pub fn find_suggestions(&self, field: &str, properties: &[String]) -> Vec<String> {
294        properties
295            .iter()
296            .filter(|p| Self::edit_distance(field, p) <= self.config.suggestion_distance)
297            .cloned()
298            .collect()
299    }
300
301    /// Simple Levenshtein distance (case-insensitive)
302    pub fn edit_distance(a: &str, b: &str) -> usize {
303        let a = a.to_lowercase();
304        let b = b.to_lowercase();
305
306        if a.is_empty() {
307            return b.len();
308        }
309        if b.is_empty() {
310            return a.len();
311        }
312
313        let a_chars: Vec<char> = a.chars().collect();
314        let b_chars: Vec<char> = b.chars().collect();
315
316        let mut matrix = vec![vec![0usize; b_chars.len() + 1]; a_chars.len() + 1];
317
318        for (i, row) in matrix.iter_mut().enumerate().take(a_chars.len() + 1) {
319            row[0] = i;
320        }
321        for (j, val) in matrix[0].iter_mut().enumerate() {
322            *val = j;
323        }
324
325        for i in 1..=a_chars.len() {
326            for j in 1..=b_chars.len() {
327                let cost = if a_chars[i - 1] == b_chars[j - 1] {
328                    0
329                } else {
330                    1
331                };
332                matrix[i][j] = std::cmp::min(
333                    std::cmp::min(
334                        matrix[i - 1][j] + 1, // deletion
335                        matrix[i][j - 1] + 1, // insertion
336                    ),
337                    matrix[i - 1][j - 1] + cost, // substitution
338                );
339            }
340        }
341
342        matrix[a_chars.len()][b_chars.len()]
343    }
344}
345
346// ============================================================================
347// TESTS (TDD)
348// ============================================================================
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use crate::types::ToolDefinition;
354    use serde_json::json;
355
356    // ========================================================================
357    // Test: Validate missing required field
358    // ========================================================================
359    #[test]
360    fn test_validate_missing_required_field() {
361        let validator = McpValidator::new(ValidationConfig::default());
362        validator
363            .cache()
364            .populate(
365                "novanet",
366                &[
367                    ToolDefinition::new("novanet_context").with_input_schema(json!({
368                        "type": "object",
369                        "properties": {
370                            "entity": { "type": "string" },
371                            "locale": { "type": "string" }
372                        },
373                        "required": ["entity"]
374                    })),
375                ],
376            )
377            .unwrap();
378
379        // Missing required "entity" field
380        let result = validator.validate(
381            "novanet",
382            "novanet_context",
383            &json!({
384                "locale": "fr-FR"
385            }),
386        );
387
388        assert!(!result.is_valid);
389        assert_eq!(result.errors.len(), 1);
390
391        // Check that it's a MissingRequired error
392        match &result.errors[0].kind {
393            ValidationErrorKind::MissingRequired { field } => {
394                assert_eq!(field, "entity");
395            }
396            other => {
397                panic!("Expected MissingRequired, got {:?}", other);
398            }
399        }
400    }
401
402    // ========================================================================
403    // Test: Valid params passes
404    // ========================================================================
405    #[test]
406    fn test_validate_valid_params_passes() {
407        let validator = McpValidator::new(ValidationConfig::default());
408        validator
409            .cache()
410            .populate(
411                "novanet",
412                &[
413                    ToolDefinition::new("novanet_context").with_input_schema(json!({
414                        "type": "object",
415                        "properties": {
416                            "entity": { "type": "string" }
417                        },
418                        "required": ["entity"]
419                    })),
420                ],
421            )
422            .unwrap();
423
424        let result = validator.validate(
425            "novanet",
426            "novanet_context",
427            &json!({
428                "entity": "qr-code"
429            }),
430        );
431
432        assert!(result.is_valid);
433        assert!(result.errors.is_empty());
434    }
435
436    // ========================================================================
437    // Test: Validation disabled always passes
438    // ========================================================================
439    #[test]
440    fn test_validate_disabled_always_passes() {
441        let config = ValidationConfig {
442            pre_validate: false,
443            ..Default::default()
444        };
445        let validator = McpValidator::new(config);
446
447        // No schema cached, but should pass
448        let result = validator.validate("any", "tool", &json!({}));
449        assert!(result.is_valid);
450    }
451
452    // ========================================================================
453    // Test: No cached schema passes
454    // ========================================================================
455    #[test]
456    fn test_validate_no_cached_schema_passes() {
457        let validator = McpValidator::new(ValidationConfig::default());
458
459        // No schema cached for this tool
460        let result = validator.validate(
461            "unknown",
462            "tool",
463            &json!({
464                "anything": "goes"
465            }),
466        );
467
468        assert!(result.is_valid);
469    }
470
471    // ========================================================================
472    // Test: Type mismatch
473    // ========================================================================
474    #[test]
475    fn test_validate_type_mismatch() {
476        let validator = McpValidator::new(ValidationConfig::default());
477        validator
478            .cache()
479            .populate(
480                "s",
481                &[ToolDefinition::new("t").with_input_schema(json!({
482                    "type": "object",
483                    "properties": {
484                        "count": { "type": "integer" }
485                    }
486                }))],
487            )
488            .unwrap();
489
490        let result = validator.validate(
491            "s",
492            "t",
493            &json!({
494                "count": "not-an-integer"
495            }),
496        );
497
498        assert!(!result.is_valid);
499        assert!(matches!(
500            &result.errors[0].kind,
501            ValidationErrorKind::TypeMismatch { .. }
502        ));
503    }
504
505    // ========================================================================
506    // Test: Edit distance exact match
507    // ========================================================================
508    #[test]
509    fn test_edit_distance_exact_match() {
510        assert_eq!(McpValidator::edit_distance("entity", "entity"), 0);
511    }
512
513    // ========================================================================
514    // Test: Edit distance one char diff
515    // ========================================================================
516    #[test]
517    fn test_edit_distance_one_char_diff() {
518        assert_eq!(McpValidator::edit_distance("entity", "entityy"), 1);
519        assert_eq!(McpValidator::edit_distance("entty", "entity"), 1);
520    }
521
522    // ========================================================================
523    // Test: Edit distance case insensitive
524    // ========================================================================
525    #[test]
526    fn test_edit_distance_case_insensitive() {
527        assert_eq!(McpValidator::edit_distance("Entity", "ENTITY"), 0);
528    }
529
530    // ========================================================================
531    // Test: Find suggestions within distance
532    // ========================================================================
533    #[test]
534    fn test_find_suggestions_within_distance() {
535        let validator = McpValidator::new(ValidationConfig::default());
536        validator
537            .cache()
538            .populate(
539                "s",
540                &[ToolDefinition::new("t").with_input_schema(json!({
541                    "type": "object",
542                    "properties": {
543                        "entity": {},
544                        "locale": {},
545                        "forms": {}
546                    }
547                }))],
548            )
549            .unwrap();
550
551        let schema = validator.cache().get("s", "t").unwrap();
552        let suggestions = validator.find_suggestions("entiy", &schema.properties);
553
554        assert!(suggestions.contains(&"entity".to_string()));
555    }
556
557    // ========================================================================
558    // Test: Edit distance empty strings
559    // ========================================================================
560    #[test]
561    fn test_edit_distance_empty_strings() {
562        assert_eq!(McpValidator::edit_distance("", ""), 0);
563        assert_eq!(McpValidator::edit_distance("abc", ""), 3);
564        assert_eq!(McpValidator::edit_distance("", "xyz"), 3);
565    }
566
567    // ========================================================================
568    // Test: Edit distance completely different
569    // ========================================================================
570    #[test]
571    fn test_edit_distance_completely_different() {
572        assert_eq!(McpValidator::edit_distance("abc", "xyz"), 3);
573    }
574
575    // ========================================================================
576    // Test: Multiple validation errors
577    // ========================================================================
578    #[test]
579    fn test_multiple_validation_errors() {
580        let validator = McpValidator::new(ValidationConfig::default());
581        validator
582            .cache()
583            .populate(
584                "s",
585                &[ToolDefinition::new("t").with_input_schema(json!({
586                    "type": "object",
587                    "properties": {
588                        "a": { "type": "string" },
589                        "b": { "type": "integer" }
590                    },
591                    "required": ["a", "b"]
592                }))],
593            )
594            .unwrap();
595
596        // Missing both required fields
597        let result = validator.validate("s", "t", &json!({}));
598
599        assert!(!result.is_valid);
600        assert_eq!(result.errors.len(), 2);
601    }
602
603    // ========================================================================
604    // Test: Error message includes required fields
605    // ========================================================================
606    #[test]
607    fn test_error_message_includes_required_fields() {
608        let validator = McpValidator::new(ValidationConfig::default());
609        validator
610            .cache()
611            .populate(
612                "s",
613                &[ToolDefinition::new("t").with_input_schema(json!({
614                    "type": "object",
615                    "properties": {
616                        "entity": { "type": "string" },
617                        "locale": { "type": "string" }
618                    },
619                    "required": ["entity"]
620                }))],
621            )
622            .unwrap();
623
624        let result = validator.validate("s", "t", &json!({}));
625
626        assert!(!result.is_valid);
627        // Message should mention required fields
628        assert!(result.errors[0].message.contains("Required fields"));
629        assert!(result.errors[0].message.contains("entity"));
630    }
631
632    // ========================================================================
633    // Test: Suggestion distance config respected
634    // ========================================================================
635    #[test]
636    fn test_suggestion_distance_config() {
637        let config = ValidationConfig {
638            suggestion_distance: 1,
639            ..Default::default()
640        };
641        let validator = McpValidator::new(config);
642
643        // "entiy" is distance 1 from "entity" - should be suggested
644        let suggestions = validator.find_suggestions(
645            "entiy",
646            &["entity".to_string(), "completely_different".to_string()],
647        );
648        assert!(suggestions.contains(&"entity".to_string()));
649        assert!(!suggestions.contains(&"completely_different".to_string()));
650    }
651}