use jsonschema::Validator;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value as JsonValue;
use std::sync::Arc;
use thiserror::Error;
pub type ProgressCallback = Arc<dyn Fn(String) + Send + Sync>;
pub fn progress_callback<F: Fn(String) + Send + Sync + 'static>(f: F) -> ProgressCallback {
Arc::new(f)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub parameters: JsonValue,
}
impl Tool {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters: JsonValue,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
}
}
pub fn with_string_param(
name: impl Into<String>,
description: impl Into<String>,
param_name: impl Into<String>,
param_description: impl Into<String>,
) -> Self {
let param_name = param_name.into();
let param_description = param_description.into();
let mut properties = serde_json::Map::new();
properties.insert("type".to_string(), json!("object"));
let mut obj_properties = serde_json::Map::new();
obj_properties.insert(
param_name.clone(),
json!({
"type": "string",
"description": param_description
}),
);
properties.insert(
"properties".to_string(),
serde_json::Value::Object(obj_properties),
);
let required_arr =
serde_json::Value::Array(vec![serde_json::Value::String(param_name.clone())]);
properties.insert("required".to_string(), required_arr);
let params = serde_json::Value::Object(properties);
Self::new(name, description, params)
}
pub fn validate(&self, args: &JsonValue) -> Result<JsonValue, ValidationError> {
validate_args_internal(&self.parameters, args)
}
pub fn requires_parameters(&self) -> bool {
self.parameters
.get("required")
.and_then(|r| r.as_array())
.map(|arr| !arr.is_empty())
.unwrap_or(false)
}
}
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("Invalid JSON: {0}")]
InvalidJson(#[from] serde_json::Error),
#[error("Schema validation failed: {0}")]
SchemaValidation(String),
#[error("Missing required field: {0}")]
MissingRequiredField(String),
}
pub fn validate_args(tool: &Tool, args: &JsonValue) -> Result<JsonValue, ValidationError> {
validate_args_internal(&tool.parameters, args)
}
fn validate_args_internal(
schema: &JsonValue,
args: &JsonValue,
) -> Result<JsonValue, ValidationError> {
let validator =
Validator::new(schema).map_err(|e| ValidationError::SchemaValidation(e.to_string()))?;
let validation_result = validator.validate(args);
match validation_result {
Ok(()) => Ok(args.clone()),
Err(errors) => {
let error_msg = format!("{}", errors);
Err(ValidationError::SchemaValidation(error_msg))
}
}
}
#[allow(dead_code)]
pub fn create_schema(fields: &[(&str, &str, &str)]) -> JsonValue {
let mut properties = serde_json::Map::new();
let mut required: Vec<&str> = Vec::new();
for (name, schema_type, description) in fields {
let prop = serde_json::json!({
"type": schema_type,
"description": description
});
properties.insert(name.to_string(), prop);
required.push(name);
}
serde_json::json!({
"type": "object",
"properties": properties,
"required": required
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_validation() {
let tool = Tool::with_string_param(
"get_weather",
"Get current weather for a location",
"location",
"City name or coordinates",
);
let valid_args = serde_json::json!({
"location": "London"
});
let result = tool.validate(&valid_args);
assert!(result.is_ok());
}
#[test]
fn test_tool_validation_failure() {
let tool = Tool::with_string_param(
"get_weather",
"Get current weather for a location",
"location",
"City name or coordinates",
);
let invalid_args = serde_json::json!({});
let result = tool.validate(&invalid_args);
assert!(result.is_err());
}
}