adk-tool 0.6.0

Tool system for Rust Agent Development Kit (ADK-Rust) agents (FunctionTool, MCP, Google Search)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
//! AgentTool - Use agents as callable tools
//!
//! This module provides `AgentTool` which wraps an `Agent` instance to make it
//! callable as a `Tool`. This enables powerful composition patterns where a
//! coordinator agent can invoke specialized sub-agents.
//!
//! # Example
//!
//! ```rust,ignore
//! use adk_tool::AgentTool;
//! use adk_agent::LlmAgentBuilder;
//!
//! // Create a specialized agent
//! let math_agent = LlmAgentBuilder::new("math_expert")
//!     .description("Solves mathematical problems")
//!     .instruction("You are a math expert. Solve problems step by step.")
//!     .model(model.clone())
//!     .build()?;
//!
//! // Wrap it as a tool
//! let math_tool = AgentTool::new(Arc::new(math_agent));
//!
//! // Use in coordinator agent
//! let coordinator = LlmAgentBuilder::new("coordinator")
//!     .instruction("Help users by delegating to specialists")
//!     .tools(vec![Arc::new(math_tool)])
//!     .build()?;
//! ```

use adk_core::{
    Agent, Artifacts, CallbackContext, Content, Event, InvocationContext, Memory, Part,
    ReadonlyContext, Result, RunConfig, Session, State, Tool, ToolContext,
};
use async_trait::async_trait;
use futures::StreamExt;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::{Arc, atomic::AtomicBool};
use std::time::Duration;

/// Configuration options for AgentTool behavior.
#[derive(Debug, Clone)]
pub struct AgentToolConfig {
    /// Skip summarization after sub-agent execution.
    /// When true, returns the raw output from the sub-agent.
    pub skip_summarization: bool,

    /// Forward artifacts between parent and sub-agent.
    /// When true, the sub-agent can access parent's artifacts.
    pub forward_artifacts: bool,

    /// Optional timeout for sub-agent execution.
    pub timeout: Option<Duration>,

    /// Custom input schema for the tool.
    /// If None, defaults to `{"request": "string"}`.
    pub input_schema: Option<Value>,

    /// Custom output schema for the tool.
    pub output_schema: Option<Value>,
}

impl Default for AgentToolConfig {
    fn default() -> Self {
        Self {
            skip_summarization: false,
            forward_artifacts: true,
            timeout: None,
            input_schema: None,
            output_schema: None,
        }
    }
}

/// AgentTool wraps an Agent to make it callable as a Tool.
///
/// When the parent LLM generates a function call targeting this tool,
/// the framework executes the wrapped agent, captures its final response,
/// and returns it as the tool's result.
pub struct AgentTool {
    agent: Arc<dyn Agent>,
    config: AgentToolConfig,
}

impl AgentTool {
    /// Create a new AgentTool wrapping the given agent.
    pub fn new(agent: Arc<dyn Agent>) -> Self {
        Self { agent, config: AgentToolConfig::default() }
    }

    /// Create a new AgentTool with custom configuration.
    pub fn with_config(agent: Arc<dyn Agent>, config: AgentToolConfig) -> Self {
        Self { agent, config }
    }

    /// Set whether to skip summarization.
    pub fn skip_summarization(mut self, skip: bool) -> Self {
        self.config.skip_summarization = skip;
        self
    }

    /// Set whether to forward artifacts.
    pub fn forward_artifacts(mut self, forward: bool) -> Self {
        self.config.forward_artifacts = forward;
        self
    }

    /// Set timeout for sub-agent execution.
    pub fn timeout(mut self, timeout: Duration) -> Self {
        self.config.timeout = Some(timeout);
        self
    }

    /// Set custom input schema.
    pub fn input_schema(mut self, schema: Value) -> Self {
        self.config.input_schema = Some(schema);
        self
    }

    /// Set custom output schema.
    pub fn output_schema(mut self, schema: Value) -> Self {
        self.config.output_schema = Some(schema);
        self
    }

