1use dashmap::DashMap;
22use jsonschema::Validator;
23use std::sync::Arc;
24
25use crate::error::{McpError, Result};
26use crate::types::ToolDefinition;
27
28type CacheKey = (String, String);
30
31pub struct CachedSchema {
33 pub raw: serde_json::Value,
35
36 pub validator: Arc<Validator>,
38
39 pub required: Vec<String>,
41
42 pub properties: Vec<String>,
44}
45
46#[derive(Debug, Clone, PartialEq)]
48pub struct CacheStats {
49 pub tool_count: usize,
51
52 pub servers: usize,
54}
55
56pub struct ToolSchemaCache {
58 cache: DashMap<CacheKey, CachedSchema>,
59}
60
61impl Default for ToolSchemaCache {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl ToolSchemaCache {
68 pub fn new() -> Self {
70 Self {
71 cache: DashMap::new(),
72 }
73 }
74
75 pub fn populate(&self, server: &str, tools: &[ToolDefinition]) -> Result<usize> {
79 let mut count = 0;
80 for tool in tools {
81 if let Some(schema) = &tool.input_schema {
82 self.compile_and_cache(server, &tool.name, schema)?;
83 count += 1;
84 }
85 }
86 Ok(count)
87 }
88
89 pub fn get(
91 &self,
92 server: &str,
93 tool: &str,
94 ) -> Option<dashmap::mapref::one::Ref<'_, CacheKey, CachedSchema>> {
95 self.cache.get(&(server.to_string(), tool.to_string()))
96 }
97
98 pub fn clear(&self) {
100 self.cache.clear();
101 }
102
103 pub fn stats(&self) -> CacheStats {
105 let servers: std::collections::HashSet<_> =
106 self.cache.iter().map(|e| e.key().0.clone()).collect();
107
108 CacheStats {
109 tool_count: self.cache.len(),
110 servers: servers.len(),
111 }
112 }
113
114 fn compile_and_cache(
116 &self,
117 server: &str,
118 tool: &str,
119 schema: &serde_json::Value,
120 ) -> Result<()> {
121 let required = schema
123 .get("required")
124 .and_then(|r| r.as_array())
125 .map(|arr| {
126 arr.iter()
127 .filter_map(|v| v.as_str().map(String::from))
128 .collect()
129 })
130 .unwrap_or_default();
131
132 let properties = schema
134 .get("properties")
135 .and_then(|p| p.as_object())
136 .map(|obj| obj.keys().cloned().collect())
137 .unwrap_or_default();
138
139 let validator = Validator::new(schema).map_err(|e| McpError::McpProtocolError {
141 reason: format!("Invalid schema for {}.{}: {}", server, tool, e),
142 })?;
143
144 let cached = CachedSchema {
145 raw: schema.clone(),
146 validator: Arc::new(validator),
147 required,
148 properties,
149 };
150
151 self.cache
152 .insert((server.to_string(), tool.to_string()), cached);
153 Ok(())
154 }
155}
156
157#[cfg(test)]
162mod tests {
163 use super::*;
164 use serde_json::json;
165
166 #[test]
170 fn test_cache_empty_by_default() {
171 let cache = ToolSchemaCache::new();
172 assert_eq!(cache.stats().tool_count, 0);
173 assert_eq!(cache.stats().servers, 0);
174 }
175
176 #[test]
180 fn test_populate_from_tool_definitions() {
181 let cache = ToolSchemaCache::new();
182 let tools = vec![ToolDefinition::new("tool1").with_input_schema(json!({
183 "type": "object",
184 "properties": { "a": { "type": "string" } },
185 "required": ["a"]
186 }))];
187
188 let count = cache.populate("server", &tools).unwrap();
189 assert_eq!(count, 1);
190 assert!(cache.get("server", "tool1").is_some());
191 }
192
193 #[test]
197 fn test_populate_skips_tools_without_schema() {
198 let cache = ToolSchemaCache::new();
199 let tools = vec![
200 ToolDefinition::new("no_schema"),
201 ToolDefinition::new("has_schema").with_input_schema(json!({"type": "object"})),
202 ];
203
204 let count = cache.populate("server", &tools).unwrap();
205 assert_eq!(count, 1);
206 assert!(cache.get("server", "no_schema").is_none());
207 assert!(cache.get("server", "has_schema").is_some());
208 }
209
210 #[test]
214 fn test_get_nonexistent_returns_none() {
215 let cache = ToolSchemaCache::new();
216 assert!(cache.get("server", "tool").is_none());
217 }
218
219 #[test]
223 fn test_clear_removes_all_entries() {
224 let cache = ToolSchemaCache::new();
225 cache
226 .populate(
227 "s",
228 &[ToolDefinition::new("t").with_input_schema(json!({}))],
229 )
230 .unwrap();
231 assert_eq!(cache.stats().tool_count, 1);
232
233 cache.clear();
234 assert_eq!(cache.stats().tool_count, 0);
235 }
236
237 #[test]
241 fn test_extracts_required_fields() {
242 let cache = ToolSchemaCache::new();
243 cache
244 .populate(
245 "s",
246 &[ToolDefinition::new("t").with_input_schema(json!({
247 "type": "object",
248 "properties": {
249 "entity": { "type": "string" },
250 "locale": { "type": "string" }
251 },
252 "required": ["entity"]
253 }))],
254 )
255 .unwrap();
256
257 let schema = cache.get("s", "t").unwrap();
258 assert_eq!(schema.required, vec!["entity"]);
259 assert!(schema.properties.contains(&"entity".to_string()));
260 assert!(schema.properties.contains(&"locale".to_string()));
261 }
262
263 #[test]
267 fn test_multiple_servers_tracked() {
268 let cache = ToolSchemaCache::new();
269 cache
270 .populate(
271 "server1",
272 &[ToolDefinition::new("t1").with_input_schema(json!({}))],
273 )
274 .unwrap();
275 cache
276 .populate(
277 "server2",
278 &[ToolDefinition::new("t2").with_input_schema(json!({}))],
279 )
280 .unwrap();
281
282 let stats = cache.stats();
283 assert_eq!(stats.tool_count, 2);
284 assert_eq!(stats.servers, 2);
285 }
286
287 #[test]
291 fn test_same_tool_name_different_servers() {
292 let cache = ToolSchemaCache::new();
293 cache
294 .populate(
295 "server1",
296 &[ToolDefinition::new("tool").with_input_schema(json!({
297 "type": "object",
298 "properties": { "a": {} },
299 "required": ["a"]
300 }))],
301 )
302 .unwrap();
303 cache
304 .populate(
305 "server2",
306 &[ToolDefinition::new("tool").with_input_schema(json!({
307 "type": "object",
308 "properties": { "b": {} },
309 "required": ["b"]
310 }))],
311 )
312 .unwrap();
313
314 let schema1 = cache.get("server1", "tool").unwrap();
315 let schema2 = cache.get("server2", "tool").unwrap();
316
317 assert_eq!(schema1.required, vec!["a"]);
318 assert_eq!(schema2.required, vec!["b"]);
319 }
320
321 #[test]
325 fn test_invalid_schema_returns_error() {
326 let cache = ToolSchemaCache::new();
327
328 let result = cache.populate(
330 "s",
331 &[ToolDefinition::new("t").with_input_schema(json!({
332 "$ref": "#/definitions/nonexistent"
333 }))],
334 );
335
336 if let Err(err) = result {
340 assert!(matches!(err, McpError::McpProtocolError { .. }));
341 }
342 }
343
344 #[test]
348 fn test_default_impl() {
349 let cache = ToolSchemaCache::default();
350 assert_eq!(cache.stats().tool_count, 0);
351 }
352
353 #[test]
357 fn test_properties_extraction() {
358 let cache = ToolSchemaCache::new();
359 cache
360 .populate(
361 "s",
362 &[ToolDefinition::new("t").with_input_schema(json!({
363 "type": "object",
364 "properties": {
365 "z_field": {},
366 "a_field": {},
367 "m_field": {}
368 }
369 }))],
370 )
371 .unwrap();
372
373 let schema = cache.get("s", "t").unwrap();
374 assert_eq!(schema.properties.len(), 3);
376 assert!(schema.properties.contains(&"z_field".to_string()));
377 assert!(schema.properties.contains(&"a_field".to_string()));
378 assert!(schema.properties.contains(&"m_field".to_string()));
379 }
380}