Skip to main content

nika_mcp/validation/
schema_cache.rs

1//! Schema Cache Module (Layer 1)
2//!
3//! Caches tool schemas from MCP `list_tools()` for validation.
4//!
5//! ## Design
6//!
7//! - On connect(), cache tool schemas from list_tools()
8//! - Thread-safe via DashMap
9//! - Extracts required fields and property names for fast lookup
10//!
11//! ## Usage
12//!
13//! ```rust,ignore
14//! use nika_mcp::validation::ToolSchemaCache;
15//!
16//! let cache = ToolSchemaCache::new();
17//! let count = cache.populate("novanet", &tools)?;
18//! let schema = cache.get("novanet", "novanet_context");
19//! ```
20
21use dashmap::DashMap;
22use jsonschema::Validator;
23use std::sync::Arc;
24
25use crate::error::{McpError, Result};
26use crate::types::ToolDefinition;
27
28/// Cache key: (server_name, tool_name)
29type CacheKey = (String, String);
30
31/// Cached compiled JSON Schema validator
32pub struct CachedSchema {
33    /// Raw schema JSON (for error messages)
34    pub raw: serde_json::Value,
35
36    /// Compiled validator (thread-safe)
37    pub validator: Arc<Validator>,
38
39    /// Required properties (extracted for quick access)
40    pub required: Vec<String>,
41
42    /// All property names (for suggestions)
43    pub properties: Vec<String>,
44}
45
46/// Statistics about the schema cache
47#[derive(Debug, Clone, PartialEq)]
48pub struct CacheStats {
49    /// Number of tools cached
50    pub tool_count: usize,
51
52    /// Number of distinct servers
53    pub servers: usize,
54}
55
56/// Thread-safe schema cache for MCP tools
57pub struct ToolSchemaCache {
58    cache: DashMap<CacheKey, CachedSchema>,
59}
60
61impl Default for ToolSchemaCache {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl ToolSchemaCache {
68    /// Create a new empty cache
69    pub fn new() -> Self {
70        Self {
71            cache: DashMap::new(),
72        }
73    }
74
75    /// Populate cache from list_tools() results
76    ///
77    /// Returns the number of tools cached (skips tools without input_schema)
78    pub fn populate(&self, server: &str, tools: &[ToolDefinition]) -> Result<usize> {
79        let mut count = 0;
80        for tool in tools {
81            if let Some(schema) = &tool.input_schema {
82                self.compile_and_cache(server, &tool.name, schema)?;
83                count += 1;
84            }
85        }
86        Ok(count)
87    }
88
89    /// Get cached schema for a tool
90    pub fn get(
91        &self,
92        server: &str,
93        tool: &str,
94    ) -> Option<dashmap::mapref::one::Ref<'_, CacheKey, CachedSchema>> {
95        self.cache.get(&(server.to_string(), tool.to_string()))
96    }
97
98    /// Clear all cached schemas
99    pub fn clear(&self) {
100        self.cache.clear();
101    }
102
103    /// Get cache statistics
104    pub fn stats(&self) -> CacheStats {
105        let servers: std::collections::HashSet<_> =
106            self.cache.iter().map(|e| e.key().0.clone()).collect();
107
108        CacheStats {
109            tool_count: self.cache.len(),
110            servers: servers.len(),
111        }
112    }
113
114    /// Compile and cache a schema
115    fn compile_and_cache(
116        &self,
117        server: &str,
118        tool: &str,
119        schema: &serde_json::Value,
120    ) -> Result<()> {
121        // Extract required fields
122        let required = schema
123            .get("required")
124            .and_then(|r| r.as_array())
125            .map(|arr| {
126                arr.iter()
127                    .filter_map(|v| v.as_str().map(String::from))
128                    .collect()
129            })
130            .unwrap_or_default();
131
132        // Extract property names
133        let properties = schema
134            .get("properties")
135            .and_then(|p| p.as_object())
136            .map(|obj| obj.keys().cloned().collect())
137            .unwrap_or_default();
138
139        // Compile validator
140        let validator = Validator::new(schema).map_err(|e| McpError::McpProtocolError {
141            reason: format!("Invalid schema for {}.{}: {}", server, tool, e),
142        })?;
143
144        let cached = CachedSchema {
145            raw: schema.clone(),
146            validator: Arc::new(validator),
147            required,
148            properties,
149        };
150
151        self.cache
152            .insert((server.to_string(), tool.to_string()), cached);
153        Ok(())
154    }
155}
156
157// ============================================================================
158// TESTS (TDD - Written First)
159// ============================================================================
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use serde_json::json;
165
166    // ========================================================================
167    // Test: Cache is empty by default
168    // ========================================================================
169    #[test]
170    fn test_cache_empty_by_default() {
171        let cache = ToolSchemaCache::new();
172        assert_eq!(cache.stats().tool_count, 0);
173        assert_eq!(cache.stats().servers, 0);
174    }
175
176    // ========================================================================
177    // Test: Populate from tool definitions
178    // ========================================================================
179    #[test]
180    fn test_populate_from_tool_definitions() {
181        let cache = ToolSchemaCache::new();
182        let tools = vec![ToolDefinition::new("tool1").with_input_schema(json!({
183            "type": "object",
184            "properties": { "a": { "type": "string" } },
185            "required": ["a"]
186        }))];
187
188        let count = cache.populate("server", &tools).unwrap();
189        assert_eq!(count, 1);
190        assert!(cache.get("server", "tool1").is_some());
191    }
192
193    // ========================================================================
194    // Test: Populate skips tools without schema
195    // ========================================================================
196    #[test]
197    fn test_populate_skips_tools_without_schema() {
198        let cache = ToolSchemaCache::new();
199        let tools = vec![
200            ToolDefinition::new("no_schema"),
201            ToolDefinition::new("has_schema").with_input_schema(json!({"type": "object"})),
202        ];
203
204        let count = cache.populate("server", &tools).unwrap();
205        assert_eq!(count, 1);
206        assert!(cache.get("server", "no_schema").is_none());
207        assert!(cache.get("server", "has_schema").is_some());
208    }
209
210    // ========================================================================
211    // Test: Get nonexistent returns None
212    // ========================================================================
213    #[test]
214    fn test_get_nonexistent_returns_none() {
215        let cache = ToolSchemaCache::new();
216        assert!(cache.get("server", "tool").is_none());
217    }
218
219    // ========================================================================
220    // Test: Clear removes all entries
221    // ========================================================================
222    #[test]
223    fn test_clear_removes_all_entries() {
224        let cache = ToolSchemaCache::new();
225        cache
226            .populate(
227                "s",
228                &[ToolDefinition::new("t").with_input_schema(json!({}))],
229            )
230            .unwrap();
231        assert_eq!(cache.stats().tool_count, 1);
232
233        cache.clear();
234        assert_eq!(cache.stats().tool_count, 0);
235    }
236
237    // ========================================================================
238    // Test: Extracts required fields
239    // ========================================================================
240    #[test]
241    fn test_extracts_required_fields() {
242        let cache = ToolSchemaCache::new();
243        cache
244            .populate(
245                "s",
246                &[ToolDefinition::new("t").with_input_schema(json!({
247                    "type": "object",
248                    "properties": {
249                        "entity": { "type": "string" },
250                        "locale": { "type": "string" }
251                    },
252                    "required": ["entity"]
253                }))],
254            )
255            .unwrap();
256
257        let schema = cache.get("s", "t").unwrap();
258        assert_eq!(schema.required, vec!["entity"]);
259        assert!(schema.properties.contains(&"entity".to_string()));
260        assert!(schema.properties.contains(&"locale".to_string()));
261    }
262
263    // ========================================================================
264    // Test: Multiple servers tracked separately
265    // ========================================================================
266    #[test]
267    fn test_multiple_servers_tracked() {
268        let cache = ToolSchemaCache::new();
269        cache
270            .populate(
271                "server1",
272                &[ToolDefinition::new("t1").with_input_schema(json!({}))],
273            )
274            .unwrap();
275        cache
276            .populate(
277                "server2",
278                &[ToolDefinition::new("t2").with_input_schema(json!({}))],
279            )
280            .unwrap();
281
282        let stats = cache.stats();
283        assert_eq!(stats.tool_count, 2);
284        assert_eq!(stats.servers, 2);
285    }
286
287    // ========================================================================
288    // Test: Same tool name, different servers
289    // ========================================================================
290    #[test]
291    fn test_same_tool_name_different_servers() {
292        let cache = ToolSchemaCache::new();
293        cache
294            .populate(
295                "server1",
296                &[ToolDefinition::new("tool").with_input_schema(json!({
297                    "type": "object",
298                    "properties": { "a": {} },
299                    "required": ["a"]
300                }))],
301            )
302            .unwrap();
303        cache
304            .populate(
305                "server2",
306                &[ToolDefinition::new("tool").with_input_schema(json!({
307                    "type": "object",
308                    "properties": { "b": {} },
309                    "required": ["b"]
310                }))],
311            )
312            .unwrap();
313
314        let schema1 = cache.get("server1", "tool").unwrap();
315        let schema2 = cache.get("server2", "tool").unwrap();
316
317        assert_eq!(schema1.required, vec!["a"]);
318        assert_eq!(schema2.required, vec!["b"]);
319    }
320
321    // ========================================================================
322    // Test: Invalid schema returns error
323    // ========================================================================
324    #[test]
325    fn test_invalid_schema_returns_error() {
326        let cache = ToolSchemaCache::new();
327
328        // Schema with invalid $ref should fail to compile
329        let result = cache.populate(
330            "s",
331            &[ToolDefinition::new("t").with_input_schema(json!({
332                "$ref": "#/definitions/nonexistent"
333            }))],
334        );
335
336        // jsonschema may or may not error on invalid refs depending on version
337        // What matters is we handle it gracefully
338        // If it doesn't error, the test still passes
339        if let Err(err) = result {
340            assert!(matches!(err, McpError::McpProtocolError { .. }));
341        }
342    }
343
344    // ========================================================================
345    // Test: Default impl works
346    // ========================================================================
347    #[test]
348    fn test_default_impl() {
349        let cache = ToolSchemaCache::default();
350        assert_eq!(cache.stats().tool_count, 0);
351    }
352
353    // ========================================================================
354    // Test: Extracted properties order independent
355    // ========================================================================
356    #[test]
357    fn test_properties_extraction() {
358        let cache = ToolSchemaCache::new();
359        cache
360            .populate(
361                "s",
362                &[ToolDefinition::new("t").with_input_schema(json!({
363                    "type": "object",
364                    "properties": {
365                        "z_field": {},
366                        "a_field": {},
367                        "m_field": {}
368                    }
369                }))],
370            )
371            .unwrap();
372
373        let schema = cache.get("s", "t").unwrap();
374        // Should have all 3 properties
375        assert_eq!(schema.properties.len(), 3);
376        assert!(schema.properties.contains(&"z_field".to_string()));
377        assert!(schema.properties.contains(&"a_field".to_string()));
378        assert!(schema.properties.contains(&"m_field".to_string()));
379    }
380}