    /// Generate the default parameters schema for this agent tool.
    fn default_parameters_schema(&self) -> Value {
        json!({
            "type": "object",
            "properties": {
                "request": {
                    "type": "string",
                    "description": format!("The request to send to the {} agent", self.agent.name())
                }
            },
            "required": ["request"]
        })
    }

    /// Extract the request text from the tool arguments.
    fn extract_request(&self, args: &Value) -> String {
        // Try to get "request" field first
        if let Some(request) = args.get("request").and_then(|v| v.as_str()) {
            return request.to_string();
        }

        // If custom schema, try to serialize the whole args
        if self.config.input_schema.is_some() {
            return serde_json::to_string(args).unwrap_or_default();
        }

        // Fallback: convert args to string
        match args {
            Value::String(s) => s.clone(),
            Value::Object(map) => {
                // Try to find any string field
                for value in map.values() {
                    if let Value::String(s) = value {
                        return s.clone();
                    }
                }
                serde_json::to_string(args).unwrap_or_default()
            }
            _ => serde_json::to_string(args).unwrap_or_default(),
        }
    }

    /// Extract the final response text from agent events.
    fn extract_response(events: &[Event]) -> Value {
        // Collect all text responses from final events
        let mut responses = Vec::new();

        for event in events.iter().rev() {
            if event.is_final_response() {
                if let Some(content) = &event.llm_response.content {
                    for part in &content.parts {
                        if let Part::Text { text } = part {
                            responses.push(text.clone());
                        }
                    }
                }
                break; // Only get the last final response
            }
        }

        if responses.is_empty() {
            // Try to get any text from the last event
            if let Some(last_event) = events.last() {
                if let Some(content) = &last_event.llm_response.content {
                    for part in &content.parts {
                        if let Part::Text { text } = part {
                            return json!({ "response": text });
                        }
                    }
                }
            }
            json!({ "response": "No response from agent" })
        } else {
            json!({ "response": responses.join("\n") })
        }
    }
}

#[async_trait]
impl Tool for AgentTool {
    fn name(&self) -> &str {
        self.agent.name()
    }

    fn description(&self) -> &str {
        self.agent.description()
    }

    fn parameters_schema(&self) -> Option<Value> {
        Some(self.config.input_schema.clone().unwrap_or_else(|| self.default_parameters_schema()))
    }

    fn response_schema(&self) -> Option<Value> {
        self.config.output_schema.clone()
    }

    fn is_long_running(&self) -> bool {
        // Agent execution could take time, but we wait for completion
        false
    }

    #[adk_telemetry::instrument(
        skip(self, ctx, args),
        fields(
            agent_tool.name = %self.agent.name(),
            agent_tool.description = %self.agent.description(),
            function_call.id = %ctx.function_call_id()
        )
    )]
    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
        adk_telemetry::debug!("Executing agent tool: {}", self.agent.name());

        // Extract the request from args
        let request_text = self.extract_request(&args);

        // Create user content for the sub-agent
        let user_content = Content::new("user").with_text(&request_text);

        // Create an isolated context for the sub-agent
        let sub_ctx = Arc::new(AgentToolInvocationContext::new(
            ctx.clone(),
            self.agent.clone(),
            user_content.clone(),
            self.config.forward_artifacts,
        ));

        // Execute the sub-agent
        let execution = async {
            let mut event_stream = self.agent.run(sub_ctx.clone()).await?;

            // Collect all events
            let mut events = Vec::new();
            let mut state_delta = HashMap::new();
            let mut artifact_delta = HashMap::new();

            while let Some(result) = event_stream.next().await {
                match result {
                    Ok(event) => {
                        // Merge state deltas
                        state_delta.extend(event.actions.state_delta.clone());
                        artifact_delta.extend(event.actions.artifact_delta.clone());
                        events.push(event);
                    }
                    Err(e) => {
                        adk_telemetry::error!("Error in sub-agent execution: {}", e);
                        return Err(e);
                    }
                }
            }

            Ok((events, state_delta, artifact_delta))
        };

        // Apply timeout if configured
        let result = if let Some(timeout_duration) = self.config.timeout {
            match tokio::time::timeout(timeout_duration, execution).await {
                Ok(r) => r,
                Err(_) => {
                    return Ok(json!({
                        "error": "Agent execution timed out",
                        "agent": self.agent.name()
                    }));
                }
            }
        } else {
            execution.await
        };

        match result {
            Ok((events, state_delta, artifact_delta)) => {
                // Forward state_delta and artifact_delta to parent context
                if !state_delta.is_empty() || !artifact_delta.is_empty() {
                    let mut parent_actions = ctx.actions();
                    parent_actions.state_delta.extend(state_delta);
                    parent_actions.artifact_delta.extend(artifact_delta);
                    ctx.set_actions(parent_actions);
                }

                // Extract and return the response
                let response = Self::extract_response(&events);

                adk_telemetry::debug!(
                    "Agent tool {} completed with {} events",
                    self.agent.name(),
                    events.len()
                );

                Ok(response)
            }
            Err(e) => Ok(json!({
                "error": format!("Agent execution failed: {}", e),
                "agent": self.agent.name()
            })),
        }
    }
}

