Skip to main content

oxi_ai/
tools.rs

1//! Tool definitions and validation
2
3use jsonschema::Validator;
4use serde::{Deserialize, Serialize};
5use serde_json::json;
6use serde_json::Value as JsonValue;
7use std::sync::Arc;
8use thiserror::Error;
9
10/// Callback type for progress updates
11pub type ProgressCallback = Arc<dyn Fn(String) + Send + Sync>;
12
13/// Create a progress callback from a closure
14pub fn progress_callback<F: Fn(String) + Send + Sync + 'static>(f: F) -> ProgressCallback {
15    Arc::new(f)
16}
17
18/// Tool definition with JSON Schema parameters
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Tool {
21    /// Tool name
22    pub name: String,
23
24    /// Human-readable description
25    pub description: String,
26
27    /// JSON Schema for parameters
28    pub parameters: JsonValue,
29}
30
31impl Tool {
32    /// Create a new tool with the given name, description, and JSON Schema.
33    ///
34    /// # Examples
35    ///
36    /// ```
37    /// use oxi_ai::Tool;
38    /// let tool = Tool::new(
39    ///     "read_file",
40    ///     "Read contents from a file",
41    ///     serde_json::json!({
42    ///         "type": "object",
43    ///         "properties": {
44    ///             "path": {
45    ///                 "type": "string",
46    ///                 "description": "File path to read"
47    ///             }
48    ///         },
49    ///         "required": ["path"]
50    ///     }),
51    /// );
52    /// ```
53    pub fn new(
54        name: impl Into<String>,
55        description: impl Into<String>,
56        parameters: JsonValue,
57    ) -> Self {
58        Self {
59            name: name.into(),
60            description: description.into(),
61            parameters,
62        }
63    }
64
65    /// Create a simple tool with a single string parameter
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// use oxi_ai::Tool;
71    /// let tool = Tool::with_string_param(
72    ///     "get_weather",
73    ///     "Get current weather",
74    ///     "location",
75    ///     "City name",
76    /// );
77    /// assert_eq!(tool.name, "get_weather");
78    /// ```
79    pub fn with_string_param(
80        name: impl Into<String>,
81        description: impl Into<String>,
82        param_name: impl Into<String>,
83        param_description: impl Into<String>,
84    ) -> Self {
85        let param_name = param_name.into();
86        let param_description = param_description.into();
87
88        // Build properties manually to avoid borrow issues
89        let mut properties = serde_json::Map::new();
90        properties.insert("type".to_string(), json!("object"));
91
92        let mut obj_properties = serde_json::Map::new();
93        obj_properties.insert(
94            param_name.clone(),
95            json!({
96                "type": "string",
97                "description": param_description
98            }),
99        );
100        properties.insert(
101            "properties".to_string(),
102            serde_json::Value::Object(obj_properties),
103        );
104
105        let required_arr =
106            serde_json::Value::Array(vec![serde_json::Value::String(param_name.clone())]);
107        properties.insert("required".to_string(), required_arr);
108
109        let params = serde_json::Value::Object(properties);
110        Self::new(name, description, params)
111    }
112
113    /// Validate arguments against the tool's JSON Schema
114    ///
115    /// # Examples
116    ///
117    /// ```
118    /// use oxi_ai::Tool;
119    /// let tool = Tool::with_string_param(
120    ///     "get_weather",
121    ///     "Get weather",
122    ///     "location",
123    ///     "City",
124    /// );
125    /// let result = tool.validate(&serde_json::json!({"location": "London"}));
126    /// assert!(result.is_ok());
127    /// ```
128    pub fn validate(&self, args: &JsonValue) -> Result<JsonValue, ToolValidationError> {
129        validate_args_internal(&self.parameters, args)
130    }
131
132    /// Check if this tool requires parameters
133    pub fn requires_parameters(&self) -> bool {
134        self.parameters
135            .get("required")
136            .and_then(|r| r.as_array())
137            .map(|arr| !arr.is_empty())
138            .unwrap_or(false)
139    }
140}
141
142/// Validation error
143#[derive(Error, Debug)]
144pub enum ToolValidationError {
145    #[error("Invalid JSON: {0}")]
146    /// invalid json variant.
147    InvalidJson(#[from] serde_json::Error),
148
149    #[error("Schema validation failed: {0}")]
150    /// schema validation variant.
151    SchemaValidation(String),
152
153    #[error("Missing required field: {0}")]
154    /// missing required field variant.
155    MissingRequiredField(String),
156}
157
158/// Validate tool arguments against a JSON Schema
159pub fn validate_args(tool: &Tool, args: &JsonValue) -> Result<JsonValue, ToolValidationError> {
160    validate_args_internal(&tool.parameters, args)
161}
162
163/// Internal validation implementation
164fn validate_args_internal(
165    schema: &JsonValue,
166    args: &JsonValue,
167) -> Result<JsonValue, ToolValidationError> {
168    let validator =
169        Validator::new(schema).map_err(|e| ToolValidationError::SchemaValidation(e.to_string()))?;
170
171    let validation_result = validator.validate(args);
172
173    match validation_result {
174        Ok(()) => Ok(args.clone()),
175        Err(errors) => {
176            // jsonschema returns an error with formatted message
177            let error_msg = format!("{}", errors);
178            Err(ToolValidationError::SchemaValidation(error_msg))
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_tool_validation() {
189        let tool = Tool::with_string_param(
190            "get_weather",
191            "Get current weather for a location",
192            "location",
193            "City name or coordinates",
194        );
195
196        let valid_args = serde_json::json!({
197            "location": "London"
198        });
199
200        let result = tool.validate(&valid_args);
201        assert!(result.is_ok());
202    }
203
204    #[test]
205    fn test_tool_validation_failure() {
206        let tool = Tool::with_string_param(
207            "get_weather",
208            "Get current weather for a location",
209            "location",
210            "City name or coordinates",
211        );
212
213        // Missing required field
214        let invalid_args = serde_json::json!({});
215
216        let result = tool.validate(&invalid_args);
217        assert!(result.is_err());
218    }
219}