use crate::mcp::client::error::{ClientError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
pub name: String,
pub description: Option<String>,
pub input_schema: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct ToolRegistry {
tools: HashMap<String, ToolInfo>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn add_tool(&mut self, tool: ToolInfo) {
self.tools.insert(tool.name.clone(), tool);
}
pub fn get_tool(&self, name: &str) -> Option<&ToolInfo> {
self.tools.get(name)
}
pub fn has_tool(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn tool_names(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
}
pub fn validate_parameters(&self, tool_name: &str, params: &serde_json::Value) -> Result<()> {
if !self.has_tool(tool_name) {
return Err(ClientError::Validation(format!(
"Unknown tool: '{}'. Available tools: {:?}",
tool_name,
self.tool_names()
)));
}
let tool_info = self.get_tool(tool_name).unwrap();
let Some(schema) = &tool_info.input_schema else {
return Ok(());
};
self.validate_basic_structure(schema, params, tool_name)?;
if let Some(params_obj) = params.as_object() {
self.check_required_fields(schema, params_obj)?;
self.validate_parameter_types(schema, params_obj)?;
}
Ok(())
}
fn validate_basic_structure(
&self,
schema: &serde_json::Value,
params: &serde_json::Value,
tool_name: &str,
) -> Result<()> {
if schema.get("type") == Some(&serde_json::Value::String("object".to_string()))
&& !params.is_object()
&& !params.is_null()
{
return Err(ClientError::Validation(format!(
"Tool '{}' expects object parameters, got: {}",
tool_name, params
)));
}
if params.is_null() || (params.is_object() && params.as_object().unwrap().is_empty()) {
if let Some(required) = schema.get("required") {
if let Some(required_array) = required.as_array() {
if !required_array.is_empty() {
let required_fields: Vec<String> = required_array
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
return Err(ClientError::Validation(format!(
"required parameter '{}' is missing",
required_fields.first().unwrap_or(&"unknown".to_string())
)));
}
}
}
}
Ok(())
}
fn check_required_fields(
&self,
schema: &serde_json::Value,
params_obj: &serde_json::Map<String, serde_json::Value>,
) -> Result<()> {
if let Some(required) = schema.get("required") {
if let Some(required_array) = required.as_array() {
for required_field in required_array {
if let Some(field_name) = required_field.as_str() {
if !params_obj.contains_key(field_name) {
return Err(ClientError::Validation(format!(
"required parameter '{}' is missing",
field_name
)));
}
}
}
}
}
Ok(())
}
fn validate_parameter_types(
&self,
schema: &serde_json::Value,
params_obj: &serde_json::Map<String, serde_json::Value>,
) -> Result<()> {
if let Some(properties) = schema.get("properties") {
if let Some(properties_obj) = properties.as_object() {
for (param_name, param_value) in params_obj {
if !properties_obj.contains_key(param_name) {
return Err(ClientError::Validation(format!(
"unknown parameter '{}'",
param_name
)));
}
self.validate_single_parameter_type(
param_name,
param_value,
properties_obj.get(param_name).unwrap(),
)?;
}
}
}
Ok(())
}
fn validate_single_parameter_type(
&self,
param_name: &str,
param_value: &serde_json::Value,
param_schema: &serde_json::Value,
) -> Result<()> {
if let Some(expected_type) = param_schema.get("type") {
if let Some(type_str) = expected_type.as_str() {
let actual_type = self.get_json_value_type(param_value);
let type_matches = match type_str {
"number" => actual_type == "number" || actual_type == "integer",
_ => type_str == actual_type,
};
if !type_matches {
return Err(ClientError::Validation(format!(
"parameter '{}' should be a {}, got {}",
param_name,
if type_str == "integer" {
"number"
} else {
type_str
},
actual_type
)));
}
}
}
Ok(())
}
fn get_json_value_type(&self, value: &serde_json::Value) -> &'static str {
match value {
serde_json::Value::Null => "null",
serde_json::Value::Bool(_) => "boolean",
serde_json::Value::Number(n) => {
if n.is_i64() || n.is_u64() {
"integer"
} else {
"number"
}
}
serde_json::Value::String(_) => "string",
serde_json::Value::Array(_) => "array",
serde_json::Value::Object(_) => "object",
}
}
pub fn update_from_rmcp_tools(&mut self, rmcp_tools: Vec<rmcp::model::Tool>) {
self.tools.clear();
for rmcp_tool in rmcp_tools {
let tool_info = ToolInfo {
name: rmcp_tool.name.to_string(),
description: rmcp_tool.description.map(|d| d.to_string()),
input_schema: Some(serde_json::Value::Object((*rmcp_tool.input_schema).clone())),
};
self.add_tool(tool_info);
}
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_tool_registry_basic() {
let mut registry = ToolRegistry::new();
let tool = ToolInfo {
name: "get_pet_by_id".to_string(),
description: Some("Get a pet by ID".to_string()),
input_schema: Some(json!({
"type": "object",
"properties": {
"id": {"type": "integer"}
},
"required": ["id"]
})),
};
registry.add_tool(tool);
assert!(registry.has_tool("get_pet_by_id"));
assert!(!registry.has_tool("nonexistent_tool"));
assert_eq!(registry.tool_names(), vec!["get_pet_by_id"]);
}
#[test]
fn test_parameter_validation() {
let mut registry = ToolRegistry::new();
let tool = ToolInfo {
name: "get_pet_by_id".to_string(),
description: Some("Get a pet by ID".to_string()),
input_schema: Some(json!({"type": "object"})),
};
registry.add_tool(tool);
assert!(
registry
.validate_parameters("get_pet_by_id", &json!({"id": 123}))
.is_ok()
);
assert!(
registry
.validate_parameters("get_pet_by_id", &json!(null))
.is_ok()
);
assert!(
registry
.validate_parameters("get_pet_by_id", &json!([1, 2, 3]))
.is_err()
);
assert!(
registry
.validate_parameters("unknown_tool", &json!({}))
.is_err()
);
}
}