// Internal context for sub-agent execution
struct AgentToolInvocationContext {
    parent_ctx: Arc<dyn ToolContext>,
    agent: Arc<dyn Agent>,
    user_content: Content,
    invocation_id: String,
    ended: Arc<AtomicBool>,
    forward_artifacts: bool,
    session: Arc<AgentToolSession>,
}

impl AgentToolInvocationContext {
    fn new(
        parent_ctx: Arc<dyn ToolContext>,
        agent: Arc<dyn Agent>,
        user_content: Content,
        forward_artifacts: bool,
    ) -> Self {
        let invocation_id = format!("agent-tool-{}", uuid::Uuid::new_v4());
        Self {
            parent_ctx,
            agent,
            user_content,
            invocation_id,
            ended: Arc::new(AtomicBool::new(false)),
            forward_artifacts,
            session: Arc::new(AgentToolSession::new()),
        }
    }
}

#[async_trait]
impl ReadonlyContext for AgentToolInvocationContext {
    fn invocation_id(&self) -> &str {
        &self.invocation_id
    }

    fn agent_name(&self) -> &str {
        self.agent.name()
    }

    fn user_id(&self) -> &str {
        self.parent_ctx.user_id()
    }

    fn app_name(&self) -> &str {
        self.parent_ctx.app_name()
    }

    fn session_id(&self) -> &str {
        // Use a unique session ID for the sub-agent
        &self.invocation_id
    }

    fn branch(&self) -> &str {
        ""
    }

    fn user_content(&self) -> &Content {
        &self.user_content
    }
}

#[async_trait]
impl CallbackContext for AgentToolInvocationContext {
    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
        if self.forward_artifacts { self.parent_ctx.artifacts() } else { None }
    }
}

#[async_trait]
impl InvocationContext for AgentToolInvocationContext {
    fn agent(&self) -> Arc<dyn Agent> {
        self.agent.clone()
    }

    fn memory(&self) -> Option<Arc<dyn Memory>> {
        // Sub-agents don't have direct memory access in this implementation
        // Could be extended to forward memory if needed
        None
    }

    fn session(&self) -> &dyn Session {
        self.session.as_ref()
    }

    fn run_config(&self) -> &RunConfig {
        // Use None streaming mode for sub-agent so responses are fully accumulated
        // before being returned. SSE mode yields partial chunks which makes
        // extract_response unable to capture the complete text.
        static AGENT_TOOL_CONFIG: std::sync::OnceLock<RunConfig> = std::sync::OnceLock::new();
        AGENT_TOOL_CONFIG.get_or_init(|| RunConfig {
            streaming_mode: adk_core::StreamingMode::None,
            ..RunConfig::default()
        })
    }

    fn end_invocation(&self) {
        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
    }

    fn ended(&self) -> bool {
        self.ended.load(std::sync::atomic::Ordering::SeqCst)
    }
}

