1use crate::{CallbackContext, EventActions, MemoryEntry, Result};
2use async_trait::async_trait;
3use serde_json::Value;
4use std::sync::Arc;
5
6#[async_trait]
7pub trait Tool: Send + Sync {
8 fn name(&self) -> &str;
9 fn description(&self) -> &str;
10
11 fn enhanced_description(&self) -> String {
16 self.description().to_string()
17 }
18
19 fn is_long_running(&self) -> bool {
23 false
24 }
25 fn parameters_schema(&self) -> Option<Value> {
26 None
27 }
28 fn response_schema(&self) -> Option<Value> {
29 None
30 }
31
32 fn required_scopes(&self) -> &[&str] {
46 &[]
47 }
48
49 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value>;
50}
51
52#[async_trait]
53pub trait ToolContext: CallbackContext {
54 fn function_call_id(&self) -> &str;
55 fn actions(&self) -> EventActions;
57 fn set_actions(&self, actions: EventActions);
59 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>>;
60
61 fn user_scopes(&self) -> Vec<String> {
68 vec![]
69 }
70}
71
72#[async_trait]
73pub trait Toolset: Send + Sync {
74 fn name(&self) -> &str;
75 async fn tools(&self, ctx: Arc<dyn crate::ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>>;
76}
77
78pub type ToolPredicate = Box<dyn Fn(&dyn Tool) -> bool + Send + Sync>;
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::{Content, EventActions, ReadonlyContext, RunConfig};
84 use std::sync::Mutex;
85
86 struct TestTool {
87 name: String,
88 }
89
90 #[allow(dead_code)]
91 struct TestContext {
92 content: Content,
93 config: RunConfig,
94 actions: Mutex<EventActions>,
95 }
96
97 impl TestContext {
98 fn new() -> Self {
99 Self {
100 content: Content::new("user"),
101 config: RunConfig::default(),
102 actions: Mutex::new(EventActions::default()),
103 }
104 }
105 }
106
107 #[async_trait]
108 impl ReadonlyContext for TestContext {
109 fn invocation_id(&self) -> &str {
110 "test"
111 }
112 fn agent_name(&self) -> &str {
113 "test"
114 }
115 fn user_id(&self) -> &str {
116 "user"
117 }
118 fn app_name(&self) -> &str {
119 "app"
120 }
121 fn session_id(&self) -> &str {
122 "session"
123 }
124 fn branch(&self) -> &str {
125 ""
126 }
127 fn user_content(&self) -> &Content {
128 &self.content
129 }
130 }
131
132 #[async_trait]
133 impl CallbackContext for TestContext {
134 fn artifacts(&self) -> Option<Arc<dyn crate::Artifacts>> {
135 None
136 }
137 }
138
139 #[async_trait]
140 impl ToolContext for TestContext {
141 fn function_call_id(&self) -> &str {
142 "call-123"
143 }
144 fn actions(&self) -> EventActions {
145 self.actions.lock().unwrap().clone()
146 }
147 fn set_actions(&self, actions: EventActions) {
148 *self.actions.lock().unwrap() = actions;
149 }
150 async fn search_memory(&self, _query: &str) -> Result<Vec<crate::MemoryEntry>> {
151 Ok(vec![])
152 }
153 }
154
155 #[async_trait]
156 impl Tool for TestTool {
157 fn name(&self) -> &str {
158 &self.name
159 }
160
161 fn description(&self) -> &str {
162 "test tool"
163 }
164
165 async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
166 Ok(Value::String("result".to_string()))
167 }
168 }
169
170 #[test]
171 fn test_tool_trait() {
172 let tool = TestTool { name: "test".to_string() };
173 assert_eq!(tool.name(), "test");
174 assert_eq!(tool.description(), "test tool");
175 assert!(!tool.is_long_running());
176 }
177
178 #[tokio::test]
179 async fn test_tool_execute() {
180 let tool = TestTool { name: "test".to_string() };
181 let ctx = Arc::new(TestContext::new()) as Arc<dyn ToolContext>;
182 let result = tool.execute(ctx, Value::Null).await.unwrap();
183 assert_eq!(result, Value::String("result".to_string()));
184 }
185}