1use crate::{CallbackContext, EventActions, MemoryEntry, Result};
2use async_trait::async_trait;
3use serde_json::Value;
4use std::sync::Arc;
5
6#[async_trait]
7pub trait Tool: Send + Sync {
8 fn name(&self) -> &str;
9 fn description(&self) -> &str;
10
11 fn enhanced_description(&self) -> String {
16 self.description().to_string()
17 }
18
19 fn is_long_running(&self) -> bool {
23 false
24 }
25 fn parameters_schema(&self) -> Option<Value> {
26 None
27 }
28 fn response_schema(&self) -> Option<Value> {
29 None
30 }
31 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value>;
32}
33
34#[async_trait]
35pub trait ToolContext: CallbackContext {
36 fn function_call_id(&self) -> &str;
37 fn actions(&self) -> EventActions;
39 fn set_actions(&self, actions: EventActions);
41 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>>;
42}
43
44#[async_trait]
45pub trait Toolset: Send + Sync {
46 fn name(&self) -> &str;
47 async fn tools(&self, ctx: Arc<dyn crate::ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>>;
48}
49
50pub type ToolPredicate = Box<dyn Fn(&dyn Tool) -> bool + Send + Sync>;
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55 use crate::{Content, EventActions, ReadonlyContext, RunConfig};
56 use std::sync::Mutex;
57
58 struct TestTool {
59 name: String,
60 }
61
62 #[allow(dead_code)]
63 struct TestContext {
64 content: Content,
65 config: RunConfig,
66 actions: Mutex<EventActions>,
67 }
68
69 impl TestContext {
70 fn new() -> Self {
71 Self {
72 content: Content::new("user"),
73 config: RunConfig::default(),
74 actions: Mutex::new(EventActions::default()),
75 }
76 }
77 }
78
79 #[async_trait]
80 impl ReadonlyContext for TestContext {
81 fn invocation_id(&self) -> &str {
82 "test"
83 }
84 fn agent_name(&self) -> &str {
85 "test"
86 }
87 fn user_id(&self) -> &str {
88 "user"
89 }
90 fn app_name(&self) -> &str {
91 "app"
92 }
93 fn session_id(&self) -> &str {
94 "session"
95 }
96 fn branch(&self) -> &str {
97 ""
98 }
99 fn user_content(&self) -> &Content {
100 &self.content
101 }
102 }
103
104 #[async_trait]
105 impl CallbackContext for TestContext {
106 fn artifacts(&self) -> Option<Arc<dyn crate::Artifacts>> {
107 None
108 }
109 }
110
111 #[async_trait]
112 impl ToolContext for TestContext {
113 fn function_call_id(&self) -> &str {
114 "call-123"
115 }
116 fn actions(&self) -> EventActions {
117 self.actions.lock().unwrap().clone()
118 }
119 fn set_actions(&self, actions: EventActions) {
120 *self.actions.lock().unwrap() = actions;
121 }
122 async fn search_memory(&self, _query: &str) -> Result<Vec<crate::MemoryEntry>> {
123 Ok(vec![])
124 }
125 }
126
127 #[async_trait]
128 impl Tool for TestTool {
129 fn name(&self) -> &str {
130 &self.name
131 }
132
133 fn description(&self) -> &str {
134 "test tool"
135 }
136
137 async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
138 Ok(Value::String("result".to_string()))
139 }
140 }
141
142 #[test]
143 fn test_tool_trait() {
144 let tool = TestTool { name: "test".to_string() };
145 assert_eq!(tool.name(), "test");
146 assert_eq!(tool.description(), "test tool");
147 assert!(!tool.is_long_running());
148 }
149
150 #[tokio::test]
151 async fn test_tool_execute() {
152 let tool = TestTool { name: "test".to_string() };
153 let ctx = Arc::new(TestContext::new()) as Arc<dyn ToolContext>;
154 let result = tool.execute(ctx, Value::Null).await.unwrap();
155 assert_eq!(result, Value::String("result".to_string()));
156 }
157}