// Minimal session for sub-agent execution
struct AgentToolSession {
    id: String,
    state: std::sync::RwLock<HashMap<String, Value>>,
}

impl AgentToolSession {
    fn new() -> Self {
        Self {
            id: format!("agent-tool-session-{}", uuid::Uuid::new_v4()),
            state: Default::default(),
        }
    }
}

impl Session for AgentToolSession {
    fn id(&self) -> &str {
        &self.id
    }

    fn app_name(&self) -> &str {
        "agent-tool"
    }

    fn user_id(&self) -> &str {
        "agent-tool-user"
    }

    fn state(&self) -> &dyn State {
        self
    }

    fn conversation_history(&self) -> Vec<Content> {
        // Sub-agent starts with empty history
        Vec::new()
    }
}

impl State for AgentToolSession {
    fn get(&self, key: &str) -> Option<Value> {
        self.state.read().ok()?.get(key).cloned()
    }

    fn set(&mut self, key: String, value: Value) {
        if let Err(msg) = adk_core::validate_state_key(&key) {
            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
            return;
        }
        if let Ok(mut state) = self.state.write() {
            state.insert(key, value);
        }
    }

    fn all(&self) -> HashMap<String, Value> {
        self.state.read().ok().map(|s| s.clone()).unwrap_or_default()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    struct MockAgent {
        name: String,
        description: String,
    }

    #[async_trait]
    impl Agent for MockAgent {
        fn name(&self) -> &str {
            &self.name
        }

        fn description(&self) -> &str {
            &self.description
        }

        fn sub_agents(&self) -> &[Arc<dyn Agent>] {
            &[]
        }

        async fn run(&self, _ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
            use async_stream::stream;

            let name = self.name.clone();
            let s = stream! {
                let mut event = Event::new("mock-inv");
                event.author = name;
                event.llm_response.content = Some(Content::new("model").with_text("Mock response"));
                yield Ok(event);
            };

            Ok(Box::pin(s))
        }
    }

    #[test]
    fn test_agent_tool_creation() {
        let agent = Arc::new(MockAgent {
            name: "test_agent".to_string(),
            description: "A test agent".to_string(),
        });

        let tool = AgentTool::new(agent);
        assert_eq!(tool.name(), "test_agent");
        assert_eq!(tool.description(), "A test agent");
    }

    #[test]
    fn test_agent_tool_config() {
        let agent =
            Arc::new(MockAgent { name: "test".to_string(), description: "test".to_string() });

        let tool = AgentTool::new(agent)
            .skip_summarization(true)
            .forward_artifacts(false)
            .timeout(Duration::from_secs(30));

        assert!(tool.config.skip_summarization);
        assert!(!tool.config.forward_artifacts);
        assert_eq!(tool.config.timeout, Some(Duration::from_secs(30)));
    }

    #[test]
    fn test_parameters_schema() {
        let agent = Arc::new(MockAgent {
            name: "calculator".to_string(),
            description: "Performs calculations".to_string(),
        });

        let tool = AgentTool::new(agent);
        let schema = tool.parameters_schema().unwrap();

        assert_eq!(schema["type"], "object");
        assert!(schema["properties"]["request"].is_object());
    }

    #[test]
    fn test_extract_request() {
        let agent =
            Arc::new(MockAgent { name: "test".to_string(), description: "test".to_string() });

        let tool = AgentTool::new(agent);

        // Test with request field
        let args = json!({"request": "solve 2+2"});
        assert_eq!(tool.extract_request(&args), "solve 2+2");

        // Test with string value
        let args = json!("direct request");
        assert_eq!(tool.extract_request(&args), "direct request");
    }

    #[test]
    fn test_extract_response() {
        let mut event = Event::new("inv-123");
        event.llm_response.content = Some(Content::new("model").with_text("The answer is 4"));

        let events = vec![event];
        let response = AgentTool::extract_response(&events);

        assert_eq!(response["response"], "The answer is 4");
    }
}