use dashmap::DashMap;
use jsonschema::Validator;
use std::sync::Arc;
use crate::error::{NikaError, Result};
use crate::mcp::types::ToolDefinition;
type CacheKey = (String, String);
pub struct CachedSchema {
pub raw: serde_json::Value,
pub validator: Arc<Validator>,
pub required: Vec<String>,
pub properties: Vec<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CacheStats {
pub tool_count: usize,
pub servers: usize,
}
pub struct ToolSchemaCache {
cache: DashMap<CacheKey, CachedSchema>,
}
impl Default for ToolSchemaCache {
fn default() -> Self {
Self::new()
}
}
impl ToolSchemaCache {
pub fn new() -> Self {
Self {
cache: DashMap::new(),
}
}
pub fn populate(&self, server: &str, tools: &[ToolDefinition]) -> Result<usize> {
let mut count = 0;
for tool in tools {
if let Some(schema) = &tool.input_schema {
self.compile_and_cache(server, &tool.name, schema)?;
count += 1;
}
}
Ok(count)
}
pub fn get(
&self,
server: &str,
tool: &str,
) -> Option<dashmap::mapref::one::Ref<'_, CacheKey, CachedSchema>> {
self.cache.get(&(server.to_string(), tool.to_string()))
}
pub fn clear(&self) {
self.cache.clear();
}
pub fn stats(&self) -> CacheStats {
let servers: std::collections::HashSet<_> =
self.cache.iter().map(|e| e.key().0.clone()).collect();
CacheStats {
tool_count: self.cache.len(),
servers: servers.len(),
}
}
fn compile_and_cache(
&self,
server: &str,
tool: &str,
schema: &serde_json::Value,
) -> Result<()> {
let required = schema
.get("required")
.and_then(|r| r.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let properties = schema
.get("properties")
.and_then(|p| p.as_object())
.map(|obj| obj.keys().cloned().collect())
.unwrap_or_default();
let validator = Validator::new(schema).map_err(|e| NikaError::McpProtocolError {
reason: format!("Invalid schema for {}.{}: {}", server, tool, e),
})?;
let cached = CachedSchema {
raw: schema.clone(),
validator: Arc::new(validator),
required,
properties,
};
self.cache
.insert((server.to_string(), tool.to_string()), cached);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_cache_empty_by_default() {
let cache = ToolSchemaCache::new();
assert_eq!(cache.stats().tool_count, 0);
assert_eq!(cache.stats().servers, 0);
}
#[test]
fn test_populate_from_tool_definitions() {
let cache = ToolSchemaCache::new();
let tools = vec![ToolDefinition::new("tool1").with_input_schema(json!({
"type": "object",
"properties": { "a": { "type": "string" } },
"required": ["a"]
}))];
let count = cache.populate("server", &tools).unwrap();
assert_eq!(count, 1);
assert!(cache.get("server", "tool1").is_some());
}
#[test]
fn test_populate_skips_tools_without_schema() {
let cache = ToolSchemaCache::new();
let tools = vec![
ToolDefinition::new("no_schema"),
ToolDefinition::new("has_schema").with_input_schema(json!({"type": "object"})),
];
let count = cache.populate("server", &tools).unwrap();
assert_eq!(count, 1);
assert!(cache.get("server", "no_schema").is_none());
assert!(cache.get("server", "has_schema").is_some());
}
#[test]
fn test_get_nonexistent_returns_none() {
let cache = ToolSchemaCache::new();
assert!(cache.get("server", "tool").is_none());
}
#[test]
fn test_clear_removes_all_entries() {
let cache = ToolSchemaCache::new();
cache
.populate(
"s",
&[ToolDefinition::new("t").with_input_schema(json!({}))],
)
.unwrap();
assert_eq!(cache.stats().tool_count, 1);
cache.clear();
assert_eq!(cache.stats().tool_count, 0);
}
#[test]
fn test_extracts_required_fields() {
let cache = ToolSchemaCache::new();
cache
.populate(
"s",
&[ToolDefinition::new("t").with_input_schema(json!({
"type": "object",
"properties": {
"entity": { "type": "string" },
"locale": { "type": "string" }
},
"required": ["entity"]
}))],
)
.unwrap();
let schema = cache.get("s", "t").unwrap();
assert_eq!(schema.required, vec!["entity"]);
assert!(schema.properties.contains(&"entity".to_string()));
assert!(schema.properties.contains(&"locale".to_string()));
}
#[test]
fn test_multiple_servers_tracked() {
let cache = ToolSchemaCache::new();
cache
.populate(
"server1",
&[ToolDefinition::new("t1").with_input_schema(json!({}))],
)
.unwrap();
cache
.populate(
"server2",
&[ToolDefinition::new("t2").with_input_schema(json!({}))],
)
.unwrap();
let stats = cache.stats();
assert_eq!(stats.tool_count, 2);
assert_eq!(stats.servers, 2);
}
#[test]
fn test_same_tool_name_different_servers() {
let cache = ToolSchemaCache::new();
cache
.populate(
"server1",
&[ToolDefinition::new("tool").with_input_schema(json!({
"type": "object",
"properties": { "a": {} },
"required": ["a"]
}))],
)
.unwrap();
cache
.populate(
"server2",
&[ToolDefinition::new("tool").with_input_schema(json!({
"type": "object",
"properties": { "b": {} },
"required": ["b"]
}))],
)
.unwrap();
let schema1 = cache.get("server1", "tool").unwrap();
let schema2 = cache.get("server2", "tool").unwrap();
assert_eq!(schema1.required, vec!["a"]);
assert_eq!(schema2.required, vec!["b"]);
}
#[test]
fn test_invalid_schema_returns_error() {
let cache = ToolSchemaCache::new();
let result = cache.populate(
"s",
&[ToolDefinition::new("t").with_input_schema(json!({
"$ref": "#/definitions/nonexistent"
}))],
);
if let Err(err) = result {
assert!(matches!(err, NikaError::McpProtocolError { .. }));
}
}
#[test]
fn test_default_impl() {
let cache = ToolSchemaCache::default();
assert_eq!(cache.stats().tool_count, 0);
}
#[test]
fn test_properties_extraction() {
let cache = ToolSchemaCache::new();
cache
.populate(
"s",
&[ToolDefinition::new("t").with_input_schema(json!({
"type": "object",
"properties": {
"z_field": {},
"a_field": {},
"m_field": {}
}
}))],
)
.unwrap();
let schema = cache.get("s", "t").unwrap();
assert_eq!(schema.properties.len(), 3);
assert!(schema.properties.contains(&"z_field".to_string()));
assert!(schema.properties.contains(&"a_field".to_string()));
assert!(schema.properties.contains(&"m_field".to_string()));
}
}