1use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::json;
14use tokio::sync::broadcast;
15
16use rs_genai::session::SessionEvent;
17
18use crate::agent::Agent;
19use crate::agent_session::{AgentSession, NoOpSessionWriter};
20use crate::context::{AgentEvent, InvocationContext};
21use crate::error::ToolError;
22use crate::tool::ToolFunction;
23
24pub struct AgentTool {
30 agent: Arc<dyn Agent>,
31 description: String,
32 parameters: Option<serde_json::Value>,
33}
34
35impl AgentTool {
36 pub fn new(agent: impl Agent + 'static) -> Self {
38 let description = format!("Delegate to the {} agent", agent.name());
39 Self {
40 agent: Arc::new(agent),
41 description,
42 parameters: Some(json!({
43 "type": "object",
44 "properties": {
45 "request": {
46 "type": "string",
47 "description": "The request to send to the agent"
48 }
49 },
50 "required": ["request"]
51 })),
52 }
53 }
54
55 pub fn from_arc(agent: Arc<dyn Agent>) -> Self {
57 let description = format!("Delegate to the {} agent", agent.name());
58 Self {
59 agent,
60 description,
61 parameters: Some(json!({
62 "type": "object",
63 "properties": {
64 "request": {
65 "type": "string",
66 "description": "The request to send to the agent"
67 }
68 },
69 "required": ["request"]
70 })),
71 }
72 }
73
74 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
76 self.description = desc.into();
77 self
78 }
79
80 pub fn with_parameters(mut self, params: serde_json::Value) -> Self {
82 self.parameters = Some(params);
83 self
84 }
85}
86
87#[async_trait]
88impl ToolFunction for AgentTool {
89 fn name(&self) -> &str {
90 self.agent.name()
91 }
92
93 fn description(&self) -> &str {
94 &self.description
95 }
96
97 fn parameters(&self) -> Option<serde_json::Value> {
98 self.parameters.clone()
99 }
100
101 async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
102 let start = std::time::Instant::now();
103 let agent_name = self.agent.name().to_string();
104
105 crate::telemetry::logging::log_agent_tool_dispatch("parent", &agent_name);
107
108 let (event_tx, _) = broadcast::channel::<SessionEvent>(64);
110 let noop_writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
111 let isolated_session = AgentSession::from_writer(noop_writer, event_tx);
112
113 if let Some(request) = args.get("request").and_then(|r| r.as_str()) {
115 isolated_session.state().set("request_text", request);
116 }
117 isolated_session.state().set("request", &args);
118
119 let mut ctx = InvocationContext::new(isolated_session);
121
122 let mut events = ctx.subscribe();
124
125 let agent = self.agent.clone();
127 let run_result = tokio::spawn(async move { agent.run_live(&mut ctx).await }).await;
128
129 let mut output_parts = Vec::new();
131 while let Ok(event) = events.try_recv() {
132 match event {
133 AgentEvent::Session(SessionEvent::TextDelta(text)) => {
134 output_parts.push(text);
135 }
136 AgentEvent::Session(SessionEvent::TextComplete(text)) => {
137 if output_parts.is_empty() {
138 output_parts.push(text);
139 }
140 }
143 _ => {}
144 }
145 }
146
147 let elapsed = start.elapsed();
148 crate::telemetry::metrics::record_agent_tool_dispatch(
149 "parent",
150 &agent_name,
151 elapsed.as_millis() as f64,
152 );
153
154 match run_result {
156 Ok(Ok(())) => {
157 let output = if output_parts.is_empty() {
158 json!({"status": "completed"})
159 } else {
160 json!({"result": output_parts.join("")})
161 };
162 Ok(output)
163 }
164 Ok(Err(e)) => Err(ToolError::ExecutionFailed(format!(
165 "Agent '{}' failed: {}",
166 agent_name, e
167 ))),
168 Err(e) => Err(ToolError::ExecutionFailed(format!(
169 "Agent '{}' task panicked: {}",
170 agent_name, e
171 ))),
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::error::AgentError;
180
181 struct EchoAgent {
182 name: String,
183 }
184
185 #[async_trait]
186 impl Agent for EchoAgent {
187 fn name(&self) -> &str {
188 &self.name
189 }
190 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
191 let request = ctx
193 .state()
194 .get::<String>("request_text")
195 .unwrap_or_else(|| "no request".to_string());
196 ctx.emit(AgentEvent::Session(SessionEvent::TextDelta(format!(
197 "Echo: {}",
198 request
199 ))));
200 ctx.emit(AgentEvent::Session(SessionEvent::TurnComplete));
201 Ok(())
202 }
203 }
204
205 struct FailingAgent;
206
207 #[async_trait]
208 impl Agent for FailingAgent {
209 fn name(&self) -> &str {
210 "failing"
211 }
212 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
213 Err(AgentError::Other("intentional failure".to_string()))
214 }
215 }
216
217 struct SilentAgent;
218
219 #[async_trait]
220 impl Agent for SilentAgent {
221 fn name(&self) -> &str {
222 "silent"
223 }
224 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
225 Ok(())
226 }
227 }
228
229 #[tokio::test]
230 async fn agent_tool_runs_agent_in_isolation() {
231 let agent = EchoAgent {
232 name: "echo".to_string(),
233 };
234 let tool = AgentTool::new(agent);
235
236 assert_eq!(tool.name(), "echo");
237 assert!(tool.description().contains("echo"));
238 }
239
240 #[tokio::test]
241 async fn agent_tool_collects_text_output() {
242 let agent = EchoAgent {
243 name: "echo".to_string(),
244 };
245 let tool = AgentTool::new(agent);
246
247 let result = tool.call(json!({"request": "hello world"})).await.unwrap();
248 assert_eq!(result["result"], "Echo: hello world");
249 }
250
251 #[tokio::test]
252 async fn agent_tool_propagates_errors() {
253 let tool = AgentTool::new(FailingAgent);
254 let result = tool.call(json!({"request": "test"})).await;
255 assert!(result.is_err());
256 let err = result.unwrap_err();
257 match err {
258 ToolError::ExecutionFailed(msg) => {
259 assert!(msg.contains("intentional failure"));
260 }
261 other => panic!("expected ExecutionFailed, got: {:?}", other),
262 }
263 }
264
265 #[tokio::test]
266 async fn agent_tool_returns_completed_when_no_output() {
267 let tool = AgentTool::new(SilentAgent);
268 let result = tool.call(json!({"request": "test"})).await.unwrap();
269 assert_eq!(result["status"], "completed");
270 }
271
272 #[tokio::test]
273 async fn agent_tool_state_injection() {
274 struct StateCheckAgent;
276
277 #[async_trait]
278 impl Agent for StateCheckAgent {
279 fn name(&self) -> &str {
280 "state_check"
281 }
282 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
283 let request_text = ctx.state().get::<String>("request_text");
284 let request = ctx.state().get::<serde_json::Value>("request");
285
286 assert!(request_text.is_some());
287 assert!(request.is_some());
288 assert_eq!(request_text.unwrap(), "check state");
289
290 ctx.emit(AgentEvent::Session(SessionEvent::TextDelta(
291 "state ok".to_string(),
292 )));
293 Ok(())
294 }
295 }
296
297 let tool = AgentTool::new(StateCheckAgent);
298 let result = tool.call(json!({"request": "check state"})).await.unwrap();
299 assert_eq!(result["result"], "state ok");
300 }
301
302 #[tokio::test]
303 async fn agent_tool_with_custom_description() {
304 let tool = AgentTool::new(SilentAgent).with_description("Custom description");
305 assert_eq!(tool.description(), "Custom description");
306 }
307
308 #[tokio::test]
309 async fn agent_tool_with_custom_parameters() {
310 let params = json!({
311 "type": "object",
312 "properties": {
313 "query": { "type": "string" }
314 }
315 });
316 let tool = AgentTool::new(SilentAgent).with_parameters(params.clone());
317 assert_eq!(tool.parameters().unwrap(), params);
318 }
319}