mofa_kernel/agent/components/
tool.rs1use crate::agent::context::AgentContext;
6use crate::agent::error::{AgentError, AgentResult};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[async_trait]
49pub trait Tool: Send + Sync {
50 fn name(&self) -> &str;
52
53 fn description(&self) -> &str;
55
56 fn parameters_schema(&self) -> serde_json::Value;
58
59 async fn execute(&self, input: ToolInput, ctx: &AgentContext) -> ToolResult;
61
62 fn metadata(&self) -> ToolMetadata {
64 ToolMetadata::default()
65 }
66
67 fn validate_input(&self, input: &ToolInput) -> AgentResult<()> {
69 let _ = input;
71 Ok(())
72 }
73
74 fn requires_confirmation(&self) -> bool {
76 false
77 }
78
79 fn to_llm_tool(&self) -> LLMTool {
81 LLMTool {
82 name: self.name().to_string(),
83 description: self.description().to_string(),
84 parameters: self.parameters_schema(),
85 }
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ToolInput {
92 pub arguments: serde_json::Value,
94 pub raw_input: Option<String>,
96}
97
98impl ToolInput {
99 pub fn from_json(arguments: serde_json::Value) -> Self {
101 Self {
102 arguments,
103 raw_input: None,
104 }
105 }
106
107 pub fn from_raw(raw: impl Into<String>) -> Self {
109 let raw = raw.into();
110 Self {
111 arguments: serde_json::Value::String(raw.clone()),
112 raw_input: Some(raw),
113 }
114 }
115
116 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
118 self.arguments
119 .get(key)
120 .and_then(|v| serde_json::from_value(v.clone()).ok())
121 }
122
123 pub fn get_str(&self, key: &str) -> Option<&str> {
125 self.arguments.get(key).and_then(|v| v.as_str())
126 }
127
128 pub fn get_number(&self, key: &str) -> Option<f64> {
130 self.arguments.get(key).and_then(|v| v.as_f64())
131 }
132
133 pub fn get_bool(&self, key: &str) -> Option<bool> {
135 self.arguments.get(key).and_then(|v| v.as_bool())
136 }
137}
138
139impl From<serde_json::Value> for ToolInput {
140 fn from(v: serde_json::Value) -> Self {
141 Self::from_json(v)
142 }
143}
144
145impl From<String> for ToolInput {
146 fn from(s: String) -> Self {
147 Self::from_raw(s)
148 }
149}
150
151impl From<&str> for ToolInput {
152 fn from(s: &str) -> Self {
153 Self::from_raw(s)
154 }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ToolResult {
160 pub success: bool,
162 pub output: serde_json::Value,
164 pub error: Option<String>,
166 pub metadata: HashMap<String, String>,
168}
169
170impl ToolResult {
171 pub fn success(output: serde_json::Value) -> Self {
173 Self {
174 success: true,
175 output,
176 error: None,
177 metadata: HashMap::new(),
178 }
179 }
180
181 pub fn success_text(text: impl Into<String>) -> Self {
183 Self::success(serde_json::Value::String(text.into()))
184 }
185
186 pub fn failure(error: impl Into<String>) -> Self {
188 Self {
189 success: false,
190 output: serde_json::Value::Null,
191 error: Some(error.into()),
192 metadata: HashMap::new(),
193 }
194 }
195
196 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
198 self.metadata.insert(key.into(), value.into());
199 self
200 }
201
202 pub fn as_text(&self) -> Option<&str> {
204 self.output.as_str()
205 }
206
207 pub fn to_string_output(&self) -> String {
209 if self.success {
210 match &self.output {
211 serde_json::Value::String(s) => s.clone(),
212 v => v.to_string(),
213 }
214 } else {
215 format!(
216 "Error: {}",
217 self.error.as_deref().unwrap_or("Unknown error")
218 )
219 }
220 }
221}
222
223#[derive(Debug, Clone, Default, Serialize, Deserialize)]
225pub struct ToolMetadata {
226 pub category: Option<String>,
228 pub tags: Vec<String>,
230 pub is_dangerous: bool,
232 pub requires_network: bool,
234 pub requires_filesystem: bool,
236 pub custom: HashMap<String, serde_json::Value>,
238}
239
240impl ToolMetadata {
241 pub fn new() -> Self {
243 Self::default()
244 }
245
246 pub fn with_category(mut self, category: impl Into<String>) -> Self {
248 self.category = Some(category.into());
249 self
250 }
251
252 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
254 self.tags.push(tag.into());
255 self
256 }
257
258 pub fn dangerous(mut self) -> Self {
260 self.is_dangerous = true;
261 self
262 }
263
264 pub fn needs_network(mut self) -> Self {
266 self.requires_network = true;
267 self
268 }
269
270 pub fn needs_filesystem(mut self) -> Self {
272 self.requires_filesystem = true;
273 self
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct ToolDescriptor {
280 pub name: String,
282 pub description: String,
284 pub parameters_schema: serde_json::Value,
286 pub metadata: ToolMetadata,
288}
289
290impl ToolDescriptor {
291 pub fn from_tool(tool: &dyn Tool) -> Self {
293 Self {
294 name: tool.name().to_string(),
295 description: tool.description().to_string(),
296 parameters_schema: tool.parameters_schema(),
297 metadata: tool.metadata(),
298 }
299 }
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct LLMTool {
305 pub name: String,
307 pub description: String,
309 pub parameters: serde_json::Value,
311}
312
313#[async_trait]
319pub trait ToolRegistry: Send + Sync {
320 fn register(&mut self, tool: Arc<dyn Tool>) -> AgentResult<()>;
322
323 fn register_all(&mut self, tools: Vec<Arc<dyn Tool>>) -> AgentResult<()> {
325 for tool in tools {
326 self.register(tool)?;
327 }
328 Ok(())
329 }
330
331 fn get(&self, name: &str) -> Option<Arc<dyn Tool>>;
333
334 fn unregister(&mut self, name: &str) -> AgentResult<bool>;
336
337 fn list(&self) -> Vec<ToolDescriptor>;
339
340 fn list_names(&self) -> Vec<String>;
342
343 fn contains(&self, name: &str) -> bool;
345
346 fn count(&self) -> usize;
348
349 async fn execute(
351 &self,
352 name: &str,
353 input: ToolInput,
354 ctx: &AgentContext,
355 ) -> AgentResult<ToolResult> {
356 let tool = self
357 .get(name)
358 .ok_or_else(|| AgentError::ToolNotFound(name.to_string()))?;
359 tool.validate_input(&input)?;
360 Ok(tool.execute(input, ctx).await)
361 }
362
363 fn to_llm_tools(&self) -> Vec<LLMTool> {
365 self.list()
366 .iter()
367 .map(|d| LLMTool {
368 name: d.name.clone(),
369 description: d.description.clone(),
370 parameters: d.parameters_schema.clone(),
371 })
372 .collect()
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use crate::agent::context::AgentContext;
380
381 #[test]
382 fn test_tool_input_from_json() {
383 let input = ToolInput::from_json(serde_json::json!({
384 "name": "test",
385 "count": 42
386 }));
387
388 assert_eq!(input.get_str("name"), Some("test"));
389 assert_eq!(input.get_number("count"), Some(42.0));
390 }
391
392 #[test]
393 fn test_tool_result() {
394 let success = ToolResult::success_text("OK");
395 assert!(success.success);
396 assert_eq!(success.as_text(), Some("OK"));
397
398 let failure = ToolResult::failure("Something went wrong");
399 assert!(!failure.success);
400 assert!(failure.error.is_some());
401 }
402}