1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4
5#[derive(Clone, Debug, Serialize, Deserialize)]
6#[serde(rename_all = "lowercase")]
7pub enum JsonSchemaType {
8 String,
9 Number,
10 Integer,
11 Boolean,
12 Array,
13 Object,
14 Null,
15}
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
18pub struct ToolParameter {
19 #[serde(rename = "type")]
20 pub schema_type: JsonSchemaType,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub description: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub default: Option<Value>,
25 #[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
26 pub enum_values: Option<Vec<Value>>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub items: Option<Box<ToolParameter>>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub properties: Option<HashMap<String, ToolParameter>>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub required: Option<Vec<String>>,
33}
34
35impl ToolParameter {
36 pub fn string() -> Self {
37 Self {
38 schema_type: JsonSchemaType::String,
39 description: None,
40 default: None,
41 enum_values: None,
42 items: None,
43 properties: None,
44 required: None,
45 }
46 }
47
48 pub fn number() -> Self {
49 Self {
50 schema_type: JsonSchemaType::Number,
51 ..Self::string()
52 }
53 }
54
55 pub fn integer() -> Self {
56 Self {
57 schema_type: JsonSchemaType::Integer,
58 ..Self::string()
59 }
60 }
61
62 pub fn boolean() -> Self {
63 Self {
64 schema_type: JsonSchemaType::Boolean,
65 ..Self::string()
66 }
67 }
68
69 pub fn array(items: ToolParameter) -> Self {
70 Self {
71 schema_type: JsonSchemaType::Array,
72 items: Some(Box::new(items)),
73 ..Self::string()
74 }
75 }
76
77 pub fn object() -> Self {
78 Self {
79 schema_type: JsonSchemaType::Object,
80 properties: Some(HashMap::new()),
81 required: Some(vec![]),
82 ..Self::string()
83 }
84 }
85
86 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
87 self.description = Some(desc.into());
88 self
89 }
90
91 pub fn with_default(mut self, default: Value) -> Self {
92 self.default = Some(default);
93 self
94 }
95
96 pub fn with_enum(mut self, values: Vec<Value>) -> Self {
97 self.enum_values = Some(values);
98 self
99 }
100
101 pub fn with_property(mut self, name: impl Into<String>, param: ToolParameter) -> Self {
102 if let Some(props) = &mut self.properties {
103 props.insert(name.into(), param);
104 }
105 self
106 }
107
108 pub fn with_required(mut self, name: impl Into<String>) -> Self {
109 if let Some(req) = &mut self.required {
110 req.push(name.into());
111 }
112 self
113 }
114}
115
116#[derive(Clone, Debug, Serialize, Deserialize)]
117pub struct ToolDefinition {
118 pub name: String,
119 pub description: String,
120 pub parameters: ToolParameter,
121 #[serde(default)]
122 pub dangerous: bool,
123 #[serde(skip_serializing_if = "Option::is_none")]
124 pub category: Option<String>,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 pub version: Option<String>,
127}
128
129impl ToolDefinition {
130 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
131 Self {
132 name: name.into(),
133 description: description.into(),
134 parameters: ToolParameter::object(),
135 dangerous: false,
136 category: None,
137 version: None,
138 }
139 }
140
141 pub fn with_parameters(mut self, params: ToolParameter) -> Self {
142 self.parameters = params;
143 self
144 }
145
146 pub fn with_param(mut self, name: impl Into<String>, param: ToolParameter) -> Self {
147 if let Some(props) = &mut self.parameters.properties {
148 props.insert(name.into(), param);
149 }
150 self
151 }
152
153 pub fn with_required_param(self, name: impl Into<String>, param: ToolParameter) -> Self {
154 let name = name.into();
155 self.with_param(name.clone(), param).require_param(name)
156 }
157
158 pub fn require_param(mut self, name: impl Into<String>) -> Self {
159 if let Some(req) = &mut self.parameters.required {
160 req.push(name.into());
161 }
162 self
163 }
164
165 pub fn dangerous(mut self) -> Self {
166 self.dangerous = true;
167 self
168 }
169
170 pub fn with_category(mut self, category: impl Into<String>) -> Self {
171 self.category = Some(category.into());
172 self
173 }
174
175 pub fn with_version(mut self, version: impl Into<String>) -> Self {
176 self.version = Some(version.into());
177 self
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_tool_definition() {
187 let tool = ToolDefinition::new("read_file", "Read contents of a file")
188 .with_required_param(
189 "path",
190 ToolParameter::string().with_description("File path to read"),
191 )
192 .with_param(
193 "encoding",
194 ToolParameter::string()
195 .with_description("File encoding")
196 .with_default(Value::String("utf-8".into())),
197 )
198 .with_category("filesystem");
199
200 assert_eq!(tool.name, "read_file");
201 assert!(!tool.dangerous);
202 assert_eq!(tool.category, Some("filesystem".to_string()));
203
204 let props = tool.parameters.properties.as_ref().unwrap();
205 assert!(props.contains_key("path"));
206 assert!(props.contains_key("encoding"));
207
208 let required = tool.parameters.required.as_ref().unwrap();
209 assert!(required.contains(&"path".to_string()));
210 }
211}