1use jsonschema::Validator;
4use serde::{Deserialize, Serialize};
5use serde_json::json;
6use serde_json::Value as JsonValue;
7use std::sync::Arc;
8use thiserror::Error;
9
10pub type ProgressCallback = Arc<dyn Fn(String) + Send + Sync>;
12
13pub fn progress_callback<F: Fn(String) + Send + Sync + 'static>(f: F) -> ProgressCallback {
15 Arc::new(f)
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Tool {
21 pub name: String,
23
24 pub description: String,
26
27 pub parameters: JsonValue,
29}
30
31impl Tool {
32 pub fn new(
54 name: impl Into<String>,
55 description: impl Into<String>,
56 parameters: JsonValue,
57 ) -> Self {
58 Self {
59 name: name.into(),
60 description: description.into(),
61 parameters,
62 }
63 }
64
65 pub fn with_string_param(
80 name: impl Into<String>,
81 description: impl Into<String>,
82 param_name: impl Into<String>,
83 param_description: impl Into<String>,
84 ) -> Self {
85 let param_name = param_name.into();
86 let param_description = param_description.into();
87
88 let mut properties = serde_json::Map::new();
90 properties.insert("type".to_string(), json!("object"));
91
92 let mut obj_properties = serde_json::Map::new();
93 obj_properties.insert(
94 param_name.clone(),
95 json!({
96 "type": "string",
97 "description": param_description
98 }),
99 );
100 properties.insert(
101 "properties".to_string(),
102 serde_json::Value::Object(obj_properties),
103 );
104
105 let required_arr =
106 serde_json::Value::Array(vec![serde_json::Value::String(param_name.clone())]);
107 properties.insert("required".to_string(), required_arr);
108
109 let params = serde_json::Value::Object(properties);
110 Self::new(name, description, params)
111 }
112
113 pub fn validate(&self, args: &JsonValue) -> Result<JsonValue, ToolValidationError> {
129 validate_args_internal(&self.parameters, args)
130 }
131
132 pub fn requires_parameters(&self) -> bool {
134 self.parameters
135 .get("required")
136 .and_then(|r| r.as_array())
137 .map(|arr| !arr.is_empty())
138 .unwrap_or(false)
139 }
140}
141
142#[derive(Error, Debug)]
144pub enum ToolValidationError {
145 #[error("Invalid JSON: {0}")]
146 InvalidJson(#[from] serde_json::Error),
148
149 #[error("Schema validation failed: {0}")]
150 SchemaValidation(String),
152
153 #[error("Missing required field: {0}")]
154 MissingRequiredField(String),
156}
157
158pub fn validate_args(tool: &Tool, args: &JsonValue) -> Result<JsonValue, ToolValidationError> {
160 validate_args_internal(&tool.parameters, args)
161}
162
163fn validate_args_internal(
165 schema: &JsonValue,
166 args: &JsonValue,
167) -> Result<JsonValue, ToolValidationError> {
168 let validator =
169 Validator::new(schema).map_err(|e| ToolValidationError::SchemaValidation(e.to_string()))?;
170
171 let validation_result = validator.validate(args);
172
173 match validation_result {
174 Ok(()) => Ok(args.clone()),
175 Err(errors) => {
176 let error_msg = format!("{}", errors);
178 Err(ToolValidationError::SchemaValidation(error_msg))
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_tool_validation() {
189 let tool = Tool::with_string_param(
190 "get_weather",
191 "Get current weather for a location",
192 "location",
193 "City name or coordinates",
194 );
195
196 let valid_args = serde_json::json!({
197 "location": "London"
198 });
199
200 let result = tool.validate(&valid_args);
201 assert!(result.is_ok());
202 }
203
204 #[test]
205 fn test_tool_validation_failure() {
206 let tool = Tool::with_string_param(
207 "get_weather",
208 "Get current weather for a location",
209 "location",
210 "City name or coordinates",
211 );
212
213 let invalid_args = serde_json::json!({});
215
216 let result = tool.validate(&invalid_args);
217 assert!(result.is_err());
218 }
219}