hehe_core/tool/
schema.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4
5#[derive(Clone, Debug, Serialize, Deserialize)]
6#[serde(rename_all = "lowercase")]
7pub enum JsonSchemaType {
8    String,
9    Number,
10    Integer,
11    Boolean,
12    Array,
13    Object,
14    Null,
15}
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
18pub struct ToolParameter {
19    #[serde(rename = "type")]
20    pub schema_type: JsonSchemaType,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub description: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub default: Option<Value>,
25    #[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
26    pub enum_values: Option<Vec<Value>>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub items: Option<Box<ToolParameter>>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub properties: Option<HashMap<String, ToolParameter>>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub required: Option<Vec<String>>,
33}
34
35impl ToolParameter {
36    pub fn string() -> Self {
37        Self {
38            schema_type: JsonSchemaType::String,
39            description: None,
40            default: None,
41            enum_values: None,
42            items: None,
43            properties: None,
44            required: None,
45        }
46    }
47
48    pub fn number() -> Self {
49        Self {
50            schema_type: JsonSchemaType::Number,
51            ..Self::string()
52        }
53    }
54
55    pub fn integer() -> Self {
56        Self {
57            schema_type: JsonSchemaType::Integer,
58            ..Self::string()
59        }
60    }
61
62    pub fn boolean() -> Self {
63        Self {
64            schema_type: JsonSchemaType::Boolean,
65            ..Self::string()
66        }
67    }
68
69    pub fn array(items: ToolParameter) -> Self {
70        Self {
71            schema_type: JsonSchemaType::Array,
72            items: Some(Box::new(items)),
73            ..Self::string()
74        }
75    }
76
77    pub fn object() -> Self {
78        Self {
79            schema_type: JsonSchemaType::Object,
80            properties: Some(HashMap::new()),
81            required: Some(vec![]),
82            ..Self::string()
83        }
84    }
85
86    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
87        self.description = Some(desc.into());
88        self
89    }
90
91    pub fn with_default(mut self, default: Value) -> Self {
92        self.default = Some(default);
93        self
94    }
95
96    pub fn with_enum(mut self, values: Vec<Value>) -> Self {
97        self.enum_values = Some(values);
98        self
99    }
100
101    pub fn with_property(mut self, name: impl Into<String>, param: ToolParameter) -> Self {
102        if let Some(props) = &mut self.properties {
103            props.insert(name.into(), param);
104        }
105        self
106    }
107
108    pub fn with_required(mut self, name: impl Into<String>) -> Self {
109        if let Some(req) = &mut self.required {
110            req.push(name.into());
111        }
112        self
113    }
114}
115
116#[derive(Clone, Debug, Serialize, Deserialize)]
117pub struct ToolDefinition {
118    pub name: String,
119    pub description: String,
120    pub parameters: ToolParameter,
121    #[serde(default)]
122    pub dangerous: bool,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub category: Option<String>,
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub version: Option<String>,
127}
128
129impl ToolDefinition {
130    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
131        Self {
132            name: name.into(),
133            description: description.into(),
134            parameters: ToolParameter::object(),
135            dangerous: false,
136            category: None,
137            version: None,
138        }
139    }
140
141    pub fn with_parameters(mut self, params: ToolParameter) -> Self {
142        self.parameters = params;
143        self
144    }
145
146    pub fn with_param(mut self, name: impl Into<String>, param: ToolParameter) -> Self {
147        if let Some(props) = &mut self.parameters.properties {
148            props.insert(name.into(), param);
149        }
150        self
151    }
152
153    pub fn with_required_param(self, name: impl Into<String>, param: ToolParameter) -> Self {
154        let name = name.into();
155        self.with_param(name.clone(), param).require_param(name)
156    }
157
158    pub fn require_param(mut self, name: impl Into<String>) -> Self {
159        if let Some(req) = &mut self.parameters.required {
160            req.push(name.into());
161        }
162        self
163    }
164
165    pub fn dangerous(mut self) -> Self {
166        self.dangerous = true;
167        self
168    }
169
170    pub fn with_category(mut self, category: impl Into<String>) -> Self {
171        self.category = Some(category.into());
172        self
173    }
174
175    pub fn with_version(mut self, version: impl Into<String>) -> Self {
176        self.version = Some(version.into());
177        self
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_tool_definition() {
187        let tool = ToolDefinition::new("read_file", "Read contents of a file")
188            .with_required_param(
189                "path",
190                ToolParameter::string().with_description("File path to read"),
191            )
192            .with_param(
193                "encoding",
194                ToolParameter::string()
195                    .with_description("File encoding")
196                    .with_default(Value::String("utf-8".into())),
197            )
198            .with_category("filesystem");
199
200        assert_eq!(tool.name, "read_file");
201        assert!(!tool.dangerous);
202        assert_eq!(tool.category, Some("filesystem".to_string()));
203
204        let props = tool.parameters.properties.as_ref().unwrap();
205        assert!(props.contains_key("path"));
206        assert!(props.contains_key("encoding"));
207
208        let required = tool.parameters.required.as_ref().unwrap();
209        assert!(required.contains(&"path".to_string()));
210    }
211}