adk_tool/
agent_tool.rs

1//! AgentTool - Use agents as callable tools
2//!
3//! This module provides `AgentTool` which wraps an `Agent` instance to make it
4//! callable as a `Tool`. This enables powerful composition patterns where a
5//! coordinator agent can invoke specialized sub-agents.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use adk_tool::AgentTool;
11//! use adk_agent::LlmAgentBuilder;
12//!
13//! // Create a specialized agent
14//! let math_agent = LlmAgentBuilder::new("math_expert")
15//!     .description("Solves mathematical problems")
16//!     .instruction("You are a math expert. Solve problems step by step.")
17//!     .model(model.clone())
18//!     .build()?;
19//!
20//! // Wrap it as a tool
21//! let math_tool = AgentTool::new(Arc::new(math_agent));
22//!
23//! // Use in coordinator agent
24//! let coordinator = LlmAgentBuilder::new("coordinator")
25//!     .instruction("Help users by delegating to specialists")
26//!     .tools(vec![Arc::new(math_tool)])
27//!     .build()?;
28//! ```
29
30use adk_core::{
31    Agent, Artifacts, CallbackContext, Content, Event, InvocationContext, Memory, Part,
32    ReadonlyContext, Result, RunConfig, Session, State, Tool, ToolContext,
33};
34use async_trait::async_trait;
35use futures::StreamExt;
36use serde_json::{json, Value};
37use std::collections::HashMap;
38use std::sync::{atomic::AtomicBool, Arc};
39use std::time::Duration;
40
41/// Configuration options for AgentTool behavior.
42#[derive(Debug, Clone)]
43pub struct AgentToolConfig {
44    /// Skip summarization after sub-agent execution.
45    /// When true, returns the raw output from the sub-agent.
46    pub skip_summarization: bool,
47
48    /// Forward artifacts between parent and sub-agent.
49    /// When true, the sub-agent can access parent's artifacts.
50    pub forward_artifacts: bool,
51
52    /// Optional timeout for sub-agent execution.
53    pub timeout: Option<Duration>,
54
55    /// Custom input schema for the tool.
56    /// If None, defaults to `{"request": "string"}`.
57    pub input_schema: Option<Value>,
58
59    /// Custom output schema for the tool.
60    pub output_schema: Option<Value>,
61}
62
63impl Default for AgentToolConfig {
64    fn default() -> Self {
65        Self {
66            skip_summarization: false,
67            forward_artifacts: true,
68            timeout: None,
69            input_schema: None,
70            output_schema: None,
71        }
72    }
73}
74
75/// AgentTool wraps an Agent to make it callable as a Tool.
76///
77/// When the parent LLM generates a function call targeting this tool,
78/// the framework executes the wrapped agent, captures its final response,
79/// and returns it as the tool's result.
80pub struct AgentTool {
81    agent: Arc<dyn Agent>,
82    config: AgentToolConfig,
83}
84
85impl AgentTool {
86    /// Create a new AgentTool wrapping the given agent.
87    pub fn new(agent: Arc<dyn Agent>) -> Self {
88        Self { agent, config: AgentToolConfig::default() }
89    }
90
91    /// Create a new AgentTool with custom configuration.
92    pub fn with_config(agent: Arc<dyn Agent>, config: AgentToolConfig) -> Self {
93        Self { agent, config }
94    }
95
96    /// Set whether to skip summarization.
97    pub fn skip_summarization(mut self, skip: bool) -> Self {
98        self.config.skip_summarization = skip;
99        self
100    }
101
102    /// Set whether to forward artifacts.
103    pub fn forward_artifacts(mut self, forward: bool) -> Self {
104        self.config.forward_artifacts = forward;
105        self
106    }
107
108    /// Set timeout for sub-agent execution.
109    pub fn timeout(mut self, timeout: Duration) -> Self {
110        self.config.timeout = Some(timeout);
111        self
112    }
113
114    /// Set custom input schema.
115    pub fn input_schema(mut self, schema: Value) -> Self {
116        self.config.input_schema = Some(schema);
117        self
118    }
119
120    /// Set custom output schema.
121    pub fn output_schema(mut self, schema: Value) -> Self {
122        self.config.output_schema = Some(schema);
123        self
124    }
125
126    /// Generate the default parameters schema for this agent tool.
127    fn default_parameters_schema(&self) -> Value {
128        json!({
129            "type": "object",
130            "properties": {
131                "request": {
132                    "type": "string",
133                    "description": format!("The request to send to the {} agent", self.agent.name())
134                }
135            },
136            "required": ["request"]
137        })
138    }
139
140    /// Extract the request text from the tool arguments.
141    fn extract_request(&self, args: &Value) -> String {
142        // Try to get "request" field first
143        if let Some(request) = args.get("request").and_then(|v| v.as_str()) {
144            return request.to_string();
145        }
146
147        // If custom schema, try to serialize the whole args
148        if self.config.input_schema.is_some() {
149            return serde_json::to_string(args).unwrap_or_default();
150        }
151
152        // Fallback: convert args to string
153        match args {
154            Value::String(s) => s.clone(),
155            Value::Object(map) => {
156                // Try to find any string field
157                for value in map.values() {
158                    if let Value::String(s) = value {
159                        return s.clone();
160                    }
161                }
162                serde_json::to_string(args).unwrap_or_default()
163            }
164            _ => serde_json::to_string(args).unwrap_or_default(),
165        }
166    }
167
168    /// Extract the final response text from agent events.
169    fn extract_response(events: &[Event]) -> Value {
170        // Collect all text responses from final events
171        let mut responses = Vec::new();
172
173        for event in events.iter().rev() {
174            if event.is_final_response() {
175                if let Some(content) = &event.llm_response.content {
176                    for part in &content.parts {
177                        if let Part::Text { text } = part {
178                            responses.push(text.clone());
179                        }
180                    }
181                }
182                break; // Only get the last final response
183            }
184        }
185
186        if responses.is_empty() {
187            // Try to get any text from the last event
188            if let Some(last_event) = events.last() {
189                if let Some(content) = &last_event.llm_response.content {
190                    for part in &content.parts {
191                        if let Part::Text { text } = part {
192                            return json!({ "response": text });
193                        }
194                    }
195                }
196            }
197            json!({ "response": "No response from agent" })
198        } else {
199            json!({ "response": responses.join("\n") })
200        }
201    }
202}
203
204#[async_trait]
205impl Tool for AgentTool {
206    fn name(&self) -> &str {
207        self.agent.name()
208    }
209
210    fn description(&self) -> &str {
211        self.agent.description()
212    }
213
214    fn parameters_schema(&self) -> Option<Value> {
215        Some(self.config.input_schema.clone().unwrap_or_else(|| self.default_parameters_schema()))
216    }
217
218    fn response_schema(&self) -> Option<Value> {
219        self.config.output_schema.clone()
220    }
221
222    fn is_long_running(&self) -> bool {
223        // Agent execution could take time, but we wait for completion
224        false
225    }
226
227    #[adk_telemetry::instrument(
228        skip(self, ctx, args),
229        fields(
230            agent_tool.name = %self.agent.name(),
231            agent_tool.description = %self.agent.description(),
232            function_call.id = %ctx.function_call_id()
233        )
234    )]
235    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
236        adk_telemetry::debug!("Executing agent tool: {}", self.agent.name());
237
238        // Extract the request from args
239        let request_text = self.extract_request(&args);
240
241        // Create user content for the sub-agent
242        let user_content = Content::new("user").with_text(&request_text);
243
244        // Create an isolated context for the sub-agent
245        let sub_ctx = Arc::new(AgentToolInvocationContext::new(
246            ctx.clone(),
247            self.agent.clone(),
248            user_content.clone(),
249            self.config.forward_artifacts,
250        ));
251
252        // Execute the sub-agent
253        let execution = async {
254            let mut event_stream = self.agent.run(sub_ctx.clone()).await?;
255
256            // Collect all events
257            let mut events = Vec::new();
258            let mut state_delta = HashMap::new();
259            let mut artifact_delta = HashMap::new();
260
261            while let Some(result) = event_stream.next().await {
262                match result {
263                    Ok(event) => {
264                        // Merge state deltas
265                        state_delta.extend(event.actions.state_delta.clone());
266                        artifact_delta.extend(event.actions.artifact_delta.clone());
267                        events.push(event);
268                    }
269                    Err(e) => {
270                        adk_telemetry::error!("Error in sub-agent execution: {}", e);
271                        return Err(e);
272                    }
273                }
274            }
275
276            Ok((events, state_delta, artifact_delta))
277        };
278
279        // Apply timeout if configured
280        let result = if let Some(timeout_duration) = self.config.timeout {
281            match tokio::time::timeout(timeout_duration, execution).await {
282                Ok(r) => r,
283                Err(_) => {
284                    return Ok(json!({
285                        "error": "Agent execution timed out",
286                        "agent": self.agent.name()
287                    }));
288                }
289            }
290        } else {
291            execution.await
292        };
293
294        match result {
295            Ok((events, _state_delta, _artifact_delta)) => {
296                // TODO: Forward state_delta and artifact_delta back to parent context
297                // This would require extending ToolContext or EventActions
298
299                // Extract and return the response
300                let response = Self::extract_response(&events);
301
302                adk_telemetry::debug!(
303                    "Agent tool {} completed with {} events",
304                    self.agent.name(),
305                    events.len()
306                );
307
308                Ok(response)
309            }
310            Err(e) => Ok(json!({
311                "error": format!("Agent execution failed: {}", e),
312                "agent": self.agent.name()
313            })),
314        }
315    }
316}
317
318// Internal context for sub-agent execution
319struct AgentToolInvocationContext {
320    parent_ctx: Arc<dyn ToolContext>,
321    agent: Arc<dyn Agent>,
322    user_content: Content,
323    invocation_id: String,
324    ended: Arc<AtomicBool>,
325    forward_artifacts: bool,
326    session: Arc<AgentToolSession>,
327}
328
329impl AgentToolInvocationContext {
330    fn new(
331        parent_ctx: Arc<dyn ToolContext>,
332        agent: Arc<dyn Agent>,
333        user_content: Content,
334        forward_artifacts: bool,
335    ) -> Self {
336        let invocation_id = format!("agent-tool-{}", uuid::Uuid::new_v4());
337        Self {
338            parent_ctx,
339            agent,
340            user_content,
341            invocation_id,
342            ended: Arc::new(AtomicBool::new(false)),
343            forward_artifacts,
344            session: Arc::new(AgentToolSession::new()),
345        }
346    }
347}
348
349#[async_trait]
350impl ReadonlyContext for AgentToolInvocationContext {
351    fn invocation_id(&self) -> &str {
352        &self.invocation_id
353    }
354
355    fn agent_name(&self) -> &str {
356        self.agent.name()
357    }
358
359    fn user_id(&self) -> &str {
360        self.parent_ctx.user_id()
361    }
362
363    fn app_name(&self) -> &str {
364        self.parent_ctx.app_name()
365    }
366
367    fn session_id(&self) -> &str {
368        // Use a unique session ID for the sub-agent
369        &self.invocation_id
370    }
371
372    fn branch(&self) -> &str {
373        ""
374    }
375
376    fn user_content(&self) -> &Content {
377        &self.user_content
378    }
379}
380
381#[async_trait]
382impl CallbackContext for AgentToolInvocationContext {
383    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
384        if self.forward_artifacts {
385            self.parent_ctx.artifacts()
386        } else {
387            None
388        }
389    }
390}
391
392#[async_trait]
393impl InvocationContext for AgentToolInvocationContext {
394    fn agent(&self) -> Arc<dyn Agent> {
395        self.agent.clone()
396    }
397
398    fn memory(&self) -> Option<Arc<dyn Memory>> {
399        // Sub-agents don't have direct memory access in this implementation
400        // Could be extended to forward memory if needed
401        None
402    }
403
404    fn session(&self) -> &dyn Session {
405        self.session.as_ref()
406    }
407
408    fn run_config(&self) -> &RunConfig {
409        // Use default config for sub-agent
410        // This is a static reference issue - we return a reference to a static default
411        static DEFAULT_CONFIG: RunConfig =
412            RunConfig { streaming_mode: adk_core::StreamingMode::Auto };
413        &DEFAULT_CONFIG
414    }
415
416    fn end_invocation(&self) {
417        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
418    }
419
420    fn ended(&self) -> bool {
421        self.ended.load(std::sync::atomic::Ordering::SeqCst)
422    }
423}
424
425// Minimal session for sub-agent execution
426struct AgentToolSession {
427    id: String,
428    state: std::sync::RwLock<HashMap<String, Value>>,
429}
430
431impl AgentToolSession {
432    fn new() -> Self {
433        Self {
434            id: format!("agent-tool-session-{}", uuid::Uuid::new_v4()),
435            state: Default::default(),
436        }
437    }
438}
439
440impl Session for AgentToolSession {
441    fn id(&self) -> &str {
442        &self.id
443    }
444
445    fn app_name(&self) -> &str {
446        "agent-tool"
447    }
448
449    fn user_id(&self) -> &str {
450        "agent-tool-user"
451    }
452
453    fn state(&self) -> &dyn State {
454        self
455    }
456
457    fn conversation_history(&self) -> Vec<Content> {
458        // Sub-agent starts with empty history
459        Vec::new()
460    }
461}
462
463impl State for AgentToolSession {
464    fn get(&self, key: &str) -> Option<Value> {
465        self.state.read().ok()?.get(key).cloned()
466    }
467
468    fn set(&mut self, key: String, value: Value) {
469        if let Ok(mut state) = self.state.write() {
470            state.insert(key, value);
471        }
472    }
473
474    fn all(&self) -> HashMap<String, Value> {
475        self.state.read().ok().map(|s| s.clone()).unwrap_or_default()
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    struct MockAgent {
484        name: String,
485        description: String,
486    }
487
488    #[async_trait]
489    impl Agent for MockAgent {
490        fn name(&self) -> &str {
491            &self.name
492        }
493
494        fn description(&self) -> &str {
495            &self.description
496        }
497
498        fn sub_agents(&self) -> &[Arc<dyn Agent>] {
499            &[]
500        }
501
502        async fn run(&self, _ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
503            use async_stream::stream;
504
505            let name = self.name.clone();
506            let s = stream! {
507                let mut event = Event::new("mock-inv");
508                event.author = name;
509                event.llm_response.content = Some(Content::new("model").with_text("Mock response"));
510                yield Ok(event);
511            };
512
513            Ok(Box::pin(s))
514        }
515    }
516
517    #[test]
518    fn test_agent_tool_creation() {
519        let agent = Arc::new(MockAgent {
520            name: "test_agent".to_string(),
521            description: "A test agent".to_string(),
522        });
523
524        let tool = AgentTool::new(agent);
525        assert_eq!(tool.name(), "test_agent");
526        assert_eq!(tool.description(), "A test agent");
527    }
528
529    #[test]
530    fn test_agent_tool_config() {
531        let agent =
532            Arc::new(MockAgent { name: "test".to_string(), description: "test".to_string() });
533
534        let tool = AgentTool::new(agent)
535            .skip_summarization(true)
536            .forward_artifacts(false)
537            .timeout(Duration::from_secs(30));
538
539        assert!(tool.config.skip_summarization);
540        assert!(!tool.config.forward_artifacts);
541        assert_eq!(tool.config.timeout, Some(Duration::from_secs(30)));
542    }
543
544    #[test]
545    fn test_parameters_schema() {
546        let agent = Arc::new(MockAgent {
547            name: "calculator".to_string(),
548            description: "Performs calculations".to_string(),
549        });
550
551        let tool = AgentTool::new(agent);
552        let schema = tool.parameters_schema().unwrap();
553
554        assert_eq!(schema["type"], "object");
555        assert!(schema["properties"]["request"].is_object());
556    }
557
558    #[test]
559    fn test_extract_request() {
560        let agent =
561            Arc::new(MockAgent { name: "test".to_string(), description: "test".to_string() });
562
563        let tool = AgentTool::new(agent);
564
565        // Test with request field
566        let args = json!({"request": "solve 2+2"});
567        assert_eq!(tool.extract_request(&args), "solve 2+2");
568
569        // Test with string value
570        let args = json!("direct request");
571        assert_eq!(tool.extract_request(&args), "direct request");
572    }
573
574    #[test]
575    fn test_extract_response() {
576        let mut event = Event::new("inv-123");
577        event.llm_response.content = Some(Content::new("model").with_text("The answer is 4"));
578
579        let events = vec![event];
580        let response = AgentTool::extract_response(&events);
581
582        assert_eq!(response["response"], "The answer is 4");
583    }
584}