Skip to main content

rs_adk/
agent_tool.rs

1//! AgentTool — wraps an Agent as a ToolFunction for "agent as a tool" dispatch.
2//!
3//! When the live model calls this tool, the wrapped agent runs in an isolated
4//! context (no live WebSocket). The agent's text output is collected and returned
5//! as the tool result. State changes propagate back to the parent context.
6//!
7//! This bridges live<->non-live: the wrapped agent can use regular Gemini API,
8//! external services, or pure computation — it doesn't need a WebSocket.
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::json;
14use tokio::sync::broadcast;
15
16use rs_genai::session::SessionEvent;
17
18use crate::agent::Agent;
19use crate::agent_session::{AgentSession, NoOpSessionWriter};
20use crate::context::{AgentEvent, InvocationContext};
21use crate::error::ToolError;
22use crate::tool::ToolFunction;
23
24/// Wraps an Agent as a ToolFunction for "agent as a tool" dispatch.
25///
26/// When the live model calls this tool, the wrapped agent runs in an isolated
27/// context (no live WebSocket). The agent's text output is collected and returned
28/// as the tool result.
29pub struct AgentTool {
30    agent: Arc<dyn Agent>,
31    description: String,
32    parameters: Option<serde_json::Value>,
33}
34
35impl AgentTool {
36    /// Create a new AgentTool wrapping the given agent.
37    pub fn new(agent: impl Agent + 'static) -> Self {
38        let description = format!("Delegate to the {} agent", agent.name());
39        Self {
40            agent: Arc::new(agent),
41            description,
42            parameters: Some(json!({
43                "type": "object",
44                "properties": {
45                    "request": {
46                        "type": "string",
47                        "description": "The request to send to the agent"
48                    }
49                },
50                "required": ["request"]
51            })),
52        }
53    }
54
55    /// Create from an already-Arc'd agent.
56    pub fn from_arc(agent: Arc<dyn Agent>) -> Self {
57        let description = format!("Delegate to the {} agent", agent.name());
58        Self {
59            agent,
60            description,
61            parameters: Some(json!({
62                "type": "object",
63                "properties": {
64                    "request": {
65                        "type": "string",
66                        "description": "The request to send to the agent"
67                    }
68                },
69                "required": ["request"]
70            })),
71        }
72    }
73
74    /// Override the tool description.
75    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
76        self.description = desc.into();
77        self
78    }
79
80    /// Override the tool parameters schema.
81    pub fn with_parameters(mut self, params: serde_json::Value) -> Self {
82        self.parameters = Some(params);
83        self
84    }
85}
86
87#[async_trait]
88impl ToolFunction for AgentTool {
89    fn name(&self) -> &str {
90        self.agent.name()
91    }
92
93    fn description(&self) -> &str {
94        &self.description
95    }
96
97    fn parameters(&self) -> Option<serde_json::Value> {
98        self.parameters.clone()
99    }
100
101    async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
102        let start = std::time::Instant::now();
103        let agent_name = self.agent.name().to_string();
104
105        // Telemetry
106        crate::telemetry::logging::log_agent_tool_dispatch("parent", &agent_name);
107
108        // 1. Create isolated context with NoOpSessionWriter
109        let (event_tx, _) = broadcast::channel::<SessionEvent>(64);
110        let noop_writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
111        let isolated_session = AgentSession::from_writer(noop_writer, event_tx);
112
113        // 2. Inject args into state
114        if let Some(request) = args.get("request").and_then(|r| r.as_str()) {
115            isolated_session.state().set("request_text", request);
116        }
117        isolated_session.state().set("request", &args);
118
119        // 3. Create isolated InvocationContext
120        let mut ctx = InvocationContext::new(isolated_session);
121
122        // 4. Subscribe to events before running (to collect text output)
123        let mut events = ctx.subscribe();
124
125        // 5. Run the agent
126        let agent = self.agent.clone();
127        let run_result = tokio::spawn(async move { agent.run_live(&mut ctx).await }).await;
128
129        // 6. Collect text output from events
130        let mut output_parts = Vec::new();
131        while let Ok(event) = events.try_recv() {
132            match event {
133                AgentEvent::Session(SessionEvent::TextDelta(text)) => {
134                    output_parts.push(text);
135                }
136                AgentEvent::Session(SessionEvent::TextComplete(text)) => {
137                    if output_parts.is_empty() {
138                        output_parts.push(text);
139                    }
140                    // If we already have deltas, TextComplete is the full assembled text
141                    // Don't double-count — deltas already captured incrementally
142                }
143                _ => {}
144            }
145        }
146
147        let elapsed = start.elapsed();
148        crate::telemetry::metrics::record_agent_tool_dispatch(
149            "parent",
150            &agent_name,
151            elapsed.as_millis() as f64,
152        );
153
154        // 7. Handle result
155        match run_result {
156            Ok(Ok(())) => {
157                let output = if output_parts.is_empty() {
158                    json!({"status": "completed"})
159                } else {
160                    json!({"result": output_parts.join("")})
161                };
162                Ok(output)
163            }
164            Ok(Err(e)) => Err(ToolError::ExecutionFailed(format!(
165                "Agent '{}' failed: {}",
166                agent_name, e
167            ))),
168            Err(e) => Err(ToolError::ExecutionFailed(format!(
169                "Agent '{}' task panicked: {}",
170                agent_name, e
171            ))),
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::error::AgentError;
180
181    struct EchoAgent {
182        name: String,
183    }
184
185    #[async_trait]
186    impl Agent for EchoAgent {
187        fn name(&self) -> &str {
188            &self.name
189        }
190        async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
191            // Read the request from state and echo it back as a text event
192            let request = ctx
193                .state()
194                .get::<String>("request_text")
195                .unwrap_or_else(|| "no request".to_string());
196            ctx.emit(AgentEvent::Session(SessionEvent::TextDelta(format!(
197                "Echo: {}",
198                request
199            ))));
200            ctx.emit(AgentEvent::Session(SessionEvent::TurnComplete));
201            Ok(())
202        }
203    }
204
205    struct FailingAgent;
206
207    #[async_trait]
208    impl Agent for FailingAgent {
209        fn name(&self) -> &str {
210            "failing"
211        }
212        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
213            Err(AgentError::Other("intentional failure".to_string()))
214        }
215    }
216
217    struct SilentAgent;
218
219    #[async_trait]
220    impl Agent for SilentAgent {
221        fn name(&self) -> &str {
222            "silent"
223        }
224        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
225            Ok(())
226        }
227    }
228
229    #[tokio::test]
230    async fn agent_tool_runs_agent_in_isolation() {
231        let agent = EchoAgent {
232            name: "echo".to_string(),
233        };
234        let tool = AgentTool::new(agent);
235
236        assert_eq!(tool.name(), "echo");
237        assert!(tool.description().contains("echo"));
238    }
239
240    #[tokio::test]
241    async fn agent_tool_collects_text_output() {
242        let agent = EchoAgent {
243            name: "echo".to_string(),
244        };
245        let tool = AgentTool::new(agent);
246
247        let result = tool.call(json!({"request": "hello world"})).await.unwrap();
248        assert_eq!(result["result"], "Echo: hello world");
249    }
250
251    #[tokio::test]
252    async fn agent_tool_propagates_errors() {
253        let tool = AgentTool::new(FailingAgent);
254        let result = tool.call(json!({"request": "test"})).await;
255        assert!(result.is_err());
256        let err = result.unwrap_err();
257        match err {
258            ToolError::ExecutionFailed(msg) => {
259                assert!(msg.contains("intentional failure"));
260            }
261            other => panic!("expected ExecutionFailed, got: {:?}", other),
262        }
263    }
264
265    #[tokio::test]
266    async fn agent_tool_returns_completed_when_no_output() {
267        let tool = AgentTool::new(SilentAgent);
268        let result = tool.call(json!({"request": "test"})).await.unwrap();
269        assert_eq!(result["status"], "completed");
270    }
271
272    #[tokio::test]
273    async fn agent_tool_state_injection() {
274        // Verify that args are injected into state
275        struct StateCheckAgent;
276
277        #[async_trait]
278        impl Agent for StateCheckAgent {
279            fn name(&self) -> &str {
280                "state_check"
281            }
282            async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
283                let request_text = ctx.state().get::<String>("request_text");
284                let request = ctx.state().get::<serde_json::Value>("request");
285
286                assert!(request_text.is_some());
287                assert!(request.is_some());
288                assert_eq!(request_text.unwrap(), "check state");
289
290                ctx.emit(AgentEvent::Session(SessionEvent::TextDelta(
291                    "state ok".to_string(),
292                )));
293                Ok(())
294            }
295        }
296
297        let tool = AgentTool::new(StateCheckAgent);
298        let result = tool.call(json!({"request": "check state"})).await.unwrap();
299        assert_eq!(result["result"], "state ok");
300    }
301
302    #[tokio::test]
303    async fn agent_tool_with_custom_description() {
304        let tool = AgentTool::new(SilentAgent).with_description("Custom description");
305        assert_eq!(tool.description(), "Custom description");
306    }
307
308    #[tokio::test]
309    async fn agent_tool_with_custom_parameters() {
310        let params = json!({
311            "type": "object",
312            "properties": {
313                "query": { "type": "string" }
314            }
315        });
316        let tool = AgentTool::new(SilentAgent).with_parameters(params.clone());
317        assert_eq!(tool.parameters().unwrap(), params);
318    }
319}