rs_adk/
text_agent_tool.rs1use std::sync::Arc;
12
13use async_trait::async_trait;
14use serde_json::json;
15
16use crate::error::ToolError;
17use crate::state::State;
18use crate::text::TextAgent;
19use crate::tool::ToolFunction;
20
21pub struct TextAgentTool {
47 name: String,
48 description: String,
49 agent: Arc<dyn TextAgent>,
50 parameters: serde_json::Value,
51 state: State,
52}
53
54impl TextAgentTool {
55 pub fn new(
60 name: impl Into<String>,
61 description: impl Into<String>,
62 agent: impl TextAgent + 'static,
63 state: State,
64 ) -> Self {
65 Self {
66 name: name.into(),
67 description: description.into(),
68 agent: Arc::new(agent),
69 parameters: json!({
70 "type": "object",
71 "properties": {
72 "request": {
73 "type": "string",
74 "description": "The request to process"
75 }
76 },
77 "required": ["request"]
78 }),
79 state,
80 }
81 }
82
83 pub fn from_arc(
85 name: impl Into<String>,
86 description: impl Into<String>,
87 agent: Arc<dyn TextAgent>,
88 state: State,
89 ) -> Self {
90 Self {
91 name: name.into(),
92 description: description.into(),
93 agent,
94 parameters: json!({
95 "type": "object",
96 "properties": {
97 "request": {
98 "type": "string",
99 "description": "The request to process"
100 }
101 },
102 "required": ["request"]
103 }),
104 state,
105 }
106 }
107
108 pub fn with_parameters(mut self, params: serde_json::Value) -> Self {
110 self.parameters = params;
111 self
112 }
113}
114
115#[async_trait]
116impl ToolFunction for TextAgentTool {
117 fn name(&self) -> &str {
118 &self.name
119 }
120
121 fn description(&self) -> &str {
122 &self.description
123 }
124
125 fn parameters(&self) -> Option<serde_json::Value> {
126 Some(self.parameters.clone())
127 }
128
129 async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
130 if let Some(request) = args.get("request").and_then(|r| r.as_str()) {
132 self.state.set("input", request);
133 }
134 self.state.set("agent_tool_args", &args);
135
136 let result = self
138 .agent
139 .run(&self.state)
140 .await
141 .map_err(|e| ToolError::ExecutionFailed(format!("{e}")))?;
142
143 Ok(json!({"result": result}))
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::error::AgentError;
152
153 struct EchoTextAgent;
155
156 #[async_trait]
157 impl TextAgent for EchoTextAgent {
158 fn name(&self) -> &str {
159 "echo"
160 }
161 async fn run(&self, state: &State) -> Result<String, AgentError> {
162 let input = state
163 .get::<String>("input")
164 .unwrap_or_else(|| "no input".into());
165 Ok(format!("Echo: {input}"))
166 }
167 }
168
169 struct StatefulAgent;
171
172 #[async_trait]
173 impl TextAgent for StatefulAgent {
174 fn name(&self) -> &str {
175 "stateful"
176 }
177 async fn run(&self, state: &State) -> Result<String, AgentError> {
178 let parent_val = state
180 .get::<String>("parent_key")
181 .unwrap_or_else(|| "missing".into());
182
183 state.set("child_wrote", true);
185 state.set("child_output", "from child agent");
186
187 Ok(format!("Parent said: {parent_val}"))
188 }
189 }
190
191 struct FailingTextAgent;
193
194 #[async_trait]
195 impl TextAgent for FailingTextAgent {
196 fn name(&self) -> &str {
197 "failing"
198 }
199 async fn run(&self, _state: &State) -> Result<String, AgentError> {
200 Err(AgentError::Other("intentional failure".into()))
201 }
202 }
203
204 #[tokio::test]
205 async fn basic_dispatch() {
206 let state = State::new();
207 let tool = TextAgentTool::new("echo", "Echo tool", EchoTextAgent, state);
208
209 let result = tool.call(json!({"request": "hello"})).await.unwrap();
210 assert_eq!(result["result"], "Echo: hello");
211 }
212
213 #[tokio::test]
214 async fn tool_metadata() {
215 let state = State::new();
216 let tool = TextAgentTool::new("my_tool", "Does things", EchoTextAgent, state);
217
218 assert_eq!(tool.name(), "my_tool");
219 assert_eq!(tool.description(), "Does things");
220 assert!(tool.parameters().is_some());
221 let params = tool.parameters().unwrap();
222 assert_eq!(params["type"], "object");
223 assert!(params["properties"]["request"].is_object());
224 }
225
226 #[tokio::test]
227 async fn state_shared_bidirectionally() {
228 let state = State::new();
229 state.set("parent_key", "hello from parent");
230
231 let tool = TextAgentTool::new("stateful", "Stateful tool", StatefulAgent, state.clone());
232
233 let result = tool.call(json!({"request": "test"})).await.unwrap();
234 assert_eq!(result["result"], "Parent said: hello from parent");
235
236 assert_eq!(state.get::<bool>("child_wrote"), Some(true));
238 assert_eq!(
239 state.get::<String>("child_output"),
240 Some("from child agent".into())
241 );
242 }
243
244 #[tokio::test]
245 async fn error_propagation() {
246 let state = State::new();
247 let tool = TextAgentTool::new("failing", "Fails", FailingTextAgent, state);
248
249 let result = tool.call(json!({"request": "test"})).await;
250 assert!(result.is_err());
251 match result.unwrap_err() {
252 ToolError::ExecutionFailed(msg) => {
253 assert!(msg.contains("intentional failure"));
254 }
255 other => panic!("expected ExecutionFailed, got: {other:?}"),
256 }
257 }
258
259 #[tokio::test]
260 async fn custom_parameters() {
261 let state = State::new();
262 let params = json!({
263 "type": "object",
264 "properties": {
265 "query": { "type": "string" },
266 "limit": { "type": "integer" }
267 }
268 });
269 let tool = TextAgentTool::new("custom", "Custom params", EchoTextAgent, state)
270 .with_parameters(params.clone());
271
272 assert_eq!(tool.parameters().unwrap(), params);
273 }
274
275 #[tokio::test]
276 async fn args_injected_into_state() {
277 let state = State::new();
278 let tool = TextAgentTool::new("echo", "Echo", EchoTextAgent, state.clone());
279
280 let _ = tool.call(json!({"request": "injected"})).await.unwrap();
281
282 assert_eq!(state.get::<String>("input"), Some("injected".into()));
284 let args = state.get::<serde_json::Value>("agent_tool_args").unwrap();
285 assert_eq!(args["request"], "injected");
286 }
287
288 #[tokio::test]
289 async fn from_arc_constructor() {
290 let state = State::new();
291 let agent: Arc<dyn TextAgent> = Arc::new(EchoTextAgent);
292 let tool = TextAgentTool::from_arc("echo", "Echo tool", agent, state);
293
294 let result = tool.call(json!({"request": "arc test"})).await.unwrap();
295 assert_eq!(result["result"], "Echo: arc test");
296 }
297}