1use adk_core::{
2 AfterAgentCallback, AfterModelCallback, AfterToolCallback, AfterToolCallbackFull, Agent,
3 BeforeAgentCallback, BeforeModelCallback, BeforeModelResult, BeforeToolCallback,
4 CallbackContext, Content, Event, EventActions, FunctionResponseData, GlobalInstructionProvider,
5 InstructionProvider, InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry,
6 OnToolErrorCallback, Part, ReadonlyContext, Result, RetryBudget, Tool, ToolCallbackContext,
7 ToolConfirmationDecision, ToolConfirmationPolicy, ToolConfirmationRequest, ToolContext,
8 ToolExecutionStrategy, ToolOutcome, Toolset,
9};
10use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index, select_skill_prompt_block};
11use async_stream::stream;
12use async_trait::async_trait;
13use std::sync::{Arc, Mutex};
14use tracing::Instrument;
15
16use crate::{
17 guardrails::{GuardrailSet, enforce_guardrails},
18 tool_call_markup::normalize_option_content,
19 workflow::with_user_content_override,
20};
21
22pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
24
25pub const DEFAULT_TOOL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
27
28pub struct LlmAgent {
29 name: String,
30 description: String,
31 model: Arc<dyn Llm>,
32 instruction: Option<String>,
33 instruction_provider: Option<Arc<InstructionProvider>>,
34 global_instruction: Option<String>,
35 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
36 skills_index: Option<Arc<SkillIndex>>,
37 skill_policy: SelectionPolicy,
38 max_skill_chars: usize,
39 #[allow(dead_code)] input_schema: Option<serde_json::Value>,
41 output_schema: Option<serde_json::Value>,
42 disallow_transfer_to_parent: bool,
43 disallow_transfer_to_peers: bool,
44 include_contents: adk_core::IncludeContents,
45 tools: Vec<Arc<dyn Tool>>,
46 #[allow(dead_code)] toolsets: Vec<Arc<dyn Toolset>>,
48 sub_agents: Vec<Arc<dyn Agent>>,
49 output_key: Option<String>,
50 generate_content_config: Option<adk_core::GenerateContentConfig>,
52 max_iterations: u32,
54 tool_timeout: std::time::Duration,
56 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
57 after_callbacks: Arc<Vec<AfterAgentCallback>>,
58 before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
59 after_model_callbacks: Arc<Vec<AfterModelCallback>>,
60 before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
61 after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
62 on_tool_error_callbacks: Arc<Vec<OnToolErrorCallback>>,
63 after_tool_callbacks_full: Arc<Vec<AfterToolCallbackFull>>,
65 default_retry_budget: Option<RetryBudget>,
67 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
69 circuit_breaker_threshold: Option<u32>,
72 tool_confirmation_policy: ToolConfirmationPolicy,
73 tool_execution_strategy: Option<ToolExecutionStrategy>,
76 input_guardrails: Arc<GuardrailSet>,
77 output_guardrails: Arc<GuardrailSet>,
78}
79
80impl std::fmt::Debug for LlmAgent {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("LlmAgent")
83 .field("name", &self.name)
84 .field("description", &self.description)
85 .field("model", &self.model.name())
86 .field("instruction", &self.instruction)
87 .field("tools_count", &self.tools.len())
88 .field("sub_agents_count", &self.sub_agents.len())
89 .finish()
90 }
91}
92
93impl LlmAgent {
94 async fn apply_input_guardrails(
95 ctx: Arc<dyn InvocationContext>,
96 input_guardrails: Arc<GuardrailSet>,
97 ) -> Result<Arc<dyn InvocationContext>> {
98 let content =
99 enforce_guardrails(input_guardrails.as_ref(), ctx.user_content(), "input").await?;
100 if content.role != ctx.user_content().role || content.parts != ctx.user_content().parts {
101 Ok(with_user_content_override(ctx, content))
102 } else {
103 Ok(ctx)
104 }
105 }
106
107 async fn apply_output_guardrails(
108 output_guardrails: &GuardrailSet,
109 content: Content,
110 ) -> Result<Content> {
111 enforce_guardrails(output_guardrails, &content, "output").await
112 }
113
114 fn history_parts_from_provider_metadata(
115 provider_metadata: Option<&serde_json::Value>,
116 ) -> Vec<Part> {
117 let Some(provider_metadata) = provider_metadata else {
118 return Vec::new();
119 };
120
121 let history_parts = provider_metadata
122 .get("conversation_history_parts")
123 .or_else(|| {
124 provider_metadata
125 .get("openai")
126 .and_then(|openai| openai.get("conversation_history_parts"))
127 })
128 .and_then(serde_json::Value::as_array);
129
130 history_parts
131 .into_iter()
132 .flatten()
133 .filter_map(|value| serde_json::from_value::<Part>(value.clone()).ok())
134 .collect()
135 }
136
137 fn augment_content_for_history(
138 content: &Content,
139 provider_metadata: Option<&serde_json::Value>,
140 ) -> Content {
141 let mut augmented = content.clone();
142 augmented.parts.extend(Self::history_parts_from_provider_metadata(provider_metadata));
143 augmented
144 }
145}
146
147pub struct LlmAgentBuilder {
148 name: String,
149 description: Option<String>,
150 model: Option<Arc<dyn Llm>>,
151 instruction: Option<String>,
152 instruction_provider: Option<Arc<InstructionProvider>>,
153 global_instruction: Option<String>,
154 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
155 skills_index: Option<Arc<SkillIndex>>,
156 skill_policy: SelectionPolicy,
157 max_skill_chars: usize,
158 input_schema: Option<serde_json::Value>,
159 output_schema: Option<serde_json::Value>,
160 disallow_transfer_to_parent: bool,
161 disallow_transfer_to_peers: bool,
162 include_contents: adk_core::IncludeContents,
163 tools: Vec<Arc<dyn Tool>>,
164 toolsets: Vec<Arc<dyn Toolset>>,
165 sub_agents: Vec<Arc<dyn Agent>>,
166 output_key: Option<String>,
167 generate_content_config: Option<adk_core::GenerateContentConfig>,
168 max_iterations: u32,
169 tool_timeout: std::time::Duration,
170 before_callbacks: Vec<BeforeAgentCallback>,
171 after_callbacks: Vec<AfterAgentCallback>,
172 before_model_callbacks: Vec<BeforeModelCallback>,
173 after_model_callbacks: Vec<AfterModelCallback>,
174 before_tool_callbacks: Vec<BeforeToolCallback>,
175 after_tool_callbacks: Vec<AfterToolCallback>,
176 on_tool_error_callbacks: Vec<OnToolErrorCallback>,
177 after_tool_callbacks_full: Vec<AfterToolCallbackFull>,
178 default_retry_budget: Option<RetryBudget>,
179 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
180 circuit_breaker_threshold: Option<u32>,
181 tool_confirmation_policy: ToolConfirmationPolicy,
182 tool_execution_strategy: Option<ToolExecutionStrategy>,
183 input_guardrails: GuardrailSet,
184 output_guardrails: GuardrailSet,
185}
186
187impl LlmAgentBuilder {
188 pub fn new(name: impl Into<String>) -> Self {
189 Self {
190 name: name.into(),
191 description: None,
192 model: None,
193 instruction: None,
194 instruction_provider: None,
195 global_instruction: None,
196 global_instruction_provider: None,
197 skills_index: None,
198 skill_policy: SelectionPolicy::default(),
199 max_skill_chars: 2000,
200 input_schema: None,
201 output_schema: None,
202 disallow_transfer_to_parent: false,
203 disallow_transfer_to_peers: false,
204 include_contents: adk_core::IncludeContents::Default,
205 tools: Vec::new(),
206 toolsets: Vec::new(),
207 sub_agents: Vec::new(),
208 output_key: None,
209 generate_content_config: None,
210 max_iterations: DEFAULT_MAX_ITERATIONS,
211 tool_timeout: DEFAULT_TOOL_TIMEOUT,
212 before_callbacks: Vec::new(),
213 after_callbacks: Vec::new(),
214 before_model_callbacks: Vec::new(),
215 after_model_callbacks: Vec::new(),
216 before_tool_callbacks: Vec::new(),
217 after_tool_callbacks: Vec::new(),
218 on_tool_error_callbacks: Vec::new(),
219 after_tool_callbacks_full: Vec::new(),
220 default_retry_budget: None,
221 tool_retry_budgets: std::collections::HashMap::new(),
222 circuit_breaker_threshold: None,
223 tool_confirmation_policy: ToolConfirmationPolicy::Never,
224 tool_execution_strategy: None,
225 input_guardrails: GuardrailSet::new(),
226 output_guardrails: GuardrailSet::new(),
227 }
228 }
229
230 pub fn description(mut self, desc: impl Into<String>) -> Self {
231 self.description = Some(desc.into());
232 self
233 }
234
235 pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
236 self.model = Some(model);
237 self
238 }
239
240 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
241 self.instruction = Some(instruction.into());
242 self
243 }
244
245 pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
246 self.instruction_provider = Some(Arc::new(provider));
247 self
248 }
249
250 pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
251 self.global_instruction = Some(instruction.into());
252 self
253 }
254
255 pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
256 self.global_instruction_provider = Some(Arc::new(provider));
257 self
258 }
259
260 pub fn with_skills(mut self, index: SkillIndex) -> Self {
262 self.skills_index = Some(Arc::new(index));
263 self
264 }
265
266 pub fn with_auto_skills(self) -> Result<Self> {
268 self.with_skills_from_root(".")
269 }
270
271 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
273 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
274 self.skills_index = Some(Arc::new(index));
275 Ok(self)
276 }
277
278 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
280 self.skill_policy = policy;
281 self
282 }
283
284 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
286 self.max_skill_chars = max_chars;
287 self
288 }
289
290 pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
291 self.input_schema = Some(schema);
292 self
293 }
294
295 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
296 self.output_schema = Some(schema);
297 self
298 }
299
300 pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
301 self.disallow_transfer_to_parent = disallow;
302 self
303 }
304
305 pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
306 self.disallow_transfer_to_peers = disallow;
307 self
308 }
309
310 pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
311 self.include_contents = include;
312 self
313 }
314
315 pub fn output_key(mut self, key: impl Into<String>) -> Self {
316 self.output_key = Some(key.into());
317 self
318 }
319
320 pub fn generate_content_config(mut self, config: adk_core::GenerateContentConfig) -> Self {
341 self.generate_content_config = Some(config);
342 self
343 }
344
345 pub fn temperature(mut self, temperature: f32) -> Self {
348 self.generate_content_config
349 .get_or_insert(adk_core::GenerateContentConfig::default())
350 .temperature = Some(temperature);
351 self
352 }
353
354 pub fn top_p(mut self, top_p: f32) -> Self {
356 self.generate_content_config
357 .get_or_insert(adk_core::GenerateContentConfig::default())
358 .top_p = Some(top_p);
359 self
360 }
361
362 pub fn top_k(mut self, top_k: i32) -> Self {
364 self.generate_content_config
365 .get_or_insert(adk_core::GenerateContentConfig::default())
366 .top_k = Some(top_k);
367 self
368 }
369
370 pub fn max_output_tokens(mut self, max_tokens: i32) -> Self {
372 self.generate_content_config
373 .get_or_insert(adk_core::GenerateContentConfig::default())
374 .max_output_tokens = Some(max_tokens);
375 self
376 }
377
378 pub fn max_iterations(mut self, max: u32) -> Self {
381 self.max_iterations = max;
382 self
383 }
384
385 pub fn tool_timeout(mut self, timeout: std::time::Duration) -> Self {
388 self.tool_timeout = timeout;
389 self
390 }
391
392 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
393 self.tools.push(tool);
394 self
395 }
396
397 pub fn toolset(mut self, toolset: Arc<dyn Toolset>) -> Self {
403 self.toolsets.push(toolset);
404 self
405 }
406
407 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
408 self.sub_agents.push(agent);
409 self
410 }
411
412 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
413 self.before_callbacks.push(callback);
414 self
415 }
416
417 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
418 self.after_callbacks.push(callback);
419 self
420 }
421
422 pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
423 self.before_model_callbacks.push(callback);
424 self
425 }
426
427 pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
428 self.after_model_callbacks.push(callback);
429 self
430 }
431
432 pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
433 self.before_tool_callbacks.push(callback);
434 self
435 }
436
437 pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
438 self.after_tool_callbacks.push(callback);
439 self
440 }
441
442 pub fn after_tool_callback_full(mut self, callback: AfterToolCallbackFull) -> Self {
457 self.after_tool_callbacks_full.push(callback);
458 self
459 }
460
461 pub fn on_tool_error(mut self, callback: OnToolErrorCallback) -> Self {
469 self.on_tool_error_callbacks.push(callback);
470 self
471 }
472
473 pub fn default_retry_budget(mut self, budget: RetryBudget) -> Self {
480 self.default_retry_budget = Some(budget);
481 self
482 }
483
484 pub fn tool_retry_budget(mut self, tool_name: impl Into<String>, budget: RetryBudget) -> Self {
489 self.tool_retry_budgets.insert(tool_name.into(), budget);
490 self
491 }
492
493 pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
500 self.circuit_breaker_threshold = Some(threshold);
501 self
502 }
503
504 pub fn tool_confirmation_policy(mut self, policy: ToolConfirmationPolicy) -> Self {
506 self.tool_confirmation_policy = policy;
507 self
508 }
509
510 pub fn require_tool_confirmation(mut self, tool_name: impl Into<String>) -> Self {
512 self.tool_confirmation_policy = self.tool_confirmation_policy.with_tool(tool_name);
513 self
514 }
515
516 pub fn require_tool_confirmation_for_all(mut self) -> Self {
518 self.tool_confirmation_policy = ToolConfirmationPolicy::Always;
519 self
520 }
521
522 pub fn tool_execution_strategy(mut self, strategy: ToolExecutionStrategy) -> Self {
528 self.tool_execution_strategy = Some(strategy);
529 self
530 }
531
532 pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
541 self.input_guardrails = guardrails;
542 self
543 }
544
545 pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
554 self.output_guardrails = guardrails;
555 self
556 }
557
558 pub fn build(self) -> Result<LlmAgent> {
559 let model = self.model.ok_or_else(|| adk_core::AdkError::agent("Model is required"))?;
560
561 let mut seen_names = std::collections::HashSet::new();
562 for agent in &self.sub_agents {
563 if !seen_names.insert(agent.name()) {
564 return Err(adk_core::AdkError::agent(format!(
565 "Duplicate sub-agent name: {}",
566 agent.name()
567 )));
568 }
569 }
570
571 Ok(LlmAgent {
572 name: self.name,
573 description: self.description.unwrap_or_default(),
574 model,
575 instruction: self.instruction,
576 instruction_provider: self.instruction_provider,
577 global_instruction: self.global_instruction,
578 global_instruction_provider: self.global_instruction_provider,
579 skills_index: self.skills_index,
580 skill_policy: self.skill_policy,
581 max_skill_chars: self.max_skill_chars,
582 input_schema: self.input_schema,
583 output_schema: self.output_schema,
584 disallow_transfer_to_parent: self.disallow_transfer_to_parent,
585 disallow_transfer_to_peers: self.disallow_transfer_to_peers,
586 include_contents: self.include_contents,
587 tools: self.tools,
588 toolsets: self.toolsets,
589 sub_agents: self.sub_agents,
590 output_key: self.output_key,
591 generate_content_config: self.generate_content_config,
592 max_iterations: self.max_iterations,
593 tool_timeout: self.tool_timeout,
594 before_callbacks: Arc::new(self.before_callbacks),
595 after_callbacks: Arc::new(self.after_callbacks),
596 before_model_callbacks: Arc::new(self.before_model_callbacks),
597 after_model_callbacks: Arc::new(self.after_model_callbacks),
598 before_tool_callbacks: Arc::new(self.before_tool_callbacks),
599 after_tool_callbacks: Arc::new(self.after_tool_callbacks),
600 on_tool_error_callbacks: Arc::new(self.on_tool_error_callbacks),
601 after_tool_callbacks_full: Arc::new(self.after_tool_callbacks_full),
602 default_retry_budget: self.default_retry_budget,
603 tool_retry_budgets: self.tool_retry_budgets,
604 circuit_breaker_threshold: self.circuit_breaker_threshold,
605 tool_confirmation_policy: self.tool_confirmation_policy,
606 tool_execution_strategy: self.tool_execution_strategy,
607 input_guardrails: Arc::new(self.input_guardrails),
608 output_guardrails: Arc::new(self.output_guardrails),
609 })
610 }
611}
612
613struct AgentToolContext {
616 parent_ctx: Arc<dyn InvocationContext>,
617 function_call_id: String,
618 actions: Mutex<EventActions>,
619}
620
621impl AgentToolContext {
622 fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
623 Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
624 }
625
626 fn actions_guard(&self) -> std::sync::MutexGuard<'_, EventActions> {
627 self.actions.lock().unwrap_or_else(|e| e.into_inner())
628 }
629}
630
631#[async_trait]
632impl ReadonlyContext for AgentToolContext {
633 fn invocation_id(&self) -> &str {
634 self.parent_ctx.invocation_id()
635 }
636
637 fn agent_name(&self) -> &str {
638 self.parent_ctx.agent_name()
639 }
640
641 fn user_id(&self) -> &str {
642 self.parent_ctx.user_id()
644 }
645
646 fn app_name(&self) -> &str {
647 self.parent_ctx.app_name()
649 }
650
651 fn session_id(&self) -> &str {
652 self.parent_ctx.session_id()
654 }
655
656 fn branch(&self) -> &str {
657 self.parent_ctx.branch()
658 }
659
660 fn user_content(&self) -> &Content {
661 self.parent_ctx.user_content()
662 }
663}
664
665#[async_trait]
666impl CallbackContext for AgentToolContext {
667 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
668 self.parent_ctx.artifacts()
670 }
671
672 fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
673 self.parent_ctx.shared_state()
674 }
675}
676
677#[async_trait]
678impl ToolContext for AgentToolContext {
679 fn function_call_id(&self) -> &str {
680 &self.function_call_id
681 }
682
683 fn actions(&self) -> EventActions {
684 self.actions_guard().clone()
685 }
686
687 fn set_actions(&self, actions: EventActions) {
688 *self.actions_guard() = actions;
689 }
690
691 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
692 if let Some(memory) = self.parent_ctx.memory() {
694 memory.search(query).await
695 } else {
696 Ok(vec![])
697 }
698 }
699
700 fn user_scopes(&self) -> Vec<String> {
701 self.parent_ctx.user_scopes()
702 }
703}
704
705struct ToolOutcomeCallbackContext {
709 inner: Arc<dyn CallbackContext>,
710 outcome: ToolOutcome,
711}
712
713#[async_trait]
714impl ReadonlyContext for ToolOutcomeCallbackContext {
715 fn invocation_id(&self) -> &str {
716 self.inner.invocation_id()
717 }
718
719 fn agent_name(&self) -> &str {
720 self.inner.agent_name()
721 }
722
723 fn user_id(&self) -> &str {
724 self.inner.user_id()
725 }
726
727 fn app_name(&self) -> &str {
728 self.inner.app_name()
729 }
730
731 fn session_id(&self) -> &str {
732 self.inner.session_id()
733 }
734
735 fn branch(&self) -> &str {
736 self.inner.branch()
737 }
738
739 fn user_content(&self) -> &Content {
740 self.inner.user_content()
741 }
742}
743
744#[async_trait]
745impl CallbackContext for ToolOutcomeCallbackContext {
746 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
747 self.inner.artifacts()
748 }
749
750 fn tool_outcome(&self) -> Option<ToolOutcome> {
751 Some(self.outcome.clone())
752 }
753}
754
755struct CircuitBreakerState {
765 threshold: u32,
766 failures: std::collections::HashMap<String, u32>,
768}
769
770impl CircuitBreakerState {
771 fn new(threshold: u32) -> Self {
772 Self { threshold, failures: std::collections::HashMap::new() }
773 }
774
775 fn is_open(&self, tool_name: &str) -> bool {
777 self.failures.get(tool_name).copied().unwrap_or(0) >= self.threshold
778 }
779
780 fn record(&mut self, outcome: &ToolOutcome) {
782 if outcome.success {
783 self.failures.remove(&outcome.tool_name);
784 } else {
785 let count = self.failures.entry(outcome.tool_name.clone()).or_insert(0);
786 *count += 1;
787 }
788 }
789}
790
791#[async_trait]
792impl Agent for LlmAgent {
793 fn name(&self) -> &str {
794 &self.name
795 }
796
797 fn description(&self) -> &str {
798 &self.description
799 }
800
801 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
802 &self.sub_agents
803 }
804
805 #[adk_telemetry::instrument(
806 skip(self, ctx),
807 fields(
808 agent.name = %self.name,
809 agent.description = %self.description,
810 invocation.id = %ctx.invocation_id(),
811 user.id = %ctx.user_id(),
812 session.id = %ctx.session_id()
813 )
814 )]
815 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
816 adk_telemetry::info!("Starting agent execution");
817 let ctx = Self::apply_input_guardrails(ctx, self.input_guardrails.clone()).await?;
818
819 let agent_name = self.name.clone();
820 let invocation_id = ctx.invocation_id().to_string();
821 let model = self.model.clone();
822 let tools = self.tools.clone();
823 let toolsets = self.toolsets.clone();
824 let sub_agents = self.sub_agents.clone();
825
826 let instruction = self.instruction.clone();
827 let instruction_provider = self.instruction_provider.clone();
828 let global_instruction = self.global_instruction.clone();
829 let global_instruction_provider = self.global_instruction_provider.clone();
830 let skills_index = self.skills_index.clone();
831 let skill_policy = self.skill_policy.clone();
832 let max_skill_chars = self.max_skill_chars;
833 let output_key = self.output_key.clone();
834 let output_schema = self.output_schema.clone();
835 let generate_content_config = self.generate_content_config.clone();
836 let include_contents = self.include_contents;
837 let max_iterations = self.max_iterations;
838 let tool_timeout = self.tool_timeout;
839 let before_agent_callbacks = self.before_callbacks.clone();
841 let after_agent_callbacks = self.after_callbacks.clone();
842 let before_model_callbacks = self.before_model_callbacks.clone();
843 let after_model_callbacks = self.after_model_callbacks.clone();
844 let before_tool_callbacks = self.before_tool_callbacks.clone();
845 let after_tool_callbacks = self.after_tool_callbacks.clone();
846 let on_tool_error_callbacks = self.on_tool_error_callbacks.clone();
847 let after_tool_callbacks_full = self.after_tool_callbacks_full.clone();
848 let default_retry_budget = self.default_retry_budget.clone();
849 let tool_retry_budgets = self.tool_retry_budgets.clone();
850 let circuit_breaker_threshold = self.circuit_breaker_threshold;
851 let tool_confirmation_policy = self.tool_confirmation_policy.clone();
852 let disallow_transfer_to_parent = self.disallow_transfer_to_parent;
853 let disallow_transfer_to_peers = self.disallow_transfer_to_peers;
854 let output_guardrails = self.output_guardrails.clone();
855 let agent_tool_execution_strategy = self.tool_execution_strategy;
856
857 let s = stream! {
858 for callback in before_agent_callbacks.as_ref() {
862 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
863 Ok(Some(content)) => {
864 let mut early_event = Event::new(&invocation_id);
866 early_event.author = agent_name.clone();
867 early_event.llm_response.content = Some(content);
868 yield Ok(early_event);
869
870 for after_callback in after_agent_callbacks.as_ref() {
872 match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
873 Ok(Some(after_content)) => {
874 let mut after_event = Event::new(&invocation_id);
875 after_event.author = agent_name.clone();
876 after_event.llm_response.content = Some(after_content);
877 yield Ok(after_event);
878 return;
879 }
880 Ok(None) => continue,
881 Err(e) => {
882 yield Err(e);
883 return;
884 }
885 }
886 }
887 return;
888 }
889 Ok(None) => {
890 continue;
892 }
893 Err(e) => {
894 yield Err(e);
896 return;
897 }
898 }
899 }
900
901 let mut prompt_preamble = Vec::new();
903
904 if let Some(index) = &skills_index {
908 let user_query = ctx
909 .user_content()
910 .parts
911 .iter()
912 .filter_map(|part| match part {
913 Part::Text { text } => Some(text.as_str()),
914 _ => None,
915 })
916 .collect::<Vec<_>>()
917 .join("\n");
918
919 if let Some((_matched, skill_block)) = select_skill_prompt_block(
920 index.as_ref(),
921 &user_query,
922 &skill_policy,
923 max_skill_chars,
924 ) {
925 prompt_preamble.push(Content {
926 role: "user".to_string(),
927 parts: vec![Part::Text { text: skill_block }],
928 });
929 }
930 }
931
932 if let Some(provider) = &global_instruction_provider {
935 let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
937 if !global_inst.is_empty() {
938 prompt_preamble.push(Content {
939 role: "user".to_string(),
940 parts: vec![Part::Text { text: global_inst }],
941 });
942 }
943 } else if let Some(ref template) = global_instruction {
944 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
946 if !processed.is_empty() {
947 prompt_preamble.push(Content {
948 role: "user".to_string(),
949 parts: vec![Part::Text { text: processed }],
950 });
951 }
952 }
953
954 if let Some(provider) = &instruction_provider {
957 let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
959 if !inst.is_empty() {
960 prompt_preamble.push(Content {
961 role: "user".to_string(),
962 parts: vec![Part::Text { text: inst }],
963 });
964 }
965 } else if let Some(ref template) = instruction {
966 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
968 if !processed.is_empty() {
969 prompt_preamble.push(Content {
970 role: "user".to_string(),
971 parts: vec![Part::Text { text: processed }],
972 });
973 }
974 }
975
976 let session_history = if !ctx.run_config().transfer_targets.is_empty() {
982 ctx.session().conversation_history_for_agent(&agent_name)
983 } else {
984 ctx.session().conversation_history()
985 };
986 let mut session_history = session_history;
987 let current_user_content = ctx.user_content().clone();
988 if let Some(index) = session_history.iter().rposition(|content| content.role == "user") {
989 session_history[index] = current_user_content.clone();
990 } else {
991 session_history.push(current_user_content.clone());
992 }
993
994 let mut conversation_history = match include_contents {
997 adk_core::IncludeContents::None => {
998 let mut filtered = prompt_preamble.clone();
999 filtered.push(current_user_content);
1000 filtered
1001 }
1002 adk_core::IncludeContents::Default => {
1003 let mut full_history = prompt_preamble;
1004 full_history.extend(session_history);
1005 full_history
1006 }
1007 };
1008
1009 let mut resolved_tools: Vec<Arc<dyn Tool>> = tools.clone();
1012 let static_tool_names: std::collections::HashSet<String> =
1013 tools.iter().map(|t| t.name().to_string()).collect();
1014
1015 let mut toolset_source: std::collections::HashMap<String, String> =
1017 std::collections::HashMap::new();
1018
1019 for toolset in &toolsets {
1020 let toolset_tools = match toolset
1021 .tools(ctx.clone() as Arc<dyn ReadonlyContext>)
1022 .await
1023 {
1024 Ok(t) => t,
1025 Err(e) => {
1026 yield Err(e);
1027 return;
1028 }
1029 };
1030 for tool in &toolset_tools {
1031 let name = tool.name().to_string();
1032 if static_tool_names.contains(&name) {
1034 yield Err(adk_core::AdkError::agent(format!(
1035 "Duplicate tool name '{name}': conflict between static tool and toolset '{}'",
1036 toolset.name()
1037 )));
1038 return;
1039 }
1040 if let Some(other_toolset_name) = toolset_source.get(&name) {
1042 yield Err(adk_core::AdkError::agent(format!(
1043 "Duplicate tool name '{name}': conflict between toolset '{}' and toolset '{}'",
1044 other_toolset_name,
1045 toolset.name()
1046 )));
1047 return;
1048 }
1049 toolset_source.insert(name, toolset.name().to_string());
1050 resolved_tools.push(tool.clone());
1051 }
1052 }
1053
1054 let tool_map: std::collections::HashMap<String, Arc<dyn Tool>> = resolved_tools
1056 .iter()
1057 .map(|t| (t.name().to_string(), t.clone()))
1058 .collect();
1059
1060 let collect_long_running_ids = |content: &Content| -> Vec<String> {
1062 content.parts.iter()
1063 .filter_map(|p| {
1064 if let Part::FunctionCall { name, .. } = p {
1065 if let Some(tool) = tool_map.get(name) {
1066 if tool.is_long_running() {
1067 return Some(name.clone());
1068 }
1069 }
1070 }
1071 None
1072 })
1073 .collect()
1074 };
1075
1076 let mut tool_declarations = std::collections::HashMap::new();
1081 for tool in &resolved_tools {
1082 tool_declarations.insert(tool.name().to_string(), tool.declaration());
1083 }
1084
1085 let mut valid_transfer_targets: Vec<String> = sub_agents
1090 .iter()
1091 .map(|a| a.name().to_string())
1092 .collect();
1093
1094 let run_config_targets = &ctx.run_config().transfer_targets;
1096 let parent_agent_name = ctx.run_config().parent_agent.clone();
1097 let sub_agent_names: std::collections::HashSet<&str> = sub_agents
1098 .iter()
1099 .map(|a| a.name())
1100 .collect();
1101
1102 for target in run_config_targets {
1103 if sub_agent_names.contains(target.as_str()) {
1105 continue;
1106 }
1107
1108 let is_parent = parent_agent_name.as_deref() == Some(target.as_str());
1110 if is_parent && disallow_transfer_to_parent {
1111 continue;
1112 }
1113 if !is_parent && disallow_transfer_to_peers {
1114 continue;
1115 }
1116
1117 valid_transfer_targets.push(target.clone());
1118 }
1119
1120 if !valid_transfer_targets.is_empty() {
1122 let transfer_tool_name = "transfer_to_agent";
1123 let transfer_tool_decl = serde_json::json!({
1124 "name": transfer_tool_name,
1125 "description": format!(
1126 "Transfer execution to another agent. Valid targets: {}",
1127 valid_transfer_targets.join(", ")
1128 ),
1129 "parameters": {
1130 "type": "object",
1131 "properties": {
1132 "agent_name": {
1133 "type": "string",
1134 "description": "The name of the agent to transfer to.",
1135 "enum": valid_transfer_targets
1136 }
1137 },
1138 "required": ["agent_name"]
1139 }
1140 });
1141 tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
1142 }
1143
1144
1145 let mut circuit_breaker_state = circuit_breaker_threshold.map(CircuitBreakerState::new);
1148
1149 let mut iteration = 0;
1151
1152 loop {
1153 iteration += 1;
1154 if iteration > max_iterations {
1155 yield Err(adk_core::AdkError::agent(
1156 format!("Max iterations ({max_iterations}) exceeded")
1157 ));
1158 return;
1159 }
1160
1161 let config = match (&generate_content_config, &output_schema) {
1168 (Some(base), Some(schema)) => {
1169 let mut merged = base.clone();
1170 merged.response_schema = Some(schema.clone());
1171 Some(merged)
1172 }
1173 (Some(base), None) => Some(base.clone()),
1174 (None, Some(schema)) => Some(adk_core::GenerateContentConfig {
1175 response_schema: Some(schema.clone()),
1176 ..Default::default()
1177 }),
1178 (None, None) => None,
1179 };
1180
1181 let config = if let Some(ref cached) = ctx.run_config().cached_content {
1183 let mut cfg = config.unwrap_or_default();
1184 if cfg.cached_content.is_none() {
1186 cfg.cached_content = Some(cached.clone());
1187 }
1188 Some(cfg)
1189 } else {
1190 config
1191 };
1192
1193 let request = LlmRequest {
1194 model: model.name().to_string(),
1195 contents: conversation_history.clone(),
1196 tools: tool_declarations.clone(),
1197 config,
1198 };
1199
1200 let mut current_request = request;
1203 let mut model_response_override = None;
1204 for callback in before_model_callbacks.as_ref() {
1205 match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
1206 Ok(BeforeModelResult::Continue(modified_request)) => {
1207 current_request = modified_request;
1209 }
1210 Ok(BeforeModelResult::Skip(response)) => {
1211 model_response_override = Some(response);
1213 break;
1214 }
1215 Err(e) => {
1216 yield Err(e);
1218 return;
1219 }
1220 }
1221 }
1222 let request = current_request;
1223
1224 let mut accumulated_content: Option<Content> = None;
1226 let mut final_provider_metadata: Option<serde_json::Value> = None;
1227
1228 if let Some(cached_response) = model_response_override {
1229 accumulated_content = cached_response.content.clone();
1232 final_provider_metadata = cached_response.provider_metadata.clone();
1233 normalize_option_content(&mut accumulated_content);
1234 if let Some(content) = accumulated_content.take() {
1235 let has_function_calls = content
1236 .parts
1237 .iter()
1238 .any(|part| matches!(part, Part::FunctionCall { .. }));
1239 let content = if has_function_calls {
1240 content
1241 } else {
1242 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1243 };
1244 accumulated_content = Some(content);
1245 }
1246
1247 let mut cached_event = Event::new(&invocation_id);
1248 cached_event.author = agent_name.clone();
1249 cached_event.llm_response.content = accumulated_content.clone();
1250 cached_event.llm_response.provider_metadata = cached_response.provider_metadata.clone();
1251 cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
1252 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), serde_json::to_string(&request).unwrap_or_default());
1253 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&cached_response).unwrap_or_default());
1254
1255 if let Some(ref content) = accumulated_content {
1257 cached_event.long_running_tool_ids = collect_long_running_ids(content);
1258 }
1259
1260 yield Ok(cached_event);
1261 } else {
1262 let request_json = serde_json::to_string(&request).unwrap_or_default();
1264
1265 let llm_ts = std::time::SystemTime::now()
1267 .duration_since(std::time::UNIX_EPOCH)
1268 .unwrap_or_default()
1269 .as_nanos();
1270 let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
1271 let llm_span = tracing::info_span!(
1272 "call_llm",
1273 "gcp.vertex.agent.event_id" = %llm_event_id,
1274 "gcp.vertex.agent.invocation_id" = %invocation_id,
1275 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1276 "gen_ai.conversation.id" = %ctx.session_id(),
1277 "gcp.vertex.agent.llm_request" = %request_json,
1278 "gcp.vertex.agent.llm_response" = tracing::field::Empty );
1280 let _llm_guard = llm_span.enter();
1281
1282 use adk_core::StreamingMode;
1284 let streaming_mode = ctx.run_config().streaming_mode;
1285 let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi)
1286 && output_guardrails.is_empty();
1287
1288 let mut response_stream = model.generate_content(request, true).await?;
1290
1291 use futures::StreamExt;
1292
1293 let mut last_chunk: Option<LlmResponse> = None;
1295
1296 while let Some(chunk_result) = response_stream.next().await {
1298 let mut chunk = match chunk_result {
1299 Ok(c) => c,
1300 Err(e) => {
1301 yield Err(e);
1302 return;
1303 }
1304 };
1305
1306 for callback in after_model_callbacks.as_ref() {
1309 match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
1310 Ok(Some(modified_chunk)) => {
1311 chunk = modified_chunk;
1313 break;
1314 }
1315 Ok(None) => {
1316 continue;
1318 }
1319 Err(e) => {
1320 yield Err(e);
1322 return;
1323 }
1324 }
1325 }
1326
1327 normalize_option_content(&mut chunk.content);
1328
1329 if let Some(chunk_content) = chunk.content.clone() {
1331 if let Some(ref mut acc) = accumulated_content {
1332 acc.parts.extend(chunk_content.parts);
1333 } else {
1334 accumulated_content = Some(chunk_content);
1335 }
1336 }
1337
1338 if should_stream_to_client {
1340 let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
1341 partial_event.author = agent_name.clone();
1342 partial_event.llm_request = Some(request_json.clone());
1343 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1344 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&chunk).unwrap_or_default());
1345 partial_event.llm_response.partial = chunk.partial;
1346 partial_event.llm_response.turn_complete = chunk.turn_complete;
1347 partial_event.llm_response.finish_reason = chunk.finish_reason;
1348 partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
1349 partial_event.llm_response.content = chunk.content.clone();
1350 partial_event.llm_response.provider_metadata = chunk.provider_metadata.clone();
1351
1352 if let Some(ref content) = chunk.content {
1354 partial_event.long_running_tool_ids = collect_long_running_ids(content);
1355 }
1356
1357 yield Ok(partial_event);
1358 }
1359
1360 last_chunk = Some(chunk.clone());
1362
1363 if chunk.turn_complete {
1365 break;
1366 }
1367 }
1368
1369 if !should_stream_to_client {
1371 if let Some(content) = accumulated_content.take() {
1372 let has_function_calls = content
1373 .parts
1374 .iter()
1375 .any(|part| matches!(part, Part::FunctionCall { .. }));
1376 let content = if has_function_calls {
1377 content
1378 } else {
1379 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1380 };
1381 accumulated_content = Some(content);
1382 }
1383
1384 let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
1385 final_event.author = agent_name.clone();
1386 final_event.llm_request = Some(request_json.clone());
1387 final_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1388 final_event.llm_response.content = accumulated_content.clone();
1389 final_event.llm_response.partial = false;
1390 final_event.llm_response.turn_complete = true;
1391
1392 if let Some(ref last) = last_chunk {
1394 final_event.llm_response.finish_reason = last.finish_reason;
1395 final_event.llm_response.usage_metadata = last.usage_metadata.clone();
1396 final_event.llm_response.provider_metadata = last.provider_metadata.clone();
1397 final_provider_metadata = last.provider_metadata.clone();
1398 final_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(last).unwrap_or_default());
1399 }
1400
1401 if let Some(ref content) = accumulated_content {
1403 final_event.long_running_tool_ids = collect_long_running_ids(content);
1404 }
1405
1406 yield Ok(final_event);
1407 }
1408
1409 if let Some(ref content) = accumulated_content {
1411 let response_json = serde_json::to_string(content).unwrap_or_default();
1412 llm_span.record("gcp.vertex.agent.llm_response", &response_json);
1413 }
1414 }
1415
1416 let function_call_names: Vec<String> = accumulated_content.as_ref()
1418 .map(|c| c.parts.iter()
1419 .filter_map(|p| {
1420 if let Part::FunctionCall { name, .. } = p {
1421 Some(name.clone())
1422 } else {
1423 None
1424 }
1425 })
1426 .collect())
1427 .unwrap_or_default();
1428
1429 let has_function_calls = !function_call_names.is_empty();
1430
1431 let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
1435 tool_map.get(name)
1436 .map(|t| t.is_long_running())
1437 .unwrap_or(false)
1438 });
1439
1440 if let Some(ref content) = accumulated_content {
1442 conversation_history.push(Self::augment_content_for_history(
1443 content,
1444 final_provider_metadata.as_ref(),
1445 ));
1446
1447 if let Some(ref output_key) = output_key {
1449 if !has_function_calls { let mut text_parts = String::new();
1451 for part in &content.parts {
1452 if let Part::Text { text } = part {
1453 text_parts.push_str(text);
1454 }
1455 }
1456 if !text_parts.is_empty() {
1457 let mut state_event = Event::new(&invocation_id);
1459 state_event.author = agent_name.clone();
1460 state_event.actions.state_delta.insert(
1461 output_key.clone(),
1462 serde_json::Value::String(text_parts),
1463 );
1464 yield Ok(state_event);
1465 }
1466 }
1467 }
1468 }
1469
1470 if !has_function_calls {
1471 if let Some(ref content) = accumulated_content {
1474 let response_json = serde_json::to_string(content).unwrap_or_default();
1475 tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
1476 }
1477
1478 tracing::info!(agent.name = %agent_name, "Agent execution complete");
1479 break;
1480 }
1481
1482 if let Some(content) = &accumulated_content {
1484 let strategy = agent_tool_execution_strategy
1487 .unwrap_or(ToolExecutionStrategy::Sequential);
1488
1489 let mut fc_parts: Vec<(usize, String, serde_json::Value, Option<String>, String)> = Vec::new();
1493 {
1494 let mut tci = 0usize;
1495 for part in &content.parts {
1496 if let Part::FunctionCall { name, args, id, .. } = part {
1497 let fallback = format!("{}_{}_{}", invocation_id, name, tci);
1498 let fcid = id.clone().unwrap_or(fallback);
1499 fc_parts.push((tci, name.clone(), args.clone(), id.clone(), fcid));
1500 tci += 1;
1501 }
1502 }
1503 }
1504
1505 let mut transfer_handled = false;
1509 for (_, fc_name, fc_args, fc_id, _) in &fc_parts {
1510 if fc_name == "transfer_to_agent" {
1511 let target_agent = fc_args.get("agent_name")
1512 .and_then(|v| v.as_str())
1513 .unwrap_or_default()
1514 .to_string();
1515
1516 let valid_target = valid_transfer_targets.iter().any(|n| n == &target_agent);
1517 if !valid_target {
1518 let error_content = Content {
1519 role: "function".to_string(),
1520 parts: vec![Part::FunctionResponse {
1521 function_response: FunctionResponseData::new(
1522 fc_name.clone(),
1523 serde_json::json!({
1524 "error": format!(
1525 "Agent '{}' not found. Available agents: {:?}",
1526 target_agent, valid_transfer_targets
1527 )
1528 }),
1529 ),
1530 id: fc_id.clone(),
1531 }],
1532 };
1533 conversation_history.push(error_content.clone());
1534 let mut error_event = Event::new(&invocation_id);
1535 error_event.author = agent_name.clone();
1536 error_event.llm_response.content = Some(error_content);
1537 yield Ok(error_event);
1538 continue;
1539 }
1540
1541 let mut transfer_event = Event::new(&invocation_id);
1542 transfer_event.author = agent_name.clone();
1543 transfer_event.actions.transfer_to_agent = Some(target_agent);
1544 yield Ok(transfer_event);
1545 transfer_handled = true;
1546 break;
1547 }
1548 }
1549 if transfer_handled {
1550 return;
1551 }
1552
1553 let fc_parts: Vec<_> = fc_parts.into_iter().filter(|(_, fc_name, _, _, _)| {
1555 if fc_name == "transfer_to_agent" {
1556 return false;
1557 }
1558 if let Some(tool) = tool_map.get(fc_name) {
1559 if tool.is_builtin() {
1560 adk_telemetry::debug!(tool.name = %fc_name, "skipping built-in tool execution");
1561 return false;
1562 }
1563 }
1564 true
1565 }).collect();
1566
1567 let mut confirmation_interrupted = false;
1571 for (_, fc_name, fc_args, _, fc_call_id) in &fc_parts {
1572 if tool_confirmation_policy.requires_confirmation(fc_name)
1573 && ctx.run_config().tool_confirmation_decisions.get(fc_name).copied().is_none()
1574 {
1575 let mut ce = Event::new(&invocation_id);
1576 ce.author = agent_name.clone();
1577 ce.llm_response.interrupted = true;
1578 ce.llm_response.turn_complete = true;
1579 ce.llm_response.content = Some(Content {
1580 role: "model".to_string(),
1581 parts: vec![Part::Text {
1582 text: format!(
1583 "Tool confirmation required for '{}'. Provide approve/deny decision to continue.",
1584 fc_name
1585 ),
1586 }],
1587 });
1588 ce.actions.tool_confirmation = Some(ToolConfirmationRequest {
1589 tool_name: fc_name.clone(),
1590 function_call_id: Some(fc_call_id.clone()),
1591 args: fc_args.clone(),
1592 });
1593 yield Ok(ce);
1594 confirmation_interrupted = true;
1595 break;
1596 }
1597 }
1598 if confirmation_interrupted {
1599 return;
1600 }
1601
1602 let cb_mutex = std::sync::Mutex::new(circuit_breaker_state.take());
1604
1605 let execute_one_tool = |idx: usize, name: String, args: serde_json::Value,
1610 id: Option<String>, function_call_id: String| {
1611 let ctx = ctx.clone();
1612 let tool_map = &tool_map;
1613 let tool_retry_budgets = &tool_retry_budgets;
1614 let default_retry_budget = &default_retry_budget;
1615 let before_tool_callbacks = &before_tool_callbacks;
1616 let after_tool_callbacks = &after_tool_callbacks;
1617 let after_tool_callbacks_full = &after_tool_callbacks_full;
1618 let on_tool_error_callbacks = &on_tool_error_callbacks;
1619 let tool_confirmation_policy = &tool_confirmation_policy;
1620 let cb_mutex = &cb_mutex;
1621 let invocation_id = &invocation_id;
1622 async move {
1623 let mut tool_actions = EventActions::default();
1624 let mut response_content: Option<Content> = None;
1625 let mut run_after_tool_callbacks = true;
1626 let mut tool_outcome_for_callback: Option<ToolOutcome> = None;
1627 let mut executed_tool: Option<Arc<dyn Tool>> = None;
1628 let mut executed_tool_response: Option<serde_json::Value> = None;
1629
1630 if tool_confirmation_policy.requires_confirmation(&name) {
1632 match ctx.run_config().tool_confirmation_decisions.get(&name).copied() {
1633 Some(ToolConfirmationDecision::Approve) => {
1634 tool_actions.tool_confirmation_decision =
1635 Some(ToolConfirmationDecision::Approve);
1636 }
1637 Some(ToolConfirmationDecision::Deny) => {
1638 tool_actions.tool_confirmation_decision =
1639 Some(ToolConfirmationDecision::Deny);
1640 response_content = Some(Content {
1641 role: "function".to_string(),
1642 parts: vec![Part::FunctionResponse {
1643 function_response: FunctionResponseData::new(
1644 name.clone(),
1645 serde_json::json!({
1646 "error": format!("Tool '{}' execution denied by confirmation policy", name)
1647 }),
1648 ),
1649 id: id.clone(),
1650 }],
1651 });
1652 run_after_tool_callbacks = false;
1653 }
1654 None => {
1655 response_content = Some(Content {
1656 role: "function".to_string(),
1657 parts: vec![Part::FunctionResponse {
1658 function_response: FunctionResponseData::new(
1659 name.clone(),
1660 serde_json::json!({
1661 "error": format!("Tool '{}' requires confirmation", name)
1662 }),
1663 ),
1664 id: id.clone(),
1665 }],
1666 });
1667 run_after_tool_callbacks = false;
1668 }
1669 }
1670 }
1671
1672 if response_content.is_none() {
1674 let tool_ctx = Arc::new(ToolCallbackContext::new(
1675 ctx.clone(),
1676 name.clone(),
1677 args.clone(),
1678 ));
1679 for callback in before_tool_callbacks.as_ref() {
1680 match callback(tool_ctx.clone() as Arc<dyn CallbackContext>).await {
1681 Ok(Some(c)) => { response_content = Some(c); break; }
1682 Ok(None) => continue,
1683 Err(e) => {
1684 response_content = Some(Content {
1685 role: "function".to_string(),
1686 parts: vec![Part::FunctionResponse {
1687 function_response: FunctionResponseData::new(
1688 name.clone(),
1689 serde_json::json!({ "error": e.to_string() }),
1690 ),
1691 id: id.clone(),
1692 }],
1693 });
1694 run_after_tool_callbacks = false;
1695 break;
1696 }
1697 }
1698 }
1699 }
1700
1701 if response_content.is_none() {
1703 let guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
1704 if let Some(ref cb_state) = *guard {
1705 if cb_state.is_open(&name) {
1706 let msg = format!(
1707 "Tool '{}' is temporarily disabled after {} consecutive failures",
1708 name, cb_state.threshold
1709 );
1710 tracing::warn!(tool.name = %name, "circuit breaker open, skipping tool execution");
1711 response_content = Some(Content {
1712 role: "function".to_string(),
1713 parts: vec![Part::FunctionResponse {
1714 function_response: FunctionResponseData::new(
1715 name.clone(),
1716 serde_json::json!({ "error": msg }),
1717 ),
1718 id: id.clone(),
1719 }],
1720 });
1721 run_after_tool_callbacks = false;
1722 }
1723 }
1724 drop(guard);
1725 }
1726
1727 if response_content.is_none() {
1729 if let Some(tool) = tool_map.get(&name) {
1730 let tool_ctx: Arc<dyn ToolContext> = Arc::new(
1731 AgentToolContext::new(ctx.clone(), function_call_id.clone()),
1732 );
1733 let span_name = format!("execute_tool {name}");
1734 let tool_span = tracing::info_span!(
1735 "",
1736 otel.name = %span_name,
1737 tool.name = %name,
1738 "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
1739 "gcp.vertex.agent.invocation_id" = %invocation_id,
1740 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1741 "gen_ai.conversation.id" = %ctx.session_id()
1742 );
1743
1744 let budget = tool_retry_budgets.get(&name)
1745 .or(default_retry_budget.as_ref());
1746 let max_attempts = budget.map(|b| b.max_retries + 1).unwrap_or(1);
1747 let retry_delay = budget.map(|b| b.delay).unwrap_or_default();
1748
1749 let tool_clone = tool.clone();
1750 let tool_start = std::time::Instant::now();
1751 let mut last_error = String::new();
1752 let mut final_attempt: u32 = 0;
1753 let mut retry_result: Option<serde_json::Value> = None;
1754
1755 for attempt in 0..max_attempts {
1756 final_attempt = attempt;
1757 if attempt > 0 {
1758 tokio::time::sleep(retry_delay).await;
1759 }
1760 match async {
1761 tracing::info!(tool.name = %name, tool.args = %args, attempt = attempt, "tool_call");
1762 let exec_future = tool_clone.execute(tool_ctx.clone(), args.clone());
1763 tokio::time::timeout(tool_timeout, exec_future).await
1764 }.instrument(tool_span.clone()).await {
1765 Ok(Ok(value)) => {
1766 tracing::info!(tool.name = %name, tool.result = %value, "tool_result");
1767 retry_result = Some(value);
1768 break;
1769 }
1770 Ok(Err(e)) => {
1771 last_error = e.to_string();
1772 if attempt + 1 < max_attempts {
1773 tracing::warn!(tool.name = %name, attempt = attempt, error = %last_error, "tool execution failed, retrying");
1774 } else {
1775 tracing::warn!(tool.name = %name, error = %last_error, "tool_error");
1776 }
1777 }
1778 Err(_) => {
1779 last_error = format!(
1780 "Tool '{}' timed out after {} seconds",
1781 name, tool_timeout.as_secs()
1782 );
1783 if attempt + 1 < max_attempts {
1784 tracing::warn!(tool.name = %name, attempt = attempt, timeout_secs = tool_timeout.as_secs(), "tool timed out, retrying");
1785 } else {
1786 tracing::warn!(tool.name = %name, timeout_secs = tool_timeout.as_secs(), "tool_timeout");
1787 }
1788 }
1789 }
1790 }
1791
1792 let tool_duration = tool_start.elapsed();
1793 let (tool_success, tool_error_message, function_response) = match retry_result {
1794 Some(value) => (true, None, value),
1795 None => (false, Some(last_error.clone()), serde_json::json!({ "error": last_error })),
1796 };
1797
1798 let outcome = ToolOutcome {
1799 tool_name: name.clone(),
1800 tool_args: args.clone(),
1801 success: tool_success,
1802 duration: tool_duration,
1803 error_message: tool_error_message.clone(),
1804 attempt: final_attempt,
1805 };
1806 tool_outcome_for_callback = Some(outcome);
1807
1808 {
1810 let mut guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
1811 if let Some(ref mut cb_state) = *guard {
1812 cb_state.record(tool_outcome_for_callback.as_ref().unwrap());
1813 }
1814 }
1815
1816 let final_function_response = if !tool_success {
1818 let mut fallback_result = None;
1819 let error_msg = tool_error_message.clone().unwrap_or_default();
1820 for callback in on_tool_error_callbacks.as_ref() {
1821 match callback(
1822 ctx.clone() as Arc<dyn CallbackContext>,
1823 tool.clone(),
1824 args.clone(),
1825 error_msg.clone(),
1826 ).await {
1827 Ok(Some(result)) => { fallback_result = Some(result); break; }
1828 Ok(None) => continue,
1829 Err(e) => { tracing::warn!(error = %e, "on_tool_error callback failed"); break; }
1830 }
1831 }
1832 fallback_result.unwrap_or(function_response)
1833 } else {
1834 function_response
1835 };
1836
1837 let confirmation_decision = tool_actions.tool_confirmation_decision;
1838 tool_actions = tool_ctx.actions();
1839 if tool_actions.tool_confirmation_decision.is_none() {
1840 tool_actions.tool_confirmation_decision = confirmation_decision;
1841 }
1842 executed_tool = Some(tool.clone());
1843 executed_tool_response = Some(final_function_response.clone());
1844 response_content = Some(Content {
1845 role: "function".to_string(),
1846 parts: vec![Part::FunctionResponse {
1847 function_response: FunctionResponseData::from_tool_result(
1848 name.clone(),
1849 final_function_response,
1850 ),
1851 id: id.clone(),
1852 }],
1853 });
1854 } else {
1855 response_content = Some(Content {
1856 role: "function".to_string(),
1857 parts: vec![Part::FunctionResponse {
1858 function_response: FunctionResponseData::new(
1859 name.clone(),
1860 serde_json::json!({
1861 "error": format!("Tool {} not found", name)
1862 }),
1863 ),
1864 id: id.clone(),
1865 }],
1866 });
1867 }
1868 }
1869
1870 let mut response_content = response_content.expect("tool response content is set");
1872 if run_after_tool_callbacks {
1873 let outcome_ctx: Arc<dyn CallbackContext> = match tool_outcome_for_callback {
1874 Some(outcome) => Arc::new(ToolOutcomeCallbackContext {
1875 inner: ctx.clone() as Arc<dyn CallbackContext>,
1876 outcome,
1877 }),
1878 None => ctx.clone() as Arc<dyn CallbackContext>,
1879 };
1880 let cb_ctx: Arc<dyn CallbackContext> = Arc::new(ToolCallbackContext::new(
1881 outcome_ctx,
1882 name.clone(),
1883 args.clone(),
1884 ));
1885 for callback in after_tool_callbacks.as_ref() {
1886 match callback(cb_ctx.clone()).await {
1887 Ok(Some(modified)) => { response_content = modified; break; }
1888 Ok(None) => continue,
1889 Err(e) => {
1890 response_content = Content {
1891 role: "function".to_string(),
1892 parts: vec![Part::FunctionResponse {
1893 function_response: FunctionResponseData::new(
1894 name.clone(),
1895 serde_json::json!({ "error": e.to_string() }),
1896 ),
1897 id: id.clone(),
1898 }],
1899 };
1900 break;
1901 }
1902 }
1903 }
1904 if let (Some(tool_ref), Some(tool_resp)) = (&executed_tool, executed_tool_response) {
1905 for callback in after_tool_callbacks_full.as_ref() {
1906 match callback(
1907 cb_ctx.clone(), tool_ref.clone(), args.clone(), tool_resp.clone(),
1908 ).await {
1909 Ok(Some(modified_value)) => {
1910 response_content = Content {
1911 role: "function".to_string(),
1912 parts: vec![Part::FunctionResponse {
1913 function_response: FunctionResponseData::from_tool_result(
1914 name.clone(),
1915 modified_value,
1916 ),
1917 id: id.clone(),
1918 }],
1919 };
1920 break;
1921 }
1922 Ok(None) => continue,
1923 Err(e) => {
1924 response_content = Content {
1925 role: "function".to_string(),
1926 parts: vec![Part::FunctionResponse {
1927 function_response: FunctionResponseData::new(
1928 name.clone(),
1929 serde_json::json!({ "error": e.to_string() }),
1930 ),
1931 id: id.clone(),
1932 }],
1933 };
1934 break;
1935 }
1936 }
1937 }
1938 }
1939 }
1940
1941 let escalate_or_skip = tool_actions.escalate || tool_actions.skip_summarization;
1942 (idx, response_content, tool_actions, escalate_or_skip)
1943 }
1944 };
1945
1946 let results: Vec<(usize, Content, EventActions, bool)> = match strategy {
1948 ToolExecutionStrategy::Sequential => {
1949 let mut results = Vec::with_capacity(fc_parts.len());
1950 for (idx, name, args, id, fcid) in fc_parts {
1951 results.push(execute_one_tool(idx, name, args, id, fcid).await);
1952 }
1953 results
1954 }
1955 ToolExecutionStrategy::Parallel => {
1956 let futs: Vec<_> = fc_parts.into_iter()
1957 .map(|(idx, name, args, id, fcid)| execute_one_tool(idx, name, args, id, fcid))
1958 .collect();
1959 futures::future::join_all(futs).await
1960 }
1961 ToolExecutionStrategy::Auto => {
1962 let mut read_only_fcs = Vec::new();
1964 let mut mutable_fcs = Vec::new();
1965 for fc in fc_parts {
1966 let is_ro = tool_map.get(&fc.1)
1967 .map(|t| t.is_read_only())
1968 .unwrap_or(false);
1969 if is_ro { read_only_fcs.push(fc); } else { mutable_fcs.push(fc); }
1970 }
1971 let mut all_results = Vec::new();
1972 if !read_only_fcs.is_empty() {
1974 let ro_futs: Vec<_> = read_only_fcs.into_iter()
1975 .map(|(idx, name, args, id, fcid)| execute_one_tool(idx, name, args, id, fcid))
1976 .collect();
1977 all_results.extend(futures::future::join_all(ro_futs).await);
1978 }
1979 for (idx, name, args, id, fcid) in mutable_fcs {
1981 all_results.push(execute_one_tool(idx, name, args, id, fcid).await);
1982 }
1983 all_results.sort_by_key(|r| r.0);
1985 all_results
1986 }
1987 };
1988
1989 circuit_breaker_state = cb_mutex.into_inner().unwrap_or_else(|e| e.into_inner());
1991
1992 for (_, response_content, tool_actions, escalate_or_skip) in results {
1994 let mut tool_event = Event::new(&invocation_id);
1995 tool_event.author = agent_name.clone();
1996 tool_event.actions = tool_actions;
1997 tool_event.llm_response.content = Some(response_content.clone());
1998 yield Ok(tool_event);
1999
2000 if escalate_or_skip {
2001 return;
2002 }
2003
2004 conversation_history.push(response_content);
2005 }
2006 }
2007
2008 if all_calls_are_long_running {
2012 }
2016 }
2017
2018 for callback in after_agent_callbacks.as_ref() {
2021 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
2022 Ok(Some(content)) => {
2023 let mut after_event = Event::new(&invocation_id);
2025 after_event.author = agent_name.clone();
2026 after_event.llm_response.content = Some(content);
2027 yield Ok(after_event);
2028 break; }
2030 Ok(None) => {
2031 continue;
2033 }
2034 Err(e) => {
2035 yield Err(e);
2037 return;
2038 }
2039 }
2040 }
2041 };
2042
2043 Ok(Box::pin(s))
2044 }
2045}