1use adk_core::{
31 Agent, Artifacts, CallbackContext, Content, Event, InvocationContext, Memory, Part,
32 ReadonlyContext, Result, RunConfig, Session, State, Tool, ToolContext,
33};
34use async_trait::async_trait;
35use futures::StreamExt;
36use serde_json::{Value, json};
37use std::collections::HashMap;
38use std::sync::{Arc, atomic::AtomicBool};
39use std::time::Duration;
40
41#[derive(Debug, Clone)]
43pub struct AgentToolConfig {
44 pub skip_summarization: bool,
47
48 pub forward_artifacts: bool,
51
52 pub timeout: Option<Duration>,
54
55 pub input_schema: Option<Value>,
58
59 pub output_schema: Option<Value>,
61}
62
63impl Default for AgentToolConfig {
64 fn default() -> Self {
65 Self {
66 skip_summarization: false,
67 forward_artifacts: true,
68 timeout: None,
69 input_schema: None,
70 output_schema: None,
71 }
72 }
73}
74
75pub struct AgentTool {
81 agent: Arc<dyn Agent>,
82 config: AgentToolConfig,
83}
84
85impl AgentTool {
86 pub fn new(agent: Arc<dyn Agent>) -> Self {
88 Self { agent, config: AgentToolConfig::default() }
89 }
90
91 pub fn with_config(agent: Arc<dyn Agent>, config: AgentToolConfig) -> Self {
93 Self { agent, config }
94 }
95
96 pub fn skip_summarization(mut self, skip: bool) -> Self {
98 self.config.skip_summarization = skip;
99 self
100 }
101
102 pub fn forward_artifacts(mut self, forward: bool) -> Self {
104 self.config.forward_artifacts = forward;
105 self
106 }
107
108 pub fn timeout(mut self, timeout: Duration) -> Self {
110 self.config.timeout = Some(timeout);
111 self
112 }
113
114 pub fn input_schema(mut self, schema: Value) -> Self {
116 self.config.input_schema = Some(schema);
117 self
118 }
119
120 pub fn output_schema(mut self, schema: Value) -> Self {
122 self.config.output_schema = Some(schema);
123 self
124 }
125
126 fn default_parameters_schema(&self) -> Value {
128 json!({
129 "type": "object",
130 "properties": {
131 "request": {
132 "type": "string",
133 "description": format!("The request to send to the {} agent", self.agent.name())
134 }
135 },
136 "required": ["request"]
137 })
138 }
139
140 fn extract_request(&self, args: &Value) -> String {
142 if let Some(request) = args.get("request").and_then(|v| v.as_str()) {
144 return request.to_string();
145 }
146
147 if self.config.input_schema.is_some() {
149 return serde_json::to_string(args).unwrap_or_default();
150 }
151
152 match args {
154 Value::String(s) => s.clone(),
155 Value::Object(map) => {
156 for value in map.values() {
158 if let Value::String(s) = value {
159 return s.clone();
160 }
161 }
162 serde_json::to_string(args).unwrap_or_default()
163 }
164 _ => serde_json::to_string(args).unwrap_or_default(),
165 }
166 }
167
168 fn extract_response(events: &[Event]) -> Value {
170 let mut responses = Vec::new();
172
173 for event in events.iter().rev() {
174 if event.is_final_response() {
175 if let Some(content) = &event.llm_response.content {
176 for part in &content.parts {
177 if let Part::Text { text } = part {
178 responses.push(text.clone());
179 }
180 }
181 }
182 break; }
184 }
185
186 if responses.is_empty() {
187 if let Some(last_event) = events.last() {
189 if let Some(content) = &last_event.llm_response.content {
190 for part in &content.parts {
191 if let Part::Text { text } = part {
192 return json!({ "response": text });
193 }
194 }
195 }
196 }
197 json!({ "response": "No response from agent" })
198 } else {
199 json!({ "response": responses.join("\n") })
200 }
201 }
202}
203
204#[async_trait]
205impl Tool for AgentTool {
206 fn name(&self) -> &str {
207 self.agent.name()
208 }
209
210 fn description(&self) -> &str {
211 self.agent.description()
212 }
213
214 fn parameters_schema(&self) -> Option<Value> {
215 Some(self.config.input_schema.clone().unwrap_or_else(|| self.default_parameters_schema()))
216 }
217
218 fn response_schema(&self) -> Option<Value> {
219 self.config.output_schema.clone()
220 }
221
222 fn is_long_running(&self) -> bool {
223 false
225 }
226
227 #[adk_telemetry::instrument(
228 skip(self, ctx, args),
229 fields(
230 agent_tool.name = %self.agent.name(),
231 agent_tool.description = %self.agent.description(),
232 function_call.id = %ctx.function_call_id()
233 )
234 )]
235 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
236 adk_telemetry::debug!("Executing agent tool: {}", self.agent.name());
237
238 let request_text = self.extract_request(&args);
240
241 let user_content = Content::new("user").with_text(&request_text);
243
244 let sub_ctx = Arc::new(AgentToolInvocationContext::new(
246 ctx.clone(),
247 self.agent.clone(),
248 user_content.clone(),
249 self.config.forward_artifacts,
250 ));
251
252 let execution = async {
254 let mut event_stream = self.agent.run(sub_ctx.clone()).await?;
255
256 let mut events = Vec::new();
258 let mut state_delta = HashMap::new();
259 let mut artifact_delta = HashMap::new();
260
261 while let Some(result) = event_stream.next().await {
262 match result {
263 Ok(event) => {
264 state_delta.extend(event.actions.state_delta.clone());
266 artifact_delta.extend(event.actions.artifact_delta.clone());
267 events.push(event);
268 }
269 Err(e) => {
270 adk_telemetry::error!("Error in sub-agent execution: {}", e);
271 return Err(e);
272 }
273 }
274 }
275
276 Ok((events, state_delta, artifact_delta))
277 };
278
279 let result = if let Some(timeout_duration) = self.config.timeout {
281 match tokio::time::timeout(timeout_duration, execution).await {
282 Ok(r) => r,
283 Err(_) => {
284 return Ok(json!({
285 "error": "Agent execution timed out",
286 "agent": self.agent.name()
287 }));
288 }
289 }
290 } else {
291 execution.await
292 };
293
294 match result {
295 Ok((events, state_delta, artifact_delta)) => {
296 if !state_delta.is_empty() || !artifact_delta.is_empty() {
298 let mut parent_actions = ctx.actions();
299 parent_actions.state_delta.extend(state_delta);
300 parent_actions.artifact_delta.extend(artifact_delta);
301 ctx.set_actions(parent_actions);
302 }
303
304 let response = Self::extract_response(&events);
306
307 adk_telemetry::debug!(
308 "Agent tool {} completed with {} events",
309 self.agent.name(),
310 events.len()
311 );
312
313 Ok(response)
314 }
315 Err(e) => Ok(json!({
316 "error": format!("Agent execution failed: {}", e),
317 "agent": self.agent.name()
318 })),
319 }
320 }
321}
322
323struct AgentToolInvocationContext {
325 parent_ctx: Arc<dyn ToolContext>,
326 agent: Arc<dyn Agent>,
327 user_content: Content,
328 invocation_id: String,
329 ended: Arc<AtomicBool>,
330 forward_artifacts: bool,
331 session: Arc<AgentToolSession>,
332}
333
334impl AgentToolInvocationContext {
335 fn new(
336 parent_ctx: Arc<dyn ToolContext>,
337 agent: Arc<dyn Agent>,
338 user_content: Content,
339 forward_artifacts: bool,
340 ) -> Self {
341 let invocation_id = format!("agent-tool-{}", uuid::Uuid::new_v4());
342 Self {
343 parent_ctx,
344 agent,
345 user_content,
346 invocation_id,
347 ended: Arc::new(AtomicBool::new(false)),
348 forward_artifacts,
349 session: Arc::new(AgentToolSession::new()),
350 }
351 }
352}
353
354#[async_trait]
355impl ReadonlyContext for AgentToolInvocationContext {
356 fn invocation_id(&self) -> &str {
357 &self.invocation_id
358 }
359
360 fn agent_name(&self) -> &str {
361 self.agent.name()
362 }
363
364 fn user_id(&self) -> &str {
365 self.parent_ctx.user_id()
366 }
367
368 fn app_name(&self) -> &str {
369 self.parent_ctx.app_name()
370 }
371
372 fn session_id(&self) -> &str {
373 &self.invocation_id
375 }
376
377 fn branch(&self) -> &str {
378 ""
379 }
380
381 fn user_content(&self) -> &Content {
382 &self.user_content
383 }
384}
385
386#[async_trait]
387impl CallbackContext for AgentToolInvocationContext {
388 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
389 if self.forward_artifacts { self.parent_ctx.artifacts() } else { None }
390 }
391}
392
393#[async_trait]
394impl InvocationContext for AgentToolInvocationContext {
395 fn agent(&self) -> Arc<dyn Agent> {
396 self.agent.clone()
397 }
398
399 fn memory(&self) -> Option<Arc<dyn Memory>> {
400 None
403 }
404
405 fn session(&self) -> &dyn Session {
406 self.session.as_ref()
407 }
408
409 fn run_config(&self) -> &RunConfig {
410 static AGENT_TOOL_CONFIG: std::sync::OnceLock<RunConfig> = std::sync::OnceLock::new();
414 AGENT_TOOL_CONFIG.get_or_init(|| {
415 RunConfig::builder().streaming_mode(adk_core::StreamingMode::None).build()
416 })
417 }
418
419 fn end_invocation(&self) {
420 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
421 }
422
423 fn ended(&self) -> bool {
424 self.ended.load(std::sync::atomic::Ordering::SeqCst)
425 }
426}
427
428struct AgentToolSession {
430 id: String,
431 state: std::sync::RwLock<HashMap<String, Value>>,
432}
433
434impl AgentToolSession {
435 fn new() -> Self {
436 Self {
437 id: format!("agent-tool-session-{}", uuid::Uuid::new_v4()),
438 state: Default::default(),
439 }
440 }
441}
442
443impl Session for AgentToolSession {
444 fn id(&self) -> &str {
445 &self.id
446 }
447
448 fn app_name(&self) -> &str {
449 "agent-tool"
450 }
451
452 fn user_id(&self) -> &str {
453 "agent-tool-user"
454 }
455
456 fn state(&self) -> &dyn State {
457 self
458 }
459
460 fn conversation_history(&self) -> Vec<Content> {
461 Vec::new()
463 }
464}
465
466impl State for AgentToolSession {
467 fn get(&self, key: &str) -> Option<Value> {
468 self.state.read().ok()?.get(key).cloned()
469 }
470
471 fn set(&mut self, key: String, value: Value) {
472 if let Err(msg) = adk_core::validate_state_key(&key) {
473 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
474 return;
475 }
476 if let Ok(mut state) = self.state.write() {
477 state.insert(key, value);
478 }
479 }
480
481 fn all(&self) -> HashMap<String, Value> {
482 self.state.read().ok().map(|s| s.clone()).unwrap_or_default()
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489
490 struct MockAgent {
491 name: String,
492 description: String,
493 }
494
495 #[async_trait]
496 impl Agent for MockAgent {
497 fn name(&self) -> &str {
498 &self.name
499 }
500
501 fn description(&self) -> &str {
502 &self.description
503 }
504
505 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
506 &[]
507 }
508
509 async fn run(&self, _ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
510 use async_stream::stream;
511
512 let name = self.name.clone();
513 let s = stream! {
514 let mut event = Event::new("mock-inv");
515 event.author = name;
516 event.llm_response.content = Some(Content::new("model").with_text("Mock response"));
517 yield Ok(event);
518 };
519
520 Ok(Box::pin(s))
521 }
522 }
523
524 #[test]
525 fn test_agent_tool_creation() {
526 let agent = Arc::new(MockAgent {
527 name: "test_agent".to_string(),
528 description: "A test agent".to_string(),
529 });
530
531 let tool = AgentTool::new(agent);
532 assert_eq!(tool.name(), "test_agent");
533 assert_eq!(tool.description(), "A test agent");
534 }
535
536 #[test]
537 fn test_agent_tool_config() {
538 let agent =
539 Arc::new(MockAgent { name: "test".to_string(), description: "test".to_string() });
540
541 let tool = AgentTool::new(agent)
542 .skip_summarization(true)
543 .forward_artifacts(false)
544 .timeout(Duration::from_secs(30));
545
546 assert!(tool.config.skip_summarization);
547 assert!(!tool.config.forward_artifacts);
548 assert_eq!(tool.config.timeout, Some(Duration::from_secs(30)));
549 }
550
551 #[test]
552 fn test_parameters_schema() {
553 let agent = Arc::new(MockAgent {
554 name: "calculator".to_string(),
555 description: "Performs calculations".to_string(),
556 });
557
558 let tool = AgentTool::new(agent);
559 let schema = tool.parameters_schema().unwrap();
560
561 assert_eq!(schema["type"], "object");
562 assert!(schema["properties"]["request"].is_object());
563 }
564
565 #[test]
566 fn test_extract_request() {
567 let agent =
568 Arc::new(MockAgent { name: "test".to_string(), description: "test".to_string() });
569
570 let tool = AgentTool::new(agent);
571
572 let args = json!({"request": "solve 2+2"});
574 assert_eq!(tool.extract_request(&args), "solve 2+2");
575
576 let args = json!("direct request");
578 assert_eq!(tool.extract_request(&args), "direct request");
579 }
580
581 #[test]
582 fn test_extract_response() {
583 let mut event = Event::new("inv-123");
584 event.llm_response.content = Some(Content::new("model").with_text("The answer is 4"));
585
586 let events = vec![event];
587 let response = AgentTool::extract_response(&events);
588
589 assert_eq!(response["response"], "The answer is 4");
590 }
591}