1use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use crate::messaging::{AgentMessage, MessageContent, MessageMetadata, MessageRole};
16use crate::state::AgentStateSnapshot;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ToolParameterSchema {
21 #[serde(rename = "type")]
23 pub schema_type: String,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub description: Option<String>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub properties: Option<HashMap<String, ToolParameterSchema>>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub required: Option<Vec<String>>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub items: Option<Box<ToolParameterSchema>>,
40
41 #[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
43 pub enum_values: Option<Vec<Value>>,
44
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub default: Option<Value>,
48
49 #[serde(flatten)]
51 pub additional: HashMap<String, Value>,
52}
53
54impl ToolParameterSchema {
55 pub fn string(description: impl Into<String>) -> Self {
57 Self {
58 schema_type: "string".to_string(),
59 description: Some(description.into()),
60 properties: None,
61 required: None,
62 items: None,
63 enum_values: None,
64 default: None,
65 additional: HashMap::new(),
66 }
67 }
68
69 pub fn number(description: impl Into<String>) -> Self {
71 Self {
72 schema_type: "number".to_string(),
73 description: Some(description.into()),
74 properties: None,
75 required: None,
76 items: None,
77 enum_values: None,
78 default: None,
79 additional: HashMap::new(),
80 }
81 }
82
83 pub fn integer(description: impl Into<String>) -> Self {
85 Self {
86 schema_type: "integer".to_string(),
87 description: Some(description.into()),
88 properties: None,
89 required: None,
90 items: None,
91 enum_values: None,
92 default: None,
93 additional: HashMap::new(),
94 }
95 }
96
97 pub fn boolean(description: impl Into<String>) -> Self {
99 Self {
100 schema_type: "boolean".to_string(),
101 description: Some(description.into()),
102 properties: None,
103 required: None,
104 items: None,
105 enum_values: None,
106 default: None,
107 additional: HashMap::new(),
108 }
109 }
110
111 pub fn object(
113 description: impl Into<String>,
114 properties: HashMap<String, ToolParameterSchema>,
115 required: Vec<String>,
116 ) -> Self {
117 Self {
118 schema_type: "object".to_string(),
119 description: Some(description.into()),
120 properties: Some(properties),
121 required: Some(required),
122 items: None,
123 enum_values: None,
124 default: None,
125 additional: HashMap::new(),
126 }
127 }
128
129 pub fn array(description: impl Into<String>, items: ToolParameterSchema) -> Self {
131 Self {
132 schema_type: "array".to_string(),
133 description: Some(description.into()),
134 properties: None,
135 required: None,
136 items: Some(Box::new(items)),
137 enum_values: None,
138 default: None,
139 additional: HashMap::new(),
140 }
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ToolSchema {
147 pub name: String,
149
150 pub description: String,
152
153 pub parameters: ToolParameterSchema,
155}
156
157impl ToolSchema {
158 pub fn new(
159 name: impl Into<String>,
160 description: impl Into<String>,
161 parameters: ToolParameterSchema,
162 ) -> Self {
163 Self {
164 name: name.into(),
165 description: description.into(),
166 parameters,
167 }
168 }
169
170 pub fn no_params(name: impl Into<String>, description: impl Into<String>) -> Self {
172 Self {
173 name: name.into(),
174 description: description.into(),
175 parameters: ToolParameterSchema {
176 schema_type: "object".to_string(),
177 description: None,
178 properties: Some(HashMap::new()),
179 required: Some(Vec::new()),
180 items: None,
181 enum_values: None,
182 default: None,
183 additional: HashMap::new(),
184 },
185 }
186 }
187}
188
189#[derive(Clone)]
191pub struct ToolContext {
192 pub state: Arc<AgentStateSnapshot>,
194
195 pub state_handle: Option<Arc<std::sync::RwLock<AgentStateSnapshot>>>,
197
198 pub tool_call_id: Option<String>,
200}
201
202impl ToolContext {
203 pub fn new(state: Arc<AgentStateSnapshot>) -> Self {
205 Self {
206 state,
207 state_handle: None,
208 tool_call_id: None,
209 }
210 }
211
212 pub fn with_mutable_state(
214 state: Arc<AgentStateSnapshot>,
215 state_handle: Arc<std::sync::RwLock<AgentStateSnapshot>>,
216 ) -> Self {
217 Self {
218 state,
219 state_handle: Some(state_handle),
220 tool_call_id: None,
221 }
222 }
223
224 pub fn with_call_id(mut self, call_id: Option<String>) -> Self {
226 self.tool_call_id = call_id;
227 self
228 }
229
230 pub fn text_response(&self, content: impl Into<String>) -> AgentMessage {
232 AgentMessage {
233 role: MessageRole::Tool,
234 content: MessageContent::Text(content.into()),
235 metadata: self.tool_call_id.as_ref().map(|id| MessageMetadata {
236 tool_call_id: Some(id.clone()),
237 cache_control: None,
238 }),
239 }
240 }
241
242 pub fn json_response(&self, content: Value) -> AgentMessage {
244 AgentMessage {
245 role: MessageRole::Tool,
246 content: MessageContent::Json(content),
247 metadata: self.tool_call_id.as_ref().map(|id| MessageMetadata {
248 tool_call_id: Some(id.clone()),
249 cache_control: None,
250 }),
251 }
252 }
253}
254
255#[derive(Debug, Clone)]
257pub enum ToolResult {
258 Message(AgentMessage),
260
261 WithStateUpdate {
263 message: AgentMessage,
264 state_diff: crate::command::StateDiff,
265 },
266}
267
268impl ToolResult {
269 pub fn text(ctx: &ToolContext, content: impl Into<String>) -> Self {
271 Self::Message(ctx.text_response(content))
272 }
273
274 pub fn json(ctx: &ToolContext, content: Value) -> Self {
276 Self::Message(ctx.json_response(content))
277 }
278
279 pub fn with_state(message: AgentMessage, state_diff: crate::command::StateDiff) -> Self {
281 Self::WithStateUpdate {
282 message,
283 state_diff,
284 }
285 }
286}
287
288#[async_trait]
290pub trait Tool: Send + Sync {
291 fn schema(&self) -> ToolSchema;
293
294 async fn execute(&self, args: Value, ctx: ToolContext) -> anyhow::Result<ToolResult>;
296}
297
298pub type ToolBox = Arc<dyn Tool>;
300
301#[derive(Clone, Default)]
303pub struct ToolRegistry {
304 tools: HashMap<String, ToolBox>,
305}
306
307impl ToolRegistry {
308 pub fn new() -> Self {
310 Self {
311 tools: HashMap::new(),
312 }
313 }
314
315 pub fn register(&mut self, tool: ToolBox) -> &mut Self {
317 let name = tool.schema().name.clone();
318 self.tools.insert(name, tool);
319 self
320 }
321
322 pub fn register_all<I>(&mut self, tools: I) -> &mut Self
324 where
325 I: IntoIterator<Item = ToolBox>,
326 {
327 for tool in tools {
328 self.register(tool);
329 }
330 self
331 }
332
333 pub fn get(&self, name: &str) -> Option<&ToolBox> {
335 self.tools.get(name)
336 }
337
338 pub fn all(&self) -> Vec<&ToolBox> {
340 self.tools.values().collect()
341 }
342
343 pub fn schemas(&self) -> Vec<ToolSchema> {
345 self.tools.values().map(|t| t.schema()).collect()
346 }
347
348 pub fn names(&self) -> Vec<String> {
350 self.tools.keys().cloned().collect()
351 }
352
353 pub fn has(&self, name: &str) -> bool {
355 self.tools.contains_key(name)
356 }
357
358 pub fn len(&self) -> usize {
360 self.tools.len()
361 }
362
363 pub fn is_empty(&self) -> bool {
365 self.tools.is_empty()
366 }
367}