Skip to main content

adk_tool/
agent_tool.rs

1//! AgentTool - Use agents as callable tools
2//!
3//! This module provides `AgentTool` which wraps an `Agent` instance to make it
4//! callable as a `Tool`. This enables powerful composition patterns where a
5//! coordinator agent can invoke specialized sub-agents.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use adk_tool::AgentTool;
11//! use adk_agent::LlmAgentBuilder;
12//!
13//! // Create a specialized agent
14//! let math_agent = LlmAgentBuilder::new("math_expert")
15//!     .description("Solves mathematical problems")
16//!     .instruction("You are a math expert. Solve problems step by step.")
17//!     .model(model.clone())
18//!     .build()?;
19//!
20//! // Wrap it as a tool
21//! let math_tool = AgentTool::new(Arc::new(math_agent));
22//!
23//! // Use in coordinator agent
24//! let coordinator = LlmAgentBuilder::new("coordinator")
25//!     .instruction("Help users by delegating to specialists")
26//!     .tools(vec![Arc::new(math_tool)])
27//!     .build()?;
28//! ```
29
30use adk_core::{
31    Agent, Artifacts, CallbackContext, Content, Event, InvocationContext, Memory, Part,
32    ReadonlyContext, Result, RunConfig, Session, State, Tool, ToolContext,
33};
34use async_trait::async_trait;
35use futures::StreamExt;
36use serde_json::{Value, json};
37use std::collections::HashMap;
38use std::sync::{Arc, atomic::AtomicBool};
39use std::time::Duration;
40
41/// Configuration options for AgentTool behavior.
42#[derive(Debug, Clone)]
43pub struct AgentToolConfig {
44    /// Skip summarization after sub-agent execution.
45    /// When true, returns the raw output from the sub-agent.
46    pub skip_summarization: bool,
47
48    /// Forward artifacts between parent and sub-agent.
49    /// When true, the sub-agent can access parent's artifacts.
50    pub forward_artifacts: bool,
51
52    /// Optional timeout for sub-agent execution.
53    pub timeout: Option<Duration>,
54
55    /// Custom input schema for the tool.
56    /// If None, defaults to `{"request": "string"}`.
57    pub input_schema: Option<Value>,
58
59    /// Custom output schema for the tool.
60    pub output_schema: Option<Value>,
61}
62
63impl Default for AgentToolConfig {
64    fn default() -> Self {
65        Self {
66            skip_summarization: false,
67            forward_artifacts: true,
68            timeout: None,
69            input_schema: None,
70            output_schema: None,
71        }
72    }
73}
74
75/// AgentTool wraps an Agent to make it callable as a Tool.
76///
77/// When the parent LLM generates a function call targeting this tool,
78/// the framework executes the wrapped agent, captures its final response,
79/// and returns it as the tool's result.
80pub struct AgentTool {
81    agent: Arc<dyn Agent>,
82    config: AgentToolConfig,
83}
84
85impl AgentTool {
86    /// Create a new AgentTool wrapping the given agent.
87    pub fn new(agent: Arc<dyn Agent>) -> Self {
88        Self { agent, config: AgentToolConfig::default() }
89    }
90
91    /// Create a new AgentTool with custom configuration.
92    pub fn with_config(agent: Arc<dyn Agent>, config: AgentToolConfig) -> Self {
93        Self { agent, config }
94    }
95
96    /// Set whether to skip summarization.
97    pub fn skip_summarization(mut self, skip: bool) -> Self {
98        self.config.skip_summarization = skip;
99        self
100    }
101
102    /// Set whether to forward artifacts.
103    pub fn forward_artifacts(mut self, forward: bool) -> Self {
104        self.config.forward_artifacts = forward;
105        self
106    }
107
108    /// Set timeout for sub-agent execution.
109    pub fn timeout(mut self, timeout: Duration) -> Self {
110        self.config.timeout = Some(timeout);
111        self
112    }
113
114    /// Set custom input schema.
115    pub fn input_schema(mut self, schema: Value) -> Self {
116        self.config.input_schema = Some(schema);
117        self
118    }
119
120    /// Set custom output schema.
121    pub fn output_schema(mut self, schema: Value) -> Self {
122        self.config.output_schema = Some(schema);
123        self
124    }
125
126    /// Generate the default parameters schema for this agent tool.
127    fn default_parameters_schema(&self) -> Value {
128        json!({
129            "type": "object",
130            "properties": {
131                "request": {
132                    "type": "string",
133                    "description": format!("The request to send to the {} agent", self.agent.name())
134                }
135            },
136            "required": ["request"]
137        })
138    }
139
140    /// Extract the request text from the tool arguments.
141    fn extract_request(&self, args: &Value) -> String {
142        // Try to get "request" field first
143        if let Some(request) = args.get("request").and_then(|v| v.as_str()) {
144            return request.to_string();
145        }
146
147        // If custom schema, try to serialize the whole args
148        if self.config.input_schema.is_some() {
149            return serde_json::to_string(args).unwrap_or_default();
150        }
151
152        // Fallback: convert args to string
153        match args {
154            Value::String(s) => s.clone(),
155            Value::Object(map) => {
156                // Try to find any string field
157                for value in map.values() {
158                    if let Value::String(s) = value {
159                        return s.clone();
160                    }
161                }
162                serde_json::to_string(args).unwrap_or_default()
163            }
164            _ => serde_json::to_string(args).unwrap_or_default(),
165        }
166    }
167
168    /// Extract the final response text from agent events.
169    fn extract_response(events: &[Event]) -> Value {
170        // Collect all text responses from final events
171        let mut responses = Vec::new();
172
173        for event in events.iter().rev() {
174            if event.is_final_response() {
175                if let Some(content) = &event.llm_response.content {
176                    for part in &content.parts {
177                        if let Part::Text { text } = part {
178                            responses.push(text.clone());
179                        }
180                    }
181                }
182                break; // Only get the last final response
183            }
184        }
185
186        if responses.is_empty() {
187            // Try to get any text from the last event
188            if let Some(last_event) = events.last() {
189                if let Some(content) = &last_event.llm_response.content {
190                    for part in &content.parts {
191                        if let Part::Text { text } = part {
192                            return json!({ "response": text });
193                        }
194                    }
195                }
196            }
197            json!({ "response": "No response from agent" })
198        } else {
199            json!({ "response": responses.join("\n") })
200        }
201    }
202}
203
204#[async_trait]
205impl Tool for AgentTool {
206    fn name(&self) -> &str {
207        self.agent.name()
208    }
209
210    fn description(&self) -> &str {
211        self.agent.description()
212    }
213
214    fn parameters_schema(&self) -> Option<Value> {
215        Some(self.config.input_schema.clone().unwrap_or_else(|| self.default_parameters_schema()))
216    }
217
218    fn response_schema(&self) -> Option<Value> {
219        self.config.output_schema.clone()
220    }
221
222    fn is_long_running(&self) -> bool {
223        // Agent execution could take time, but we wait for completion
224        false
225    }
226
227    #[adk_telemetry::instrument(
228        skip(self, ctx, args),
229        fields(
230            agent_tool.name = %self.agent.name(),
231            agent_tool.description = %self.agent.description(),
232            function_call.id = %ctx.function_call_id()
233        )
234    )]
235    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
236        adk_telemetry::debug!("Executing agent tool: {}", self.agent.name());
237
238        // Extract the request from args
239        let request_text = self.extract_request(&args);
240
241        // Create user content for the sub-agent
242        let user_content = Content::new("user").with_text(&request_text);
243
244        // Create an isolated context for the sub-agent
245        let sub_ctx = Arc::new(AgentToolInvocationContext::new(
246            ctx.clone(),
247            self.agent.clone(),
248            user_content.clone(),
249            self.config.forward_artifacts,
250        ));
251
252        // Execute the sub-agent
253        let execution = async {
254            let mut event_stream = self.agent.run(sub_ctx.clone()).await?;
255
256            // Collect all events
257            let mut events = Vec::new();
258            let mut state_delta = HashMap::new();
259            let mut artifact_delta = HashMap::new();
260
261            while let Some(result) = event_stream.next().await {
262                match result {
263                    Ok(event) => {
264                        // Merge state deltas
265                        state_delta.extend(event.actions.state_delta.clone());
266                        artifact_delta.extend(event.actions.artifact_delta.clone());
267                        events.push(event);
268                    }
269                    Err(e) => {
270                        adk_telemetry::error!("Error in sub-agent execution: {}", e);
271                        return Err(e);
272                    }
273                }
274            }
275
276            Ok((events, state_delta, artifact_delta))
277        };
278
279        // Apply timeout if configured
280        let result = if let Some(timeout_duration) = self.config.timeout {
281            match tokio::time::timeout(timeout_duration, execution).await {
282                Ok(r) => r,
283                Err(_) => {
284                    return Ok(json!({
285                        "error": "Agent execution timed out",
286                        "agent": self.agent.name()
287                    }));
288                }
289            }
290        } else {
291            execution.await
292        };
293
294        match result {
295            Ok((events, state_delta, artifact_delta)) => {
296                // Forward state_delta and artifact_delta to parent context
297                if !state_delta.is_empty() || !artifact_delta.is_empty() {
298                    let mut parent_actions = ctx.actions();
299                    parent_actions.state_delta.extend(state_delta);
300                    parent_actions.artifact_delta.extend(artifact_delta);
301                    ctx.set_actions(parent_actions);
302                }
303
304                // Extract and return the response
305                let response = Self::extract_response(&events);
306
307                adk_telemetry::debug!(
308                    "Agent tool {} completed with {} events",
309                    self.agent.name(),
310                    events.len()
311                );
312
313                Ok(response)
314            }
315            Err(e) => Ok(json!({
316                "error": format!("Agent execution failed: {}", e),
317                "agent": self.agent.name()
318            })),
319        }
320    }
321}
322
323// Internal context for sub-agent execution
324struct AgentToolInvocationContext {
325    parent_ctx: Arc<dyn ToolContext>,
326    agent: Arc<dyn Agent>,
327    user_content: Content,
328    invocation_id: String,
329    ended: Arc<AtomicBool>,
330    forward_artifacts: bool,
331    session: Arc<AgentToolSession>,
332}
333
334impl AgentToolInvocationContext {
335    fn new(
336        parent_ctx: Arc<dyn ToolContext>,
337        agent: Arc<dyn Agent>,
338        user_content: Content,
339        forward_artifacts: bool,
340    ) -> Self {
341        let invocation_id = format!("agent-tool-{}", uuid::Uuid::new_v4());
342        Self {
343            parent_ctx,
344            agent,
345            user_content,
346            invocation_id,
347            ended: Arc::new(AtomicBool::new(false)),
348            forward_artifacts,
349            session: Arc::new(AgentToolSession::new()),
350        }
351    }
352}
353
354#[async_trait]
355impl ReadonlyContext for AgentToolInvocationContext {
356    fn invocation_id(&self) -> &str {
357        &self.invocation_id
358    }
359
360    fn agent_name(&self) -> &str {
361        self.agent.name()
362    }
363
364    fn user_id(&self) -> &str {
365        self.parent_ctx.user_id()
366    }
367
368    fn app_name(&self) -> &str {
369        self.parent_ctx.app_name()
370    }
371
372    fn session_id(&self) -> &str {
373        // Use a unique session ID for the sub-agent
374        &self.invocation_id
375    }
376
377    fn branch(&self) -> &str {
378        ""
379    }
380
381    fn user_content(&self) -> &Content {
382        &self.user_content
383    }
384}
385
386#[async_trait]
387impl CallbackContext for AgentToolInvocationContext {
388    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
389        if self.forward_artifacts { self.parent_ctx.artifacts() } else { None }
390    }
391}
392
393#[async_trait]
394impl InvocationContext for AgentToolInvocationContext {
395    fn agent(&self) -> Arc<dyn Agent> {
396        self.agent.clone()
397    }
398
399    fn memory(&self) -> Option<Arc<dyn Memory>> {
400        // Sub-agents don't have direct memory access in this implementation
401        // Could be extended to forward memory if needed
402        None
403    }
404
405    fn session(&self) -> &dyn Session {
406        self.session.as_ref()
407    }
408
409    fn run_config(&self) -> &RunConfig {
410        // Use None streaming mode for sub-agent so responses are fully accumulated
411        // before being returned. SSE mode yields partial chunks which makes
412        // extract_response unable to capture the complete text.
413        static AGENT_TOOL_CONFIG: std::sync::OnceLock<RunConfig> = std::sync::OnceLock::new();
414        AGENT_TOOL_CONFIG.get_or_init(|| RunConfig {
415            streaming_mode: adk_core::StreamingMode::None,
416            ..RunConfig::default()
417        })
418    }
419
420    fn end_invocation(&self) {
421        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
422    }
423
424    fn ended(&self) -> bool {
425        self.ended.load(std::sync::atomic::Ordering::SeqCst)
426    }
427}
428
429// Minimal session for sub-agent execution
430struct AgentToolSession {
431    id: String,
432    state: std::sync::RwLock<HashMap<String, Value>>,
433}
434
435impl AgentToolSession {
436    fn new() -> Self {
437        Self {
438            id: format!("agent-tool-session-{}", uuid::Uuid::new_v4()),
439            state: Default::default(),
440        }
441    }
442}
443
444impl Session for AgentToolSession {
445    fn id(&self) -> &str {
446        &self.id
447    }
448
449    fn app_name(&self) -> &str {
450        "agent-tool"
451    }
452
453    fn user_id(&self) -> &str {
454        "agent-tool-user"
455    }
456
457    fn state(&self) -> &dyn State {
458        self
459    }
460
461    fn conversation_history(&self) -> Vec<Content> {
462        // Sub-agent starts with empty history
463        Vec::new()
464    }
465}
466
467impl State for AgentToolSession {
468    fn get(&self, key: &str) -> Option<Value> {
469        self.state.read().ok()?.get(key).cloned()
470    }
471
472    fn set(&mut self, key: String, value: Value) {
473        if let Err(msg) = adk_core::validate_state_key(&key) {
474            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
475            return;
476        }
477        if let Ok(mut state) = self.state.write() {
478            state.insert(key, value);
479        }
480    }
481
482    fn all(&self) -> HashMap<String, Value> {
483        self.state.read().ok().map(|s| s.clone()).unwrap_or_default()
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    struct MockAgent {
492        name: String,
493        description: String,
494    }
495
496    #[async_trait]
497    impl Agent for MockAgent {
498        fn name(&self) -> &str {
499            &self.name
500        }
501
502        fn description(&self) -> &str {
503            &self.description
504        }
505
506        fn sub_agents(&self) -> &[Arc<dyn Agent>] {
507            &[]
508        }
509
510        async fn run(&self, _ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
511            use async_stream::stream;
512
513            let name = self.name.clone();
514            let s = stream! {
515                let mut event = Event::new("mock-inv");
516                event.author = name;
517                event.llm_response.content = Some(Content::new("model").with_text("Mock response"));
518                yield Ok(event);
519            };
520
521            Ok(Box::pin(s))
522        }
523    }
524
525    #[test]
526    fn test_agent_tool_creation() {
527        let agent = Arc::new(MockAgent {
528            name: "test_agent".to_string(),
529            description: "A test agent".to_string(),
530        });
531
532        let tool = AgentTool::new(agent);
533        assert_eq!(tool.name(), "test_agent");
534        assert_eq!(tool.description(), "A test agent");
535    }
536
537    #[test]
538    fn test_agent_tool_config() {
539        let agent =
540            Arc::new(MockAgent { name: "test".to_string(), description: "test".to_string() });
541
542        let tool = AgentTool::new(agent)
543            .skip_summarization(true)
544            .forward_artifacts(false)
545            .timeout(Duration::from_secs(30));
546
547        assert!(tool.config.skip_summarization);
548        assert!(!tool.config.forward_artifacts);
549        assert_eq!(tool.config.timeout, Some(Duration::from_secs(30)));
550    }
551
552    #[test]
553    fn test_parameters_schema() {
554        let agent = Arc::new(MockAgent {
555            name: "calculator".to_string(),
556            description: "Performs calculations".to_string(),
557        });
558
559        let tool = AgentTool::new(agent);
560        let schema = tool.parameters_schema().unwrap();
561
562        assert_eq!(schema["type"], "object");
563        assert!(schema["properties"]["request"].is_object());
564    }
565
566    #[test]
567    fn test_extract_request() {
568        let agent =
569            Arc::new(MockAgent { name: "test".to_string(), description: "test".to_string() });
570
571        let tool = AgentTool::new(agent);
572
573        // Test with request field
574        let args = json!({"request": "solve 2+2"});
575        assert_eq!(tool.extract_request(&args), "solve 2+2");
576
577        // Test with string value
578        let args = json!("direct request");
579        assert_eq!(tool.extract_request(&args), "direct request");
580    }
581
582    #[test]
583    fn test_extract_response() {
584        let mut event = Event::new("inv-123");
585        event.llm_response.content = Some(Content::new("model").with_text("The answer is 4"));
586
587        let events = vec![event];
588        let response = AgentTool::extract_response(&events);
589
590        assert_eq!(response["response"], "The answer is 4");
591    }
592}