1use adk_core::{
2 AfterAgentCallback, AfterModelCallback, AfterToolCallback, AfterToolCallbackFull, Agent,
3 BeforeAgentCallback, BeforeModelCallback, BeforeModelResult, BeforeToolCallback,
4 CallbackContext, Content, Event, EventActions, FunctionResponseData, GlobalInstructionProvider,
5 InstructionProvider, InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry,
6 OnToolErrorCallback, Part, ReadonlyContext, Result, RetryBudget, Tool,
7 ToolConfirmationDecision, ToolConfirmationPolicy, ToolConfirmationRequest, ToolContext,
8 ToolOutcome, Toolset,
9};
10use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index, select_skill_prompt_block};
11use async_stream::stream;
12use async_trait::async_trait;
13use std::sync::{Arc, Mutex};
14use tracing::Instrument;
15
16use crate::{
17 guardrails::{GuardrailSet, enforce_guardrails},
18 tool_call_markup::normalize_option_content,
19 workflow::with_user_content_override,
20};
21
22pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
24
25pub const DEFAULT_TOOL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
27
28pub struct LlmAgent {
29 name: String,
30 description: String,
31 model: Arc<dyn Llm>,
32 instruction: Option<String>,
33 instruction_provider: Option<Arc<InstructionProvider>>,
34 global_instruction: Option<String>,
35 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
36 skills_index: Option<Arc<SkillIndex>>,
37 skill_policy: SelectionPolicy,
38 max_skill_chars: usize,
39 #[allow(dead_code)] input_schema: Option<serde_json::Value>,
41 output_schema: Option<serde_json::Value>,
42 disallow_transfer_to_parent: bool,
43 disallow_transfer_to_peers: bool,
44 include_contents: adk_core::IncludeContents,
45 tools: Vec<Arc<dyn Tool>>,
46 #[allow(dead_code)] toolsets: Vec<Arc<dyn Toolset>>,
48 sub_agents: Vec<Arc<dyn Agent>>,
49 output_key: Option<String>,
50 generate_content_config: Option<adk_core::GenerateContentConfig>,
52 max_iterations: u32,
54 tool_timeout: std::time::Duration,
56 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
57 after_callbacks: Arc<Vec<AfterAgentCallback>>,
58 before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
59 after_model_callbacks: Arc<Vec<AfterModelCallback>>,
60 before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
61 after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
62 on_tool_error_callbacks: Arc<Vec<OnToolErrorCallback>>,
63 after_tool_callbacks_full: Arc<Vec<AfterToolCallbackFull>>,
65 default_retry_budget: Option<RetryBudget>,
67 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
69 circuit_breaker_threshold: Option<u32>,
72 tool_confirmation_policy: ToolConfirmationPolicy,
73 input_guardrails: Arc<GuardrailSet>,
74 output_guardrails: Arc<GuardrailSet>,
75}
76
77impl std::fmt::Debug for LlmAgent {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("LlmAgent")
80 .field("name", &self.name)
81 .field("description", &self.description)
82 .field("model", &self.model.name())
83 .field("instruction", &self.instruction)
84 .field("tools_count", &self.tools.len())
85 .field("sub_agents_count", &self.sub_agents.len())
86 .finish()
87 }
88}
89
90impl LlmAgent {
91 async fn apply_input_guardrails(
92 ctx: Arc<dyn InvocationContext>,
93 input_guardrails: Arc<GuardrailSet>,
94 ) -> Result<Arc<dyn InvocationContext>> {
95 let content =
96 enforce_guardrails(input_guardrails.as_ref(), ctx.user_content(), "input").await?;
97 if content.role != ctx.user_content().role || content.parts != ctx.user_content().parts {
98 Ok(with_user_content_override(ctx, content))
99 } else {
100 Ok(ctx)
101 }
102 }
103
104 async fn apply_output_guardrails(
105 output_guardrails: &GuardrailSet,
106 content: Content,
107 ) -> Result<Content> {
108 enforce_guardrails(output_guardrails, &content, "output").await
109 }
110
111 fn history_parts_from_provider_metadata(
112 provider_metadata: Option<&serde_json::Value>,
113 ) -> Vec<Part> {
114 let Some(provider_metadata) = provider_metadata else {
115 return Vec::new();
116 };
117
118 let history_parts = provider_metadata
119 .get("conversation_history_parts")
120 .or_else(|| {
121 provider_metadata
122 .get("openai")
123 .and_then(|openai| openai.get("conversation_history_parts"))
124 })
125 .and_then(serde_json::Value::as_array);
126
127 history_parts
128 .into_iter()
129 .flatten()
130 .filter_map(|value| serde_json::from_value::<Part>(value.clone()).ok())
131 .collect()
132 }
133
134 fn augment_content_for_history(
135 content: &Content,
136 provider_metadata: Option<&serde_json::Value>,
137 ) -> Content {
138 let mut augmented = content.clone();
139 augmented.parts.extend(Self::history_parts_from_provider_metadata(provider_metadata));
140 augmented
141 }
142}
143
144pub struct LlmAgentBuilder {
145 name: String,
146 description: Option<String>,
147 model: Option<Arc<dyn Llm>>,
148 instruction: Option<String>,
149 instruction_provider: Option<Arc<InstructionProvider>>,
150 global_instruction: Option<String>,
151 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
152 skills_index: Option<Arc<SkillIndex>>,
153 skill_policy: SelectionPolicy,
154 max_skill_chars: usize,
155 input_schema: Option<serde_json::Value>,
156 output_schema: Option<serde_json::Value>,
157 disallow_transfer_to_parent: bool,
158 disallow_transfer_to_peers: bool,
159 include_contents: adk_core::IncludeContents,
160 tools: Vec<Arc<dyn Tool>>,
161 toolsets: Vec<Arc<dyn Toolset>>,
162 sub_agents: Vec<Arc<dyn Agent>>,
163 output_key: Option<String>,
164 generate_content_config: Option<adk_core::GenerateContentConfig>,
165 max_iterations: u32,
166 tool_timeout: std::time::Duration,
167 before_callbacks: Vec<BeforeAgentCallback>,
168 after_callbacks: Vec<AfterAgentCallback>,
169 before_model_callbacks: Vec<BeforeModelCallback>,
170 after_model_callbacks: Vec<AfterModelCallback>,
171 before_tool_callbacks: Vec<BeforeToolCallback>,
172 after_tool_callbacks: Vec<AfterToolCallback>,
173 on_tool_error_callbacks: Vec<OnToolErrorCallback>,
174 after_tool_callbacks_full: Vec<AfterToolCallbackFull>,
175 default_retry_budget: Option<RetryBudget>,
176 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
177 circuit_breaker_threshold: Option<u32>,
178 tool_confirmation_policy: ToolConfirmationPolicy,
179 input_guardrails: GuardrailSet,
180 output_guardrails: GuardrailSet,
181}
182
183impl LlmAgentBuilder {
184 pub fn new(name: impl Into<String>) -> Self {
185 Self {
186 name: name.into(),
187 description: None,
188 model: None,
189 instruction: None,
190 instruction_provider: None,
191 global_instruction: None,
192 global_instruction_provider: None,
193 skills_index: None,
194 skill_policy: SelectionPolicy::default(),
195 max_skill_chars: 2000,
196 input_schema: None,
197 output_schema: None,
198 disallow_transfer_to_parent: false,
199 disallow_transfer_to_peers: false,
200 include_contents: adk_core::IncludeContents::Default,
201 tools: Vec::new(),
202 toolsets: Vec::new(),
203 sub_agents: Vec::new(),
204 output_key: None,
205 generate_content_config: None,
206 max_iterations: DEFAULT_MAX_ITERATIONS,
207 tool_timeout: DEFAULT_TOOL_TIMEOUT,
208 before_callbacks: Vec::new(),
209 after_callbacks: Vec::new(),
210 before_model_callbacks: Vec::new(),
211 after_model_callbacks: Vec::new(),
212 before_tool_callbacks: Vec::new(),
213 after_tool_callbacks: Vec::new(),
214 on_tool_error_callbacks: Vec::new(),
215 after_tool_callbacks_full: Vec::new(),
216 default_retry_budget: None,
217 tool_retry_budgets: std::collections::HashMap::new(),
218 circuit_breaker_threshold: None,
219 tool_confirmation_policy: ToolConfirmationPolicy::Never,
220 input_guardrails: GuardrailSet::new(),
221 output_guardrails: GuardrailSet::new(),
222 }
223 }
224
225 pub fn description(mut self, desc: impl Into<String>) -> Self {
226 self.description = Some(desc.into());
227 self
228 }
229
230 pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
231 self.model = Some(model);
232 self
233 }
234
235 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
236 self.instruction = Some(instruction.into());
237 self
238 }
239
240 pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
241 self.instruction_provider = Some(Arc::new(provider));
242 self
243 }
244
245 pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
246 self.global_instruction = Some(instruction.into());
247 self
248 }
249
250 pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
251 self.global_instruction_provider = Some(Arc::new(provider));
252 self
253 }
254
255 pub fn with_skills(mut self, index: SkillIndex) -> Self {
257 self.skills_index = Some(Arc::new(index));
258 self
259 }
260
261 pub fn with_auto_skills(self) -> Result<Self> {
263 self.with_skills_from_root(".")
264 }
265
266 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
268 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
269 self.skills_index = Some(Arc::new(index));
270 Ok(self)
271 }
272
273 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
275 self.skill_policy = policy;
276 self
277 }
278
279 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
281 self.max_skill_chars = max_chars;
282 self
283 }
284
285 pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
286 self.input_schema = Some(schema);
287 self
288 }
289
290 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
291 self.output_schema = Some(schema);
292 self
293 }
294
295 pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
296 self.disallow_transfer_to_parent = disallow;
297 self
298 }
299
300 pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
301 self.disallow_transfer_to_peers = disallow;
302 self
303 }
304
305 pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
306 self.include_contents = include;
307 self
308 }
309
310 pub fn output_key(mut self, key: impl Into<String>) -> Self {
311 self.output_key = Some(key.into());
312 self
313 }
314
315 pub fn generate_content_config(mut self, config: adk_core::GenerateContentConfig) -> Self {
336 self.generate_content_config = Some(config);
337 self
338 }
339
340 pub fn temperature(mut self, temperature: f32) -> Self {
343 self.generate_content_config
344 .get_or_insert(adk_core::GenerateContentConfig::default())
345 .temperature = Some(temperature);
346 self
347 }
348
349 pub fn top_p(mut self, top_p: f32) -> Self {
351 self.generate_content_config
352 .get_or_insert(adk_core::GenerateContentConfig::default())
353 .top_p = Some(top_p);
354 self
355 }
356
357 pub fn top_k(mut self, top_k: i32) -> Self {
359 self.generate_content_config
360 .get_or_insert(adk_core::GenerateContentConfig::default())
361 .top_k = Some(top_k);
362 self
363 }
364
365 pub fn max_output_tokens(mut self, max_tokens: i32) -> Self {
367 self.generate_content_config
368 .get_or_insert(adk_core::GenerateContentConfig::default())
369 .max_output_tokens = Some(max_tokens);
370 self
371 }
372
373 pub fn max_iterations(mut self, max: u32) -> Self {
376 self.max_iterations = max;
377 self
378 }
379
380 pub fn tool_timeout(mut self, timeout: std::time::Duration) -> Self {
383 self.tool_timeout = timeout;
384 self
385 }
386
387 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
388 self.tools.push(tool);
389 self
390 }
391
392 pub fn toolset(mut self, toolset: Arc<dyn Toolset>) -> Self {
398 self.toolsets.push(toolset);
399 self
400 }
401
402 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
403 self.sub_agents.push(agent);
404 self
405 }
406
407 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
408 self.before_callbacks.push(callback);
409 self
410 }
411
412 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
413 self.after_callbacks.push(callback);
414 self
415 }
416
417 pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
418 self.before_model_callbacks.push(callback);
419 self
420 }
421
422 pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
423 self.after_model_callbacks.push(callback);
424 self
425 }
426
427 pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
428 self.before_tool_callbacks.push(callback);
429 self
430 }
431
432 pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
433 self.after_tool_callbacks.push(callback);
434 self
435 }
436
437 pub fn after_tool_callback_full(mut self, callback: AfterToolCallbackFull) -> Self {
452 self.after_tool_callbacks_full.push(callback);
453 self
454 }
455
456 pub fn on_tool_error(mut self, callback: OnToolErrorCallback) -> Self {
464 self.on_tool_error_callbacks.push(callback);
465 self
466 }
467
468 pub fn default_retry_budget(mut self, budget: RetryBudget) -> Self {
475 self.default_retry_budget = Some(budget);
476 self
477 }
478
479 pub fn tool_retry_budget(mut self, tool_name: impl Into<String>, budget: RetryBudget) -> Self {
484 self.tool_retry_budgets.insert(tool_name.into(), budget);
485 self
486 }
487
488 pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
495 self.circuit_breaker_threshold = Some(threshold);
496 self
497 }
498
499 pub fn tool_confirmation_policy(mut self, policy: ToolConfirmationPolicy) -> Self {
501 self.tool_confirmation_policy = policy;
502 self
503 }
504
505 pub fn require_tool_confirmation(mut self, tool_name: impl Into<String>) -> Self {
507 self.tool_confirmation_policy = self.tool_confirmation_policy.with_tool(tool_name);
508 self
509 }
510
511 pub fn require_tool_confirmation_for_all(mut self) -> Self {
513 self.tool_confirmation_policy = ToolConfirmationPolicy::Always;
514 self
515 }
516
517 pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
526 self.input_guardrails = guardrails;
527 self
528 }
529
530 pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
539 self.output_guardrails = guardrails;
540 self
541 }
542
543 pub fn build(self) -> Result<LlmAgent> {
544 let model = self.model.ok_or_else(|| adk_core::AdkError::agent("Model is required"))?;
545
546 let mut seen_names = std::collections::HashSet::new();
547 for agent in &self.sub_agents {
548 if !seen_names.insert(agent.name()) {
549 return Err(adk_core::AdkError::agent(format!(
550 "Duplicate sub-agent name: {}",
551 agent.name()
552 )));
553 }
554 }
555
556 Ok(LlmAgent {
557 name: self.name,
558 description: self.description.unwrap_or_default(),
559 model,
560 instruction: self.instruction,
561 instruction_provider: self.instruction_provider,
562 global_instruction: self.global_instruction,
563 global_instruction_provider: self.global_instruction_provider,
564 skills_index: self.skills_index,
565 skill_policy: self.skill_policy,
566 max_skill_chars: self.max_skill_chars,
567 input_schema: self.input_schema,
568 output_schema: self.output_schema,
569 disallow_transfer_to_parent: self.disallow_transfer_to_parent,
570 disallow_transfer_to_peers: self.disallow_transfer_to_peers,
571 include_contents: self.include_contents,
572 tools: self.tools,
573 toolsets: self.toolsets,
574 sub_agents: self.sub_agents,
575 output_key: self.output_key,
576 generate_content_config: self.generate_content_config,
577 max_iterations: self.max_iterations,
578 tool_timeout: self.tool_timeout,
579 before_callbacks: Arc::new(self.before_callbacks),
580 after_callbacks: Arc::new(self.after_callbacks),
581 before_model_callbacks: Arc::new(self.before_model_callbacks),
582 after_model_callbacks: Arc::new(self.after_model_callbacks),
583 before_tool_callbacks: Arc::new(self.before_tool_callbacks),
584 after_tool_callbacks: Arc::new(self.after_tool_callbacks),
585 on_tool_error_callbacks: Arc::new(self.on_tool_error_callbacks),
586 after_tool_callbacks_full: Arc::new(self.after_tool_callbacks_full),
587 default_retry_budget: self.default_retry_budget,
588 tool_retry_budgets: self.tool_retry_budgets,
589 circuit_breaker_threshold: self.circuit_breaker_threshold,
590 tool_confirmation_policy: self.tool_confirmation_policy,
591 input_guardrails: Arc::new(self.input_guardrails),
592 output_guardrails: Arc::new(self.output_guardrails),
593 })
594 }
595}
596
597struct AgentToolContext {
600 parent_ctx: Arc<dyn InvocationContext>,
601 function_call_id: String,
602 actions: Mutex<EventActions>,
603}
604
605impl AgentToolContext {
606 fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
607 Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
608 }
609
610 fn actions_guard(&self) -> std::sync::MutexGuard<'_, EventActions> {
611 self.actions.lock().unwrap_or_else(|e| e.into_inner())
612 }
613}
614
615#[async_trait]
616impl ReadonlyContext for AgentToolContext {
617 fn invocation_id(&self) -> &str {
618 self.parent_ctx.invocation_id()
619 }
620
621 fn agent_name(&self) -> &str {
622 self.parent_ctx.agent_name()
623 }
624
625 fn user_id(&self) -> &str {
626 self.parent_ctx.user_id()
628 }
629
630 fn app_name(&self) -> &str {
631 self.parent_ctx.app_name()
633 }
634
635 fn session_id(&self) -> &str {
636 self.parent_ctx.session_id()
638 }
639
640 fn branch(&self) -> &str {
641 self.parent_ctx.branch()
642 }
643
644 fn user_content(&self) -> &Content {
645 self.parent_ctx.user_content()
646 }
647}
648
649#[async_trait]
650impl CallbackContext for AgentToolContext {
651 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
652 self.parent_ctx.artifacts()
654 }
655}
656
657#[async_trait]
658impl ToolContext for AgentToolContext {
659 fn function_call_id(&self) -> &str {
660 &self.function_call_id
661 }
662
663 fn actions(&self) -> EventActions {
664 self.actions_guard().clone()
665 }
666
667 fn set_actions(&self, actions: EventActions) {
668 *self.actions_guard() = actions;
669 }
670
671 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
672 if let Some(memory) = self.parent_ctx.memory() {
674 memory.search(query).await
675 } else {
676 Ok(vec![])
677 }
678 }
679
680 fn user_scopes(&self) -> Vec<String> {
681 self.parent_ctx.user_scopes()
682 }
683}
684
685struct ToolCallbackContext {
689 inner: Arc<dyn CallbackContext>,
690 outcome: ToolOutcome,
691}
692
693#[async_trait]
694impl ReadonlyContext for ToolCallbackContext {
695 fn invocation_id(&self) -> &str {
696 self.inner.invocation_id()
697 }
698
699 fn agent_name(&self) -> &str {
700 self.inner.agent_name()
701 }
702
703 fn user_id(&self) -> &str {
704 self.inner.user_id()
705 }
706
707 fn app_name(&self) -> &str {
708 self.inner.app_name()
709 }
710
711 fn session_id(&self) -> &str {
712 self.inner.session_id()
713 }
714
715 fn branch(&self) -> &str {
716 self.inner.branch()
717 }
718
719 fn user_content(&self) -> &Content {
720 self.inner.user_content()
721 }
722}
723
724#[async_trait]
725impl CallbackContext for ToolCallbackContext {
726 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
727 self.inner.artifacts()
728 }
729
730 fn tool_outcome(&self) -> Option<ToolOutcome> {
731 Some(self.outcome.clone())
732 }
733}
734
735struct CircuitBreakerState {
745 threshold: u32,
746 failures: std::collections::HashMap<String, u32>,
748}
749
750impl CircuitBreakerState {
751 fn new(threshold: u32) -> Self {
752 Self { threshold, failures: std::collections::HashMap::new() }
753 }
754
755 fn is_open(&self, tool_name: &str) -> bool {
757 self.failures.get(tool_name).copied().unwrap_or(0) >= self.threshold
758 }
759
760 fn record(&mut self, outcome: &ToolOutcome) {
762 if outcome.success {
763 self.failures.remove(&outcome.tool_name);
764 } else {
765 let count = self.failures.entry(outcome.tool_name.clone()).or_insert(0);
766 *count += 1;
767 }
768 }
769}
770
771#[async_trait]
772impl Agent for LlmAgent {
773 fn name(&self) -> &str {
774 &self.name
775 }
776
777 fn description(&self) -> &str {
778 &self.description
779 }
780
781 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
782 &self.sub_agents
783 }
784
785 #[adk_telemetry::instrument(
786 skip(self, ctx),
787 fields(
788 agent.name = %self.name,
789 agent.description = %self.description,
790 invocation.id = %ctx.invocation_id(),
791 user.id = %ctx.user_id(),
792 session.id = %ctx.session_id()
793 )
794 )]
795 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
796 adk_telemetry::info!("Starting agent execution");
797 let ctx = Self::apply_input_guardrails(ctx, self.input_guardrails.clone()).await?;
798
799 let agent_name = self.name.clone();
800 let invocation_id = ctx.invocation_id().to_string();
801 let model = self.model.clone();
802 let tools = self.tools.clone();
803 let toolsets = self.toolsets.clone();
804 let sub_agents = self.sub_agents.clone();
805
806 let instruction = self.instruction.clone();
807 let instruction_provider = self.instruction_provider.clone();
808 let global_instruction = self.global_instruction.clone();
809 let global_instruction_provider = self.global_instruction_provider.clone();
810 let skills_index = self.skills_index.clone();
811 let skill_policy = self.skill_policy.clone();
812 let max_skill_chars = self.max_skill_chars;
813 let output_key = self.output_key.clone();
814 let output_schema = self.output_schema.clone();
815 let generate_content_config = self.generate_content_config.clone();
816 let include_contents = self.include_contents;
817 let max_iterations = self.max_iterations;
818 let tool_timeout = self.tool_timeout;
819 let before_agent_callbacks = self.before_callbacks.clone();
821 let after_agent_callbacks = self.after_callbacks.clone();
822 let before_model_callbacks = self.before_model_callbacks.clone();
823 let after_model_callbacks = self.after_model_callbacks.clone();
824 let before_tool_callbacks = self.before_tool_callbacks.clone();
825 let after_tool_callbacks = self.after_tool_callbacks.clone();
826 let on_tool_error_callbacks = self.on_tool_error_callbacks.clone();
827 let after_tool_callbacks_full = self.after_tool_callbacks_full.clone();
828 let default_retry_budget = self.default_retry_budget.clone();
829 let tool_retry_budgets = self.tool_retry_budgets.clone();
830 let circuit_breaker_threshold = self.circuit_breaker_threshold;
831 let tool_confirmation_policy = self.tool_confirmation_policy.clone();
832 let disallow_transfer_to_parent = self.disallow_transfer_to_parent;
833 let disallow_transfer_to_peers = self.disallow_transfer_to_peers;
834 let output_guardrails = self.output_guardrails.clone();
835
836 let s = stream! {
837 for callback in before_agent_callbacks.as_ref() {
841 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
842 Ok(Some(content)) => {
843 let mut early_event = Event::new(&invocation_id);
845 early_event.author = agent_name.clone();
846 early_event.llm_response.content = Some(content);
847 yield Ok(early_event);
848
849 for after_callback in after_agent_callbacks.as_ref() {
851 match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
852 Ok(Some(after_content)) => {
853 let mut after_event = Event::new(&invocation_id);
854 after_event.author = agent_name.clone();
855 after_event.llm_response.content = Some(after_content);
856 yield Ok(after_event);
857 return;
858 }
859 Ok(None) => continue,
860 Err(e) => {
861 yield Err(e);
862 return;
863 }
864 }
865 }
866 return;
867 }
868 Ok(None) => {
869 continue;
871 }
872 Err(e) => {
873 yield Err(e);
875 return;
876 }
877 }
878 }
879
880 let mut prompt_preamble = Vec::new();
882
883 if let Some(index) = &skills_index {
887 let user_query = ctx
888 .user_content()
889 .parts
890 .iter()
891 .filter_map(|part| match part {
892 Part::Text { text } => Some(text.as_str()),
893 _ => None,
894 })
895 .collect::<Vec<_>>()
896 .join("\n");
897
898 if let Some((_matched, skill_block)) = select_skill_prompt_block(
899 index.as_ref(),
900 &user_query,
901 &skill_policy,
902 max_skill_chars,
903 ) {
904 prompt_preamble.push(Content {
905 role: "user".to_string(),
906 parts: vec![Part::Text { text: skill_block }],
907 });
908 }
909 }
910
911 if let Some(provider) = &global_instruction_provider {
914 let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
916 if !global_inst.is_empty() {
917 prompt_preamble.push(Content {
918 role: "user".to_string(),
919 parts: vec![Part::Text { text: global_inst }],
920 });
921 }
922 } else if let Some(ref template) = global_instruction {
923 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
925 if !processed.is_empty() {
926 prompt_preamble.push(Content {
927 role: "user".to_string(),
928 parts: vec![Part::Text { text: processed }],
929 });
930 }
931 }
932
933 if let Some(provider) = &instruction_provider {
936 let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
938 if !inst.is_empty() {
939 prompt_preamble.push(Content {
940 role: "user".to_string(),
941 parts: vec![Part::Text { text: inst }],
942 });
943 }
944 } else if let Some(ref template) = instruction {
945 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
947 if !processed.is_empty() {
948 prompt_preamble.push(Content {
949 role: "user".to_string(),
950 parts: vec![Part::Text { text: processed }],
951 });
952 }
953 }
954
955 let session_history = if !ctx.run_config().transfer_targets.is_empty() {
961 ctx.session().conversation_history_for_agent(&agent_name)
962 } else {
963 ctx.session().conversation_history()
964 };
965 let mut session_history = session_history;
966 let current_user_content = ctx.user_content().clone();
967 if let Some(index) = session_history.iter().rposition(|content| content.role == "user") {
968 session_history[index] = current_user_content.clone();
969 } else {
970 session_history.push(current_user_content.clone());
971 }
972
973 let mut conversation_history = match include_contents {
976 adk_core::IncludeContents::None => {
977 let mut filtered = prompt_preamble.clone();
978 filtered.push(current_user_content);
979 filtered
980 }
981 adk_core::IncludeContents::Default => {
982 let mut full_history = prompt_preamble;
983 full_history.extend(session_history);
984 full_history
985 }
986 };
987
988 let mut resolved_tools: Vec<Arc<dyn Tool>> = tools.clone();
991 let static_tool_names: std::collections::HashSet<String> =
992 tools.iter().map(|t| t.name().to_string()).collect();
993
994 let mut toolset_source: std::collections::HashMap<String, String> =
996 std::collections::HashMap::new();
997
998 for toolset in &toolsets {
999 let toolset_tools = match toolset
1000 .tools(ctx.clone() as Arc<dyn ReadonlyContext>)
1001 .await
1002 {
1003 Ok(t) => t,
1004 Err(e) => {
1005 yield Err(e);
1006 return;
1007 }
1008 };
1009 for tool in &toolset_tools {
1010 let name = tool.name().to_string();
1011 if static_tool_names.contains(&name) {
1013 yield Err(adk_core::AdkError::agent(format!(
1014 "Duplicate tool name '{name}': conflict between static tool and toolset '{}'",
1015 toolset.name()
1016 )));
1017 return;
1018 }
1019 if let Some(other_toolset_name) = toolset_source.get(&name) {
1021 yield Err(adk_core::AdkError::agent(format!(
1022 "Duplicate tool name '{name}': conflict between toolset '{}' and toolset '{}'",
1023 other_toolset_name,
1024 toolset.name()
1025 )));
1026 return;
1027 }
1028 toolset_source.insert(name, toolset.name().to_string());
1029 resolved_tools.push(tool.clone());
1030 }
1031 }
1032
1033 let tool_map: std::collections::HashMap<String, Arc<dyn Tool>> = resolved_tools
1035 .iter()
1036 .map(|t| (t.name().to_string(), t.clone()))
1037 .collect();
1038
1039 let collect_long_running_ids = |content: &Content| -> Vec<String> {
1041 content.parts.iter()
1042 .filter_map(|p| {
1043 if let Part::FunctionCall { name, .. } = p {
1044 if let Some(tool) = tool_map.get(name) {
1045 if tool.is_long_running() {
1046 return Some(name.clone());
1047 }
1048 }
1049 }
1050 None
1051 })
1052 .collect()
1053 };
1054
1055 let mut tool_declarations = std::collections::HashMap::new();
1060 for tool in &resolved_tools {
1061 tool_declarations.insert(tool.name().to_string(), tool.declaration());
1062 }
1063
1064 let mut valid_transfer_targets: Vec<String> = sub_agents
1069 .iter()
1070 .map(|a| a.name().to_string())
1071 .collect();
1072
1073 let run_config_targets = &ctx.run_config().transfer_targets;
1075 let parent_agent_name = ctx.run_config().parent_agent.clone();
1076 let sub_agent_names: std::collections::HashSet<&str> = sub_agents
1077 .iter()
1078 .map(|a| a.name())
1079 .collect();
1080
1081 for target in run_config_targets {
1082 if sub_agent_names.contains(target.as_str()) {
1084 continue;
1085 }
1086
1087 let is_parent = parent_agent_name.as_deref() == Some(target.as_str());
1089 if is_parent && disallow_transfer_to_parent {
1090 continue;
1091 }
1092 if !is_parent && disallow_transfer_to_peers {
1093 continue;
1094 }
1095
1096 valid_transfer_targets.push(target.clone());
1097 }
1098
1099 if !valid_transfer_targets.is_empty() {
1101 let transfer_tool_name = "transfer_to_agent";
1102 let transfer_tool_decl = serde_json::json!({
1103 "name": transfer_tool_name,
1104 "description": format!(
1105 "Transfer execution to another agent. Valid targets: {}",
1106 valid_transfer_targets.join(", ")
1107 ),
1108 "parameters": {
1109 "type": "object",
1110 "properties": {
1111 "agent_name": {
1112 "type": "string",
1113 "description": "The name of the agent to transfer to.",
1114 "enum": valid_transfer_targets
1115 }
1116 },
1117 "required": ["agent_name"]
1118 }
1119 });
1120 tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
1121 }
1122
1123
1124 let mut circuit_breaker_state = circuit_breaker_threshold.map(CircuitBreakerState::new);
1127
1128 let mut iteration = 0;
1130
1131 loop {
1132 iteration += 1;
1133 if iteration > max_iterations {
1134 yield Err(adk_core::AdkError::agent(
1135 format!("Max iterations ({max_iterations}) exceeded")
1136 ));
1137 return;
1138 }
1139
1140 let config = match (&generate_content_config, &output_schema) {
1147 (Some(base), Some(schema)) => {
1148 let mut merged = base.clone();
1149 merged.response_schema = Some(schema.clone());
1150 Some(merged)
1151 }
1152 (Some(base), None) => Some(base.clone()),
1153 (None, Some(schema)) => Some(adk_core::GenerateContentConfig {
1154 response_schema: Some(schema.clone()),
1155 ..Default::default()
1156 }),
1157 (None, None) => None,
1158 };
1159
1160 let config = if let Some(ref cached) = ctx.run_config().cached_content {
1162 let mut cfg = config.unwrap_or_default();
1163 if cfg.cached_content.is_none() {
1165 cfg.cached_content = Some(cached.clone());
1166 }
1167 Some(cfg)
1168 } else {
1169 config
1170 };
1171
1172 let request = LlmRequest {
1173 model: model.name().to_string(),
1174 contents: conversation_history.clone(),
1175 tools: tool_declarations.clone(),
1176 config,
1177 };
1178
1179 let mut current_request = request;
1182 let mut model_response_override = None;
1183 for callback in before_model_callbacks.as_ref() {
1184 match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
1185 Ok(BeforeModelResult::Continue(modified_request)) => {
1186 current_request = modified_request;
1188 }
1189 Ok(BeforeModelResult::Skip(response)) => {
1190 model_response_override = Some(response);
1192 break;
1193 }
1194 Err(e) => {
1195 yield Err(e);
1197 return;
1198 }
1199 }
1200 }
1201 let request = current_request;
1202
1203 let mut accumulated_content: Option<Content> = None;
1205 let mut final_provider_metadata: Option<serde_json::Value> = None;
1206
1207 if let Some(cached_response) = model_response_override {
1208 accumulated_content = cached_response.content.clone();
1211 final_provider_metadata = cached_response.provider_metadata.clone();
1212 normalize_option_content(&mut accumulated_content);
1213 if let Some(content) = accumulated_content.take() {
1214 let has_function_calls = content
1215 .parts
1216 .iter()
1217 .any(|part| matches!(part, Part::FunctionCall { .. }));
1218 let content = if has_function_calls {
1219 content
1220 } else {
1221 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1222 };
1223 accumulated_content = Some(content);
1224 }
1225
1226 let mut cached_event = Event::new(&invocation_id);
1227 cached_event.author = agent_name.clone();
1228 cached_event.llm_response.content = accumulated_content.clone();
1229 cached_event.llm_response.provider_metadata = cached_response.provider_metadata.clone();
1230 cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
1231 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), serde_json::to_string(&request).unwrap_or_default());
1232 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&cached_response).unwrap_or_default());
1233
1234 if let Some(ref content) = accumulated_content {
1236 cached_event.long_running_tool_ids = collect_long_running_ids(content);
1237 }
1238
1239 yield Ok(cached_event);
1240 } else {
1241 let request_json = serde_json::to_string(&request).unwrap_or_default();
1243
1244 let llm_ts = std::time::SystemTime::now()
1246 .duration_since(std::time::UNIX_EPOCH)
1247 .unwrap_or_default()
1248 .as_nanos();
1249 let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
1250 let llm_span = tracing::info_span!(
1251 "call_llm",
1252 "gcp.vertex.agent.event_id" = %llm_event_id,
1253 "gcp.vertex.agent.invocation_id" = %invocation_id,
1254 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1255 "gen_ai.conversation.id" = %ctx.session_id(),
1256 "gcp.vertex.agent.llm_request" = %request_json,
1257 "gcp.vertex.agent.llm_response" = tracing::field::Empty );
1259 let _llm_guard = llm_span.enter();
1260
1261 use adk_core::StreamingMode;
1263 let streaming_mode = ctx.run_config().streaming_mode;
1264 let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi)
1265 && output_guardrails.is_empty();
1266
1267 let mut response_stream = model.generate_content(request, true).await?;
1269
1270 use futures::StreamExt;
1271
1272 let mut last_chunk: Option<LlmResponse> = None;
1274
1275 while let Some(chunk_result) = response_stream.next().await {
1277 let mut chunk = match chunk_result {
1278 Ok(c) => c,
1279 Err(e) => {
1280 yield Err(e);
1281 return;
1282 }
1283 };
1284
1285 for callback in after_model_callbacks.as_ref() {
1288 match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
1289 Ok(Some(modified_chunk)) => {
1290 chunk = modified_chunk;
1292 break;
1293 }
1294 Ok(None) => {
1295 continue;
1297 }
1298 Err(e) => {
1299 yield Err(e);
1301 return;
1302 }
1303 }
1304 }
1305
1306 normalize_option_content(&mut chunk.content);
1307
1308 if let Some(chunk_content) = chunk.content.clone() {
1310 if let Some(ref mut acc) = accumulated_content {
1311 acc.parts.extend(chunk_content.parts);
1312 } else {
1313 accumulated_content = Some(chunk_content);
1314 }
1315 }
1316
1317 if should_stream_to_client {
1319 let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
1320 partial_event.author = agent_name.clone();
1321 partial_event.llm_request = Some(request_json.clone());
1322 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1323 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&chunk).unwrap_or_default());
1324 partial_event.llm_response.partial = chunk.partial;
1325 partial_event.llm_response.turn_complete = chunk.turn_complete;
1326 partial_event.llm_response.finish_reason = chunk.finish_reason;
1327 partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
1328 partial_event.llm_response.content = chunk.content.clone();
1329 partial_event.llm_response.provider_metadata = chunk.provider_metadata.clone();
1330
1331 if let Some(ref content) = chunk.content {
1333 partial_event.long_running_tool_ids = collect_long_running_ids(content);
1334 }
1335
1336 yield Ok(partial_event);
1337 }
1338
1339 last_chunk = Some(chunk.clone());
1341
1342 if chunk.turn_complete {
1344 break;
1345 }
1346 }
1347
1348 if !should_stream_to_client {
1350 if let Some(content) = accumulated_content.take() {
1351 let has_function_calls = content
1352 .parts
1353 .iter()
1354 .any(|part| matches!(part, Part::FunctionCall { .. }));
1355 let content = if has_function_calls {
1356 content
1357 } else {
1358 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1359 };
1360 accumulated_content = Some(content);
1361 }
1362
1363 let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
1364 final_event.author = agent_name.clone();
1365 final_event.llm_request = Some(request_json.clone());
1366 final_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1367 final_event.llm_response.content = accumulated_content.clone();
1368 final_event.llm_response.partial = false;
1369 final_event.llm_response.turn_complete = true;
1370
1371 if let Some(ref last) = last_chunk {
1373 final_event.llm_response.finish_reason = last.finish_reason;
1374 final_event.llm_response.usage_metadata = last.usage_metadata.clone();
1375 final_event.llm_response.provider_metadata = last.provider_metadata.clone();
1376 final_provider_metadata = last.provider_metadata.clone();
1377 final_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(last).unwrap_or_default());
1378 }
1379
1380 if let Some(ref content) = accumulated_content {
1382 final_event.long_running_tool_ids = collect_long_running_ids(content);
1383 }
1384
1385 yield Ok(final_event);
1386 }
1387
1388 if let Some(ref content) = accumulated_content {
1390 let response_json = serde_json::to_string(content).unwrap_or_default();
1391 llm_span.record("gcp.vertex.agent.llm_response", &response_json);
1392 }
1393 }
1394
1395 let function_call_names: Vec<String> = accumulated_content.as_ref()
1397 .map(|c| c.parts.iter()
1398 .filter_map(|p| {
1399 if let Part::FunctionCall { name, .. } = p {
1400 Some(name.clone())
1401 } else {
1402 None
1403 }
1404 })
1405 .collect())
1406 .unwrap_or_default();
1407
1408 let has_function_calls = !function_call_names.is_empty();
1409
1410 let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
1414 tool_map.get(name)
1415 .map(|t| t.is_long_running())
1416 .unwrap_or(false)
1417 });
1418
1419 if let Some(ref content) = accumulated_content {
1421 conversation_history.push(Self::augment_content_for_history(
1422 content,
1423 final_provider_metadata.as_ref(),
1424 ));
1425
1426 if let Some(ref output_key) = output_key {
1428 if !has_function_calls { let mut text_parts = String::new();
1430 for part in &content.parts {
1431 if let Part::Text { text } = part {
1432 text_parts.push_str(text);
1433 }
1434 }
1435 if !text_parts.is_empty() {
1436 let mut state_event = Event::new(&invocation_id);
1438 state_event.author = agent_name.clone();
1439 state_event.actions.state_delta.insert(
1440 output_key.clone(),
1441 serde_json::Value::String(text_parts),
1442 );
1443 yield Ok(state_event);
1444 }
1445 }
1446 }
1447 }
1448
1449 if !has_function_calls {
1450 if let Some(ref content) = accumulated_content {
1453 let response_json = serde_json::to_string(content).unwrap_or_default();
1454 tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
1455 }
1456
1457 tracing::info!(agent.name = %agent_name, "Agent execution complete");
1458 break;
1459 }
1460
1461 if let Some(content) = &accumulated_content {
1463 let mut tool_call_index = 0usize;
1464 for part in &content.parts {
1465 if let Part::FunctionCall { name, args, id, .. } = part {
1466 let fallback_call_id =
1467 format!("{}_{}_{}", invocation_id, name, tool_call_index);
1468 tool_call_index += 1;
1469 let function_call_id = id.clone().unwrap_or(fallback_call_id);
1470
1471 if name == "transfer_to_agent" {
1473 let target_agent = args.get("agent_name")
1474 .and_then(|v| v.as_str())
1475 .unwrap_or_default()
1476 .to_string();
1477
1478 let valid_target = valid_transfer_targets.iter().any(|n| n == &target_agent);
1481 if !valid_target {
1482 let error_content = Content {
1484 role: "function".to_string(),
1485 parts: vec![Part::FunctionResponse {
1486 function_response: FunctionResponseData {
1487 name: name.clone(),
1488 response: serde_json::json!({
1489 "error": format!(
1490 "Agent '{}' not found. Available agents: {:?}",
1491 target_agent,
1492 valid_transfer_targets
1493 )
1494 }),
1495 },
1496 id: id.clone(),
1497 }],
1498 };
1499 conversation_history.push(error_content.clone());
1500
1501 let mut error_event = Event::new(&invocation_id);
1502 error_event.author = agent_name.clone();
1503 error_event.llm_response.content = Some(error_content);
1504 yield Ok(error_event);
1505 continue;
1506 }
1507
1508 let mut transfer_event = Event::new(&invocation_id);
1509 transfer_event.author = agent_name.clone();
1510 transfer_event.actions.transfer_to_agent = Some(target_agent);
1511
1512 yield Ok(transfer_event);
1513 return;
1514 }
1515
1516 let mut tool_actions = EventActions::default();
1519 let mut response_content: Option<Content> = None;
1520 let mut run_after_tool_callbacks = true;
1521 let mut tool_outcome_for_callback: Option<ToolOutcome> = None;
1522 let mut executed_tool: Option<Arc<dyn Tool>> = None;
1524 let mut executed_tool_response: Option<serde_json::Value> = None;
1525
1526 if tool_confirmation_policy.requires_confirmation(name) {
1531 match ctx.run_config().tool_confirmation_decisions.get(name).copied()
1532 {
1533 Some(ToolConfirmationDecision::Approve) => {
1534 tool_actions.tool_confirmation_decision =
1535 Some(ToolConfirmationDecision::Approve);
1536 }
1537 Some(ToolConfirmationDecision::Deny) => {
1538 tool_actions.tool_confirmation_decision =
1539 Some(ToolConfirmationDecision::Deny);
1540 response_content = Some(Content {
1541 role: "function".to_string(),
1542 parts: vec![Part::FunctionResponse {
1543 function_response: FunctionResponseData {
1544 name: name.clone(),
1545 response: serde_json::json!({
1546 "error": format!(
1547 "Tool '{}' execution denied by confirmation policy",
1548 name
1549 ),
1550 }),
1551 },
1552 id: id.clone(),
1553 }],
1554 });
1555 run_after_tool_callbacks = false;
1556 }
1557 None => {
1558 let mut confirmation_event = Event::new(&invocation_id);
1559 confirmation_event.author = agent_name.clone();
1560 confirmation_event.llm_response.interrupted = true;
1561 confirmation_event.llm_response.turn_complete = true;
1562 confirmation_event.llm_response.content = Some(Content {
1563 role: "model".to_string(),
1564 parts: vec![Part::Text {
1565 text: format!(
1566 "Tool confirmation required for '{}'. Provide approve/deny decision to continue.",
1567 name
1568 ),
1569 }],
1570 });
1571 confirmation_event.actions.tool_confirmation =
1572 Some(ToolConfirmationRequest {
1573 tool_name: name.clone(),
1574 function_call_id: Some(function_call_id),
1575 args: args.clone(),
1576 });
1577 yield Ok(confirmation_event);
1578 return;
1579 }
1580 }
1581 }
1582
1583 if response_content.is_none() {
1584 for callback in before_tool_callbacks.as_ref() {
1585 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1586 Ok(Some(content)) => {
1587 response_content = Some(content);
1588 break;
1589 }
1590 Ok(None) => continue,
1591 Err(e) => {
1592 yield Err(e);
1593 return;
1594 }
1595 }
1596 }
1597 }
1598
1599 if response_content.is_none() {
1601 if let Some(ref cb_state) = circuit_breaker_state {
1605 if cb_state.is_open(name) {
1606 let error_msg = format!(
1607 "Tool '{}' is temporarily disabled after {} consecutive failures",
1608 name, cb_state.threshold
1609 );
1610 tracing::warn!(tool.name = %name, "circuit breaker open, skipping tool execution");
1611 response_content = Some(Content {
1612 role: "function".to_string(),
1613 parts: vec![Part::FunctionResponse {
1614 function_response: FunctionResponseData {
1615 name: name.clone(),
1616 response: serde_json::json!({ "error": error_msg }),
1617 },
1618 id: id.clone(),
1619 }],
1620 });
1621 run_after_tool_callbacks = false;
1622 }
1623 }
1624 }
1625
1626 if response_content.is_none() {
1627 if let Some(tool) = tool_map.get(name) {
1628 if tool.is_builtin() {
1631 adk_telemetry::debug!(tool.name = %name, "skipping built-in tool execution");
1632 continue;
1633 }
1634 let tool_ctx: Arc<dyn ToolContext> = Arc::new(AgentToolContext::new(
1636 ctx.clone(),
1637 function_call_id.clone(),
1638 ));
1639
1640 let span_name = format!("execute_tool {}", name);
1642 let tool_span = tracing::info_span!(
1643 "",
1644 otel.name = %span_name,
1645 tool.name = %name,
1646 "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
1647 "gcp.vertex.agent.invocation_id" = %invocation_id,
1648 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1649 "gen_ai.conversation.id" = %ctx.session_id()
1650 );
1651
1652 let budget = tool_retry_budgets.get(name)
1656 .or(default_retry_budget.as_ref());
1657 let max_attempts = budget.map(|b| b.max_retries + 1).unwrap_or(1);
1658 let retry_delay = budget.map(|b| b.delay).unwrap_or_default();
1659
1660 let tool_clone = tool.clone();
1663 let tool_start = std::time::Instant::now();
1664
1665 let mut last_error = String::new();
1666 let mut final_attempt: u32 = 0;
1667 let mut retry_result: Option<serde_json::Value> = None;
1668
1669 for attempt in 0..max_attempts {
1670 final_attempt = attempt;
1671
1672 if attempt > 0 {
1673 tokio::time::sleep(retry_delay).await;
1674 }
1675
1676 match async {
1677 tracing::info!(tool.name = %name, tool.args = %args, attempt = attempt, "tool_call");
1678 let exec_future = tool_clone.execute(tool_ctx.clone(), args.clone());
1679 tokio::time::timeout(tool_timeout, exec_future).await
1680 }.instrument(tool_span.clone()).await {
1681 Ok(Ok(value)) => {
1682 tracing::info!(tool.name = %name, tool.result = %value, "tool_result");
1683 retry_result = Some(value);
1684 break;
1685 }
1686 Ok(Err(e)) => {
1687 last_error = e.to_string();
1688 if attempt + 1 < max_attempts {
1689 tracing::warn!(tool.name = %name, attempt = attempt, error = %last_error, "tool execution failed, retrying");
1690 } else {
1691 tracing::warn!(tool.name = %name, error = %last_error, "tool_error");
1692 }
1693 }
1694 Err(_) => {
1695 last_error = format!(
1696 "Tool '{}' timed out after {} seconds",
1697 name, tool_timeout.as_secs()
1698 );
1699 if attempt + 1 < max_attempts {
1700 tracing::warn!(tool.name = %name, attempt = attempt, timeout_secs = tool_timeout.as_secs(), "tool timed out, retrying");
1701 } else {
1702 tracing::warn!(tool.name = %name, timeout_secs = tool_timeout.as_secs(), "tool_timeout");
1703 }
1704 }
1705 }
1706 }
1707
1708 let tool_duration = tool_start.elapsed();
1709
1710 let (tool_success, tool_error_message, function_response) = match retry_result {
1712 Some(value) => (true, None, value),
1713 None => (false, Some(last_error.clone()), serde_json::json!({ "error": last_error })),
1714 };
1715
1716 let outcome = ToolOutcome {
1719 tool_name: name.clone(),
1720 tool_args: args.clone(),
1721 success: tool_success,
1722 duration: tool_duration,
1723 error_message: tool_error_message.clone(),
1724 attempt: final_attempt,
1725 };
1726 tool_outcome_for_callback = Some(outcome);
1727
1728 if let Some(ref mut cb_state) = circuit_breaker_state {
1730 cb_state.record(tool_outcome_for_callback.as_ref().unwrap());
1731 }
1732
1733 let final_function_response = if !tool_success {
1738 let mut fallback_result = None;
1739 let error_msg = tool_error_message.clone().unwrap_or_default();
1740 for callback in on_tool_error_callbacks.as_ref() {
1741 match callback(
1742 ctx.clone() as Arc<dyn CallbackContext>,
1743 tool.clone(),
1744 args.clone(),
1745 error_msg.clone(),
1746 ).await {
1747 Ok(Some(result)) => {
1748 fallback_result = Some(result);
1749 break;
1750 }
1751 Ok(None) => continue,
1752 Err(e) => {
1753 tracing::warn!(error = %e, "on_tool_error callback failed");
1754 break;
1755 }
1756 }
1757 }
1758 if let Some(fallback) = fallback_result {
1759 fallback
1760 } else {
1761 function_response
1762 }
1763 } else {
1764 function_response
1765 };
1766
1767 let confirmation_decision =
1768 tool_actions.tool_confirmation_decision;
1769 tool_actions = tool_ctx.actions();
1770 if tool_actions.tool_confirmation_decision.is_none() {
1771 tool_actions.tool_confirmation_decision =
1772 confirmation_decision;
1773 }
1774 executed_tool = Some(tool.clone());
1776 executed_tool_response = Some(final_function_response.clone());
1777 response_content = Some(Content {
1778 role: "function".to_string(),
1779 parts: vec![Part::FunctionResponse {
1780 function_response: FunctionResponseData {
1781 name: name.clone(),
1782 response: final_function_response,
1783 },
1784 id: id.clone(),
1785 }],
1786 });
1787 } else {
1788 response_content = Some(Content {
1789 role: "function".to_string(),
1790 parts: vec![Part::FunctionResponse {
1791 function_response: FunctionResponseData {
1792 name: name.clone(),
1793 response: serde_json::json!({
1794 "error": format!("Tool {} not found", name)
1795 }),
1796 },
1797 id: id.clone(),
1798 }],
1799 });
1800 }
1801 }
1802
1803 let mut response_content = response_content.expect("tool response content is set");
1806 if run_after_tool_callbacks {
1807 let cb_ctx: Arc<dyn CallbackContext> = match tool_outcome_for_callback {
1809 Some(outcome) => Arc::new(ToolCallbackContext {
1810 inner: ctx.clone() as Arc<dyn CallbackContext>,
1811 outcome,
1812 }),
1813 None => ctx.clone() as Arc<dyn CallbackContext>,
1814 };
1815 for callback in after_tool_callbacks.as_ref() {
1816 match callback(cb_ctx.clone()).await {
1817 Ok(Some(modified_content)) => {
1818 response_content = modified_content;
1819 break;
1820 }
1821 Ok(None) => continue,
1822 Err(e) => {
1823 yield Err(e);
1824 return;
1825 }
1826 }
1827 }
1828
1829 if let (Some(tool_ref), Some(tool_resp)) = (&executed_tool, executed_tool_response) {
1834 for callback in after_tool_callbacks_full.as_ref() {
1835 match callback(
1836 cb_ctx.clone(),
1837 tool_ref.clone(),
1838 args.clone(),
1839 tool_resp.clone(),
1840 ).await {
1841 Ok(Some(modified_value)) => {
1842 response_content = Content {
1844 role: "function".to_string(),
1845 parts: vec![Part::FunctionResponse {
1846 function_response: FunctionResponseData {
1847 name: name.clone(),
1848 response: modified_value,
1849 },
1850 id: id.clone(),
1851 }],
1852 };
1853 break;
1854 }
1855 Ok(None) => continue,
1856 Err(e) => {
1857 yield Err(e);
1858 return;
1859 }
1860 }
1861 }
1862 }
1863 }
1864
1865 let mut tool_event = Event::new(&invocation_id);
1867 tool_event.author = agent_name.clone();
1868 tool_event.actions = tool_actions.clone();
1869 tool_event.llm_response.content = Some(response_content.clone());
1870 yield Ok(tool_event);
1871
1872 if tool_actions.escalate || tool_actions.skip_summarization {
1874 return;
1876 }
1877
1878 conversation_history.push(response_content);
1880 }
1881 }
1882 }
1883
1884 if all_calls_are_long_running {
1888 }
1892 }
1893
1894 for callback in after_agent_callbacks.as_ref() {
1897 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1898 Ok(Some(content)) => {
1899 let mut after_event = Event::new(&invocation_id);
1901 after_event.author = agent_name.clone();
1902 after_event.llm_response.content = Some(content);
1903 yield Ok(after_event);
1904 break; }
1906 Ok(None) => {
1907 continue;
1909 }
1910 Err(e) => {
1911 yield Err(e);
1913 return;
1914 }
1915 }
1916 }
1917 };
1918
1919 Ok(Box::pin(s))
1920 }
1921}