1use adk_core::{
2 AfterAgentCallback, AfterModelCallback, AfterToolCallback, AfterToolCallbackFull, Agent,
3 BeforeAgentCallback, BeforeModelCallback, BeforeModelResult, BeforeToolCallback,
4 CallbackContext, Content, Event, EventActions, FunctionResponseData, GlobalInstructionProvider,
5 InstructionProvider, InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry,
6 OnToolErrorCallback, Part, ReadonlyContext, Result, RetryBudget, Tool, ToolCallbackContext,
7 ToolConfirmationDecision, ToolConfirmationPolicy, ToolConfirmationRequest, ToolContext,
8 ToolExecutionStrategy, ToolOutcome, Toolset,
9};
10use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index, select_skill_prompt_block};
11use async_stream::stream;
12use async_trait::async_trait;
13use std::sync::{Arc, Mutex};
14use tracing::Instrument;
15
16use crate::{
17 guardrails::{GuardrailSet, enforce_guardrails},
18 tool_call_markup::normalize_option_content,
19 workflow::with_user_content_override,
20};
21
22pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
24
25pub const DEFAULT_TOOL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
27
28pub struct LlmAgent {
29 name: String,
30 description: String,
31 model: Arc<dyn Llm>,
32 instruction: Option<String>,
33 instruction_provider: Option<Arc<InstructionProvider>>,
34 global_instruction: Option<String>,
35 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
36 skills_index: Option<Arc<SkillIndex>>,
37 skill_policy: SelectionPolicy,
38 max_skill_chars: usize,
39 #[allow(dead_code)] input_schema: Option<serde_json::Value>,
41 output_schema: Option<serde_json::Value>,
42 disallow_transfer_to_parent: bool,
43 disallow_transfer_to_peers: bool,
44 include_contents: adk_core::IncludeContents,
45 tools: Vec<Arc<dyn Tool>>,
46 #[allow(dead_code)] toolsets: Vec<Arc<dyn Toolset>>,
48 sub_agents: Vec<Arc<dyn Agent>>,
49 output_key: Option<String>,
50 generate_content_config: Option<adk_core::GenerateContentConfig>,
52 max_iterations: u32,
54 tool_timeout: std::time::Duration,
56 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
57 after_callbacks: Arc<Vec<AfterAgentCallback>>,
58 before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
59 after_model_callbacks: Arc<Vec<AfterModelCallback>>,
60 before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
61 after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
62 on_tool_error_callbacks: Arc<Vec<OnToolErrorCallback>>,
63 after_tool_callbacks_full: Arc<Vec<AfterToolCallbackFull>>,
65 default_retry_budget: Option<RetryBudget>,
67 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
69 circuit_breaker_threshold: Option<u32>,
72 tool_confirmation_policy: ToolConfirmationPolicy,
73 tool_execution_strategy: Option<ToolExecutionStrategy>,
76 input_guardrails: Arc<GuardrailSet>,
77 output_guardrails: Arc<GuardrailSet>,
78}
79
80impl std::fmt::Debug for LlmAgent {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("LlmAgent")
83 .field("name", &self.name)
84 .field("description", &self.description)
85 .field("model", &self.model.name())
86 .field("instruction", &self.instruction)
87 .field("tools_count", &self.tools.len())
88 .field("sub_agents_count", &self.sub_agents.len())
89 .finish()
90 }
91}
92
93impl LlmAgent {
94 async fn apply_input_guardrails(
95 ctx: Arc<dyn InvocationContext>,
96 input_guardrails: Arc<GuardrailSet>,
97 ) -> Result<Arc<dyn InvocationContext>> {
98 let content =
99 enforce_guardrails(input_guardrails.as_ref(), ctx.user_content(), "input").await?;
100 if content.role != ctx.user_content().role || content.parts != ctx.user_content().parts {
101 Ok(with_user_content_override(ctx, content))
102 } else {
103 Ok(ctx)
104 }
105 }
106
107 async fn apply_output_guardrails(
108 output_guardrails: &GuardrailSet,
109 content: Content,
110 ) -> Result<Content> {
111 enforce_guardrails(output_guardrails, &content, "output").await
112 }
113
114 fn history_parts_from_provider_metadata(
115 provider_metadata: Option<&serde_json::Value>,
116 ) -> Vec<Part> {
117 let Some(provider_metadata) = provider_metadata else {
118 return Vec::new();
119 };
120
121 let history_parts = provider_metadata
122 .get("conversation_history_parts")
123 .or_else(|| {
124 provider_metadata
125 .get("openai")
126 .and_then(|openai| openai.get("conversation_history_parts"))
127 })
128 .and_then(serde_json::Value::as_array);
129
130 history_parts
131 .into_iter()
132 .flatten()
133 .filter_map(|value| serde_json::from_value::<Part>(value.clone()).ok())
134 .collect()
135 }
136
137 fn augment_content_for_history(
138 content: &Content,
139 provider_metadata: Option<&serde_json::Value>,
140 ) -> Content {
141 let mut augmented = content.clone();
142 augmented.parts.extend(Self::history_parts_from_provider_metadata(provider_metadata));
143 augmented
144 }
145}
146
147pub struct LlmAgentBuilder {
148 name: String,
149 description: Option<String>,
150 model: Option<Arc<dyn Llm>>,
151 instruction: Option<String>,
152 instruction_provider: Option<Arc<InstructionProvider>>,
153 global_instruction: Option<String>,
154 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
155 skills_index: Option<Arc<SkillIndex>>,
156 skill_policy: SelectionPolicy,
157 max_skill_chars: usize,
158 input_schema: Option<serde_json::Value>,
159 output_schema: Option<serde_json::Value>,
160 disallow_transfer_to_parent: bool,
161 disallow_transfer_to_peers: bool,
162 include_contents: adk_core::IncludeContents,
163 tools: Vec<Arc<dyn Tool>>,
164 toolsets: Vec<Arc<dyn Toolset>>,
165 sub_agents: Vec<Arc<dyn Agent>>,
166 output_key: Option<String>,
167 generate_content_config: Option<adk_core::GenerateContentConfig>,
168 max_iterations: u32,
169 tool_timeout: std::time::Duration,
170 before_callbacks: Vec<BeforeAgentCallback>,
171 after_callbacks: Vec<AfterAgentCallback>,
172 before_model_callbacks: Vec<BeforeModelCallback>,
173 after_model_callbacks: Vec<AfterModelCallback>,
174 before_tool_callbacks: Vec<BeforeToolCallback>,
175 after_tool_callbacks: Vec<AfterToolCallback>,
176 on_tool_error_callbacks: Vec<OnToolErrorCallback>,
177 after_tool_callbacks_full: Vec<AfterToolCallbackFull>,
178 default_retry_budget: Option<RetryBudget>,
179 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
180 circuit_breaker_threshold: Option<u32>,
181 tool_confirmation_policy: ToolConfirmationPolicy,
182 tool_execution_strategy: Option<ToolExecutionStrategy>,
183 input_guardrails: GuardrailSet,
184 output_guardrails: GuardrailSet,
185}
186
187impl LlmAgentBuilder {
188 pub fn new(name: impl Into<String>) -> Self {
189 Self {
190 name: name.into(),
191 description: None,
192 model: None,
193 instruction: None,
194 instruction_provider: None,
195 global_instruction: None,
196 global_instruction_provider: None,
197 skills_index: None,
198 skill_policy: SelectionPolicy::default(),
199 max_skill_chars: 2000,
200 input_schema: None,
201 output_schema: None,
202 disallow_transfer_to_parent: false,
203 disallow_transfer_to_peers: false,
204 include_contents: adk_core::IncludeContents::Default,
205 tools: Vec::new(),
206 toolsets: Vec::new(),
207 sub_agents: Vec::new(),
208 output_key: None,
209 generate_content_config: None,
210 max_iterations: DEFAULT_MAX_ITERATIONS,
211 tool_timeout: DEFAULT_TOOL_TIMEOUT,
212 before_callbacks: Vec::new(),
213 after_callbacks: Vec::new(),
214 before_model_callbacks: Vec::new(),
215 after_model_callbacks: Vec::new(),
216 before_tool_callbacks: Vec::new(),
217 after_tool_callbacks: Vec::new(),
218 on_tool_error_callbacks: Vec::new(),
219 after_tool_callbacks_full: Vec::new(),
220 default_retry_budget: None,
221 tool_retry_budgets: std::collections::HashMap::new(),
222 circuit_breaker_threshold: None,
223 tool_confirmation_policy: ToolConfirmationPolicy::Never,
224 tool_execution_strategy: None,
225 input_guardrails: GuardrailSet::new(),
226 output_guardrails: GuardrailSet::new(),
227 }
228 }
229
230 pub fn description(mut self, desc: impl Into<String>) -> Self {
231 self.description = Some(desc.into());
232 self
233 }
234
235 pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
236 self.model = Some(model);
237 self
238 }
239
240 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
241 self.instruction = Some(instruction.into());
242 self
243 }
244
245 pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
246 self.instruction_provider = Some(Arc::new(provider));
247 self
248 }
249
250 pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
251 self.global_instruction = Some(instruction.into());
252 self
253 }
254
255 pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
256 self.global_instruction_provider = Some(Arc::new(provider));
257 self
258 }
259
260 pub fn with_skills(mut self, index: SkillIndex) -> Self {
262 self.skills_index = Some(Arc::new(index));
263 self
264 }
265
266 pub fn with_auto_skills(self) -> Result<Self> {
268 self.with_skills_from_root(".")
269 }
270
271 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
273 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
274 self.skills_index = Some(Arc::new(index));
275 Ok(self)
276 }
277
278 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
280 self.skill_policy = policy;
281 self
282 }
283
284 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
286 self.max_skill_chars = max_chars;
287 self
288 }
289
290 pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
291 self.input_schema = Some(schema);
292 self
293 }
294
295 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
296 self.output_schema = Some(schema);
297 self
298 }
299
300 pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
301 self.disallow_transfer_to_parent = disallow;
302 self
303 }
304
305 pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
306 self.disallow_transfer_to_peers = disallow;
307 self
308 }
309
310 pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
311 self.include_contents = include;
312 self
313 }
314
315 pub fn output_key(mut self, key: impl Into<String>) -> Self {
316 self.output_key = Some(key.into());
317 self
318 }
319
320 pub fn generate_content_config(mut self, config: adk_core::GenerateContentConfig) -> Self {
341 self.generate_content_config = Some(config);
342 self
343 }
344
345 pub fn temperature(mut self, temperature: f32) -> Self {
348 self.generate_content_config
349 .get_or_insert(adk_core::GenerateContentConfig::default())
350 .temperature = Some(temperature);
351 self
352 }
353
354 pub fn top_p(mut self, top_p: f32) -> Self {
356 self.generate_content_config
357 .get_or_insert(adk_core::GenerateContentConfig::default())
358 .top_p = Some(top_p);
359 self
360 }
361
362 pub fn top_k(mut self, top_k: i32) -> Self {
364 self.generate_content_config
365 .get_or_insert(adk_core::GenerateContentConfig::default())
366 .top_k = Some(top_k);
367 self
368 }
369
370 pub fn max_output_tokens(mut self, max_tokens: i32) -> Self {
372 self.generate_content_config
373 .get_or_insert(adk_core::GenerateContentConfig::default())
374 .max_output_tokens = Some(max_tokens);
375 self
376 }
377
378 pub fn max_iterations(mut self, max: u32) -> Self {
381 self.max_iterations = max;
382 self
383 }
384
385 pub fn tool_timeout(mut self, timeout: std::time::Duration) -> Self {
388 self.tool_timeout = timeout;
389 self
390 }
391
392 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
393 self.tools.push(tool);
394 self
395 }
396
397 pub fn toolset(mut self, toolset: Arc<dyn Toolset>) -> Self {
403 self.toolsets.push(toolset);
404 self
405 }
406
407 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
408 self.sub_agents.push(agent);
409 self
410 }
411
412 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
413 self.before_callbacks.push(callback);
414 self
415 }
416
417 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
418 self.after_callbacks.push(callback);
419 self
420 }
421
422 pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
423 self.before_model_callbacks.push(callback);
424 self
425 }
426
427 pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
428 self.after_model_callbacks.push(callback);
429 self
430 }
431
432 pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
433 self.before_tool_callbacks.push(callback);
434 self
435 }
436
437 pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
438 self.after_tool_callbacks.push(callback);
439 self
440 }
441
442 pub fn after_tool_callback_full(mut self, callback: AfterToolCallbackFull) -> Self {
457 self.after_tool_callbacks_full.push(callback);
458 self
459 }
460
461 pub fn on_tool_error(mut self, callback: OnToolErrorCallback) -> Self {
469 self.on_tool_error_callbacks.push(callback);
470 self
471 }
472
473 pub fn default_retry_budget(mut self, budget: RetryBudget) -> Self {
480 self.default_retry_budget = Some(budget);
481 self
482 }
483
484 pub fn tool_retry_budget(mut self, tool_name: impl Into<String>, budget: RetryBudget) -> Self {
489 self.tool_retry_budgets.insert(tool_name.into(), budget);
490 self
491 }
492
493 pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
500 self.circuit_breaker_threshold = Some(threshold);
501 self
502 }
503
504 pub fn tool_confirmation_policy(mut self, policy: ToolConfirmationPolicy) -> Self {
506 self.tool_confirmation_policy = policy;
507 self
508 }
509
510 pub fn require_tool_confirmation(mut self, tool_name: impl Into<String>) -> Self {
512 self.tool_confirmation_policy = self.tool_confirmation_policy.with_tool(tool_name);
513 self
514 }
515
516 pub fn require_tool_confirmation_for_all(mut self) -> Self {
518 self.tool_confirmation_policy = ToolConfirmationPolicy::Always;
519 self
520 }
521
522 pub fn tool_execution_strategy(mut self, strategy: ToolExecutionStrategy) -> Self {
528 self.tool_execution_strategy = Some(strategy);
529 self
530 }
531
532 pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
541 self.input_guardrails = guardrails;
542 self
543 }
544
545 pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
554 self.output_guardrails = guardrails;
555 self
556 }
557
558 pub fn build(self) -> Result<LlmAgent> {
559 let model = self.model.ok_or_else(|| adk_core::AdkError::agent("Model is required"))?;
560
561 let mut seen_names = std::collections::HashSet::new();
562 for agent in &self.sub_agents {
563 if !seen_names.insert(agent.name()) {
564 return Err(adk_core::AdkError::agent(format!(
565 "Duplicate sub-agent name: {}",
566 agent.name()
567 )));
568 }
569 }
570
571 Ok(LlmAgent {
572 name: self.name,
573 description: self.description.unwrap_or_default(),
574 model,
575 instruction: self.instruction,
576 instruction_provider: self.instruction_provider,
577 global_instruction: self.global_instruction,
578 global_instruction_provider: self.global_instruction_provider,
579 skills_index: self.skills_index,
580 skill_policy: self.skill_policy,
581 max_skill_chars: self.max_skill_chars,
582 input_schema: self.input_schema,
583 output_schema: self.output_schema,
584 disallow_transfer_to_parent: self.disallow_transfer_to_parent,
585 disallow_transfer_to_peers: self.disallow_transfer_to_peers,
586 include_contents: self.include_contents,
587 tools: self.tools,
588 toolsets: self.toolsets,
589 sub_agents: self.sub_agents,
590 output_key: self.output_key,
591 generate_content_config: self.generate_content_config,
592 max_iterations: self.max_iterations,
593 tool_timeout: self.tool_timeout,
594 before_callbacks: Arc::new(self.before_callbacks),
595 after_callbacks: Arc::new(self.after_callbacks),
596 before_model_callbacks: Arc::new(self.before_model_callbacks),
597 after_model_callbacks: Arc::new(self.after_model_callbacks),
598 before_tool_callbacks: Arc::new(self.before_tool_callbacks),
599 after_tool_callbacks: Arc::new(self.after_tool_callbacks),
600 on_tool_error_callbacks: Arc::new(self.on_tool_error_callbacks),
601 after_tool_callbacks_full: Arc::new(self.after_tool_callbacks_full),
602 default_retry_budget: self.default_retry_budget,
603 tool_retry_budgets: self.tool_retry_budgets,
604 circuit_breaker_threshold: self.circuit_breaker_threshold,
605 tool_confirmation_policy: self.tool_confirmation_policy,
606 tool_execution_strategy: self.tool_execution_strategy,
607 input_guardrails: Arc::new(self.input_guardrails),
608 output_guardrails: Arc::new(self.output_guardrails),
609 })
610 }
611}
612
613struct AgentToolContext {
616 parent_ctx: Arc<dyn InvocationContext>,
617 function_call_id: String,
618 actions: Mutex<EventActions>,
619}
620
621impl AgentToolContext {
622 fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
623 Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
624 }
625
626 fn actions_guard(&self) -> std::sync::MutexGuard<'_, EventActions> {
627 self.actions.lock().unwrap_or_else(|e| e.into_inner())
628 }
629}
630
631#[async_trait]
632impl ReadonlyContext for AgentToolContext {
633 fn invocation_id(&self) -> &str {
634 self.parent_ctx.invocation_id()
635 }
636
637 fn agent_name(&self) -> &str {
638 self.parent_ctx.agent_name()
639 }
640
641 fn user_id(&self) -> &str {
642 self.parent_ctx.user_id()
644 }
645
646 fn app_name(&self) -> &str {
647 self.parent_ctx.app_name()
649 }
650
651 fn session_id(&self) -> &str {
652 self.parent_ctx.session_id()
654 }
655
656 fn branch(&self) -> &str {
657 self.parent_ctx.branch()
658 }
659
660 fn user_content(&self) -> &Content {
661 self.parent_ctx.user_content()
662 }
663}
664
665#[async_trait]
666impl CallbackContext for AgentToolContext {
667 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
668 self.parent_ctx.artifacts()
670 }
671
672 fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
673 self.parent_ctx.shared_state()
674 }
675}
676
677#[async_trait]
678impl ToolContext for AgentToolContext {
679 fn function_call_id(&self) -> &str {
680 &self.function_call_id
681 }
682
683 fn actions(&self) -> EventActions {
684 self.actions_guard().clone()
685 }
686
687 fn set_actions(&self, actions: EventActions) {
688 *self.actions_guard() = actions;
689 }
690
691 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
692 if let Some(memory) = self.parent_ctx.memory() {
694 memory.search(query).await
695 } else {
696 Ok(vec![])
697 }
698 }
699
700 fn user_scopes(&self) -> Vec<String> {
701 self.parent_ctx.user_scopes()
702 }
703
704 async fn get_secret(&self, name: &str) -> Result<Option<String>> {
705 self.parent_ctx.get_secret(name).await
706 }
707}
708
709struct ToolOutcomeCallbackContext {
713 inner: Arc<dyn CallbackContext>,
714 outcome: ToolOutcome,
715}
716
717#[async_trait]
718impl ReadonlyContext for ToolOutcomeCallbackContext {
719 fn invocation_id(&self) -> &str {
720 self.inner.invocation_id()
721 }
722
723 fn agent_name(&self) -> &str {
724 self.inner.agent_name()
725 }
726
727 fn user_id(&self) -> &str {
728 self.inner.user_id()
729 }
730
731 fn app_name(&self) -> &str {
732 self.inner.app_name()
733 }
734
735 fn session_id(&self) -> &str {
736 self.inner.session_id()
737 }
738
739 fn branch(&self) -> &str {
740 self.inner.branch()
741 }
742
743 fn user_content(&self) -> &Content {
744 self.inner.user_content()
745 }
746}
747
748#[async_trait]
749impl CallbackContext for ToolOutcomeCallbackContext {
750 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
751 self.inner.artifacts()
752 }
753
754 fn tool_outcome(&self) -> Option<ToolOutcome> {
755 Some(self.outcome.clone())
756 }
757}
758
759struct CircuitBreakerState {
769 threshold: u32,
770 failures: std::collections::HashMap<String, u32>,
772}
773
774impl CircuitBreakerState {
775 fn new(threshold: u32) -> Self {
776 Self { threshold, failures: std::collections::HashMap::new() }
777 }
778
779 fn is_open(&self, tool_name: &str) -> bool {
781 self.failures.get(tool_name).copied().unwrap_or(0) >= self.threshold
782 }
783
784 fn record(&mut self, outcome: &ToolOutcome) {
786 if outcome.success {
787 self.failures.remove(&outcome.tool_name);
788 } else {
789 let count = self.failures.entry(outcome.tool_name.clone()).or_insert(0);
790 *count += 1;
791 }
792 }
793}
794
795#[async_trait]
796impl Agent for LlmAgent {
797 fn name(&self) -> &str {
798 &self.name
799 }
800
801 fn description(&self) -> &str {
802 &self.description
803 }
804
805 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
806 &self.sub_agents
807 }
808
809 #[adk_telemetry::instrument(
810 skip(self, ctx),
811 fields(
812 agent.name = %self.name,
813 agent.description = %self.description,
814 invocation.id = %ctx.invocation_id(),
815 user.id = %ctx.user_id(),
816 session.id = %ctx.session_id()
817 )
818 )]
819 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
820 adk_telemetry::info!("Starting agent execution");
821 let ctx = Self::apply_input_guardrails(ctx, self.input_guardrails.clone()).await?;
822
823 let agent_name = self.name.clone();
824 let invocation_id = ctx.invocation_id().to_string();
825 let model = self.model.clone();
826 let tools = self.tools.clone();
827 let toolsets = self.toolsets.clone();
828 let sub_agents = self.sub_agents.clone();
829
830 let instruction = self.instruction.clone();
831 let instruction_provider = self.instruction_provider.clone();
832 let global_instruction = self.global_instruction.clone();
833 let global_instruction_provider = self.global_instruction_provider.clone();
834 let skills_index = self.skills_index.clone();
835 let skill_policy = self.skill_policy.clone();
836 let max_skill_chars = self.max_skill_chars;
837 let output_key = self.output_key.clone();
838 let output_schema = self.output_schema.clone();
839 let generate_content_config = self.generate_content_config.clone();
840 let include_contents = self.include_contents;
841 let max_iterations = self.max_iterations;
842 let tool_timeout = self.tool_timeout;
843 let before_agent_callbacks = self.before_callbacks.clone();
845 let after_agent_callbacks = self.after_callbacks.clone();
846 let before_model_callbacks = self.before_model_callbacks.clone();
847 let after_model_callbacks = self.after_model_callbacks.clone();
848 let before_tool_callbacks = self.before_tool_callbacks.clone();
849 let after_tool_callbacks = self.after_tool_callbacks.clone();
850 let on_tool_error_callbacks = self.on_tool_error_callbacks.clone();
851 let after_tool_callbacks_full = self.after_tool_callbacks_full.clone();
852 let default_retry_budget = self.default_retry_budget.clone();
853 let tool_retry_budgets = self.tool_retry_budgets.clone();
854 let circuit_breaker_threshold = self.circuit_breaker_threshold;
855 let tool_confirmation_policy = self.tool_confirmation_policy.clone();
856 let disallow_transfer_to_parent = self.disallow_transfer_to_parent;
857 let disallow_transfer_to_peers = self.disallow_transfer_to_peers;
858 let output_guardrails = self.output_guardrails.clone();
859 let agent_tool_execution_strategy = self.tool_execution_strategy;
860
861 let s = stream! {
862 for callback in before_agent_callbacks.as_ref() {
866 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
867 Ok(Some(content)) => {
868 let mut early_event = Event::new(&invocation_id);
870 early_event.author = agent_name.clone();
871 early_event.llm_response.content = Some(content);
872 yield Ok(early_event);
873
874 for after_callback in after_agent_callbacks.as_ref() {
876 match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
877 Ok(Some(after_content)) => {
878 let mut after_event = Event::new(&invocation_id);
879 after_event.author = agent_name.clone();
880 after_event.llm_response.content = Some(after_content);
881 yield Ok(after_event);
882 return;
883 }
884 Ok(None) => continue,
885 Err(e) => {
886 yield Err(e);
887 return;
888 }
889 }
890 }
891 return;
892 }
893 Ok(None) => {
894 continue;
896 }
897 Err(e) => {
898 yield Err(e);
900 return;
901 }
902 }
903 }
904
905 let mut prompt_preamble = Vec::new();
907
908 if let Some(index) = &skills_index {
912 let user_query = ctx
913 .user_content()
914 .parts
915 .iter()
916 .filter_map(|part| match part {
917 Part::Text { text } => Some(text.as_str()),
918 _ => None,
919 })
920 .collect::<Vec<_>>()
921 .join("\n");
922
923 if let Some((_matched, skill_block)) = select_skill_prompt_block(
924 index.as_ref(),
925 &user_query,
926 &skill_policy,
927 max_skill_chars,
928 ) {
929 prompt_preamble.push(Content {
930 role: "user".to_string(),
931 parts: vec![Part::Text { text: skill_block }],
932 });
933 }
934 }
935
936 if let Some(provider) = &global_instruction_provider {
939 let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
941 if !global_inst.is_empty() {
942 prompt_preamble.push(Content {
943 role: "user".to_string(),
944 parts: vec![Part::Text { text: global_inst }],
945 });
946 }
947 } else if let Some(ref template) = global_instruction {
948 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
950 if !processed.is_empty() {
951 prompt_preamble.push(Content {
952 role: "user".to_string(),
953 parts: vec![Part::Text { text: processed }],
954 });
955 }
956 }
957
958 if let Some(provider) = &instruction_provider {
961 let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
963 if !inst.is_empty() {
964 prompt_preamble.push(Content {
965 role: "user".to_string(),
966 parts: vec![Part::Text { text: inst }],
967 });
968 }
969 } else if let Some(ref template) = instruction {
970 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
972 if !processed.is_empty() {
973 prompt_preamble.push(Content {
974 role: "user".to_string(),
975 parts: vec![Part::Text { text: processed }],
976 });
977 }
978 }
979
980 let session_history = if !ctx.run_config().transfer_targets.is_empty() {
986 ctx.session().conversation_history_for_agent(&agent_name)
987 } else {
988 ctx.session().conversation_history()
989 };
990 let mut session_history = session_history;
991 let current_user_content = ctx.user_content().clone();
992 if let Some(index) = session_history.iter().rposition(|content| content.role == "user") {
993 session_history[index] = current_user_content.clone();
994 } else {
995 session_history.push(current_user_content.clone());
996 }
997
998 let mut conversation_history = match include_contents {
1001 adk_core::IncludeContents::None => {
1002 let mut filtered = prompt_preamble.clone();
1003 filtered.push(current_user_content);
1004 filtered
1005 }
1006 adk_core::IncludeContents::Default => {
1007 let mut full_history = prompt_preamble;
1008 full_history.extend(session_history);
1009 full_history
1010 }
1011 };
1012
1013 let mut resolved_tools: Vec<Arc<dyn Tool>> = tools.clone();
1016 let static_tool_names: std::collections::HashSet<String> =
1017 tools.iter().map(|t| t.name().to_string()).collect();
1018
1019 let mut toolset_source: std::collections::HashMap<String, String> =
1021 std::collections::HashMap::new();
1022
1023 for toolset in &toolsets {
1024 let toolset_tools = match toolset
1025 .tools(ctx.clone() as Arc<dyn ReadonlyContext>)
1026 .await
1027 {
1028 Ok(t) => t,
1029 Err(e) => {
1030 yield Err(e);
1031 return;
1032 }
1033 };
1034 for tool in &toolset_tools {
1035 let name = tool.name().to_string();
1036 if static_tool_names.contains(&name) {
1038 yield Err(adk_core::AdkError::agent(format!(
1039 "Duplicate tool name '{name}': conflict between static tool and toolset '{}'",
1040 toolset.name()
1041 )));
1042 return;
1043 }
1044 if let Some(other_toolset_name) = toolset_source.get(&name) {
1046 yield Err(adk_core::AdkError::agent(format!(
1047 "Duplicate tool name '{name}': conflict between toolset '{}' and toolset '{}'",
1048 other_toolset_name,
1049 toolset.name()
1050 )));
1051 return;
1052 }
1053 toolset_source.insert(name, toolset.name().to_string());
1054 resolved_tools.push(tool.clone());
1055 }
1056 }
1057
1058 let tool_map: std::collections::HashMap<String, Arc<dyn Tool>> = resolved_tools
1060 .iter()
1061 .map(|t| (t.name().to_string(), t.clone()))
1062 .collect();
1063
1064 let collect_long_running_ids = |content: &Content| -> Vec<String> {
1066 content.parts.iter()
1067 .filter_map(|p| {
1068 if let Part::FunctionCall { name, .. } = p {
1069 if let Some(tool) = tool_map.get(name) {
1070 if tool.is_long_running() {
1071 return Some(name.clone());
1072 }
1073 }
1074 }
1075 None
1076 })
1077 .collect()
1078 };
1079
1080 let mut tool_declarations = std::collections::HashMap::new();
1085 for tool in &resolved_tools {
1086 tool_declarations.insert(tool.name().to_string(), tool.declaration());
1087 }
1088
1089 let mut valid_transfer_targets: Vec<String> = sub_agents
1094 .iter()
1095 .map(|a| a.name().to_string())
1096 .collect();
1097
1098 let run_config_targets = &ctx.run_config().transfer_targets;
1100 let parent_agent_name = ctx.run_config().parent_agent.clone();
1101 let sub_agent_names: std::collections::HashSet<&str> = sub_agents
1102 .iter()
1103 .map(|a| a.name())
1104 .collect();
1105
1106 for target in run_config_targets {
1107 if sub_agent_names.contains(target.as_str()) {
1109 continue;
1110 }
1111
1112 let is_parent = parent_agent_name.as_deref() == Some(target.as_str());
1114 if is_parent && disallow_transfer_to_parent {
1115 continue;
1116 }
1117 if !is_parent && disallow_transfer_to_peers {
1118 continue;
1119 }
1120
1121 valid_transfer_targets.push(target.clone());
1122 }
1123
1124 if !valid_transfer_targets.is_empty() {
1126 let transfer_tool_name = "transfer_to_agent";
1127 let transfer_tool_decl = serde_json::json!({
1128 "name": transfer_tool_name,
1129 "description": format!(
1130 "Transfer execution to another agent. Valid targets: {}",
1131 valid_transfer_targets.join(", ")
1132 ),
1133 "parameters": {
1134 "type": "object",
1135 "properties": {
1136 "agent_name": {
1137 "type": "string",
1138 "description": "The name of the agent to transfer to.",
1139 "enum": valid_transfer_targets
1140 }
1141 },
1142 "required": ["agent_name"]
1143 }
1144 });
1145 tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
1146 }
1147
1148
1149 let mut circuit_breaker_state = circuit_breaker_threshold.map(CircuitBreakerState::new);
1152
1153 let mut iteration = 0;
1155
1156 loop {
1157 iteration += 1;
1158 if iteration > max_iterations {
1159 yield Err(adk_core::AdkError::agent(
1160 format!("Max iterations ({max_iterations}) exceeded")
1161 ));
1162 return;
1163 }
1164
1165 let config = match (&generate_content_config, &output_schema) {
1172 (Some(base), Some(schema)) => {
1173 let mut merged = base.clone();
1174 merged.response_schema = Some(schema.clone());
1175 Some(merged)
1176 }
1177 (Some(base), None) => Some(base.clone()),
1178 (None, Some(schema)) => Some(adk_core::GenerateContentConfig {
1179 response_schema: Some(schema.clone()),
1180 ..Default::default()
1181 }),
1182 (None, None) => None,
1183 };
1184
1185 let config = if let Some(ref cached) = ctx.run_config().cached_content {
1187 let mut cfg = config.unwrap_or_default();
1188 if cfg.cached_content.is_none() {
1190 cfg.cached_content = Some(cached.clone());
1191 }
1192 Some(cfg)
1193 } else {
1194 config
1195 };
1196
1197 let request = LlmRequest {
1198 model: model.name().to_string(),
1199 contents: conversation_history.clone(),
1200 tools: tool_declarations.clone(),
1201 config,
1202 };
1203
1204 let mut current_request = request;
1207 let mut model_response_override = None;
1208 for callback in before_model_callbacks.as_ref() {
1209 match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
1210 Ok(BeforeModelResult::Continue(modified_request)) => {
1211 current_request = modified_request;
1213 }
1214 Ok(BeforeModelResult::Skip(response)) => {
1215 model_response_override = Some(response);
1217 break;
1218 }
1219 Err(e) => {
1220 yield Err(e);
1222 return;
1223 }
1224 }
1225 }
1226 let request = current_request;
1227
1228 let mut accumulated_content: Option<Content> = None;
1230 let mut final_provider_metadata: Option<serde_json::Value> = None;
1231
1232 if let Some(cached_response) = model_response_override {
1233 accumulated_content = cached_response.content.clone();
1236 final_provider_metadata = cached_response.provider_metadata.clone();
1237 normalize_option_content(&mut accumulated_content);
1238 if let Some(content) = accumulated_content.take() {
1239 let has_function_calls = content
1240 .parts
1241 .iter()
1242 .any(|part| matches!(part, Part::FunctionCall { .. }));
1243 let content = if has_function_calls {
1244 content
1245 } else {
1246 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1247 };
1248 accumulated_content = Some(content);
1249 }
1250
1251 let mut cached_event = Event::new(&invocation_id);
1252 cached_event.author = agent_name.clone();
1253 cached_event.llm_response.content = accumulated_content.clone();
1254 cached_event.llm_response.provider_metadata = cached_response.provider_metadata.clone();
1255 cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
1256 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), serde_json::to_string(&request).unwrap_or_default());
1257 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&cached_response).unwrap_or_default());
1258
1259 if let Some(ref content) = accumulated_content {
1261 cached_event.long_running_tool_ids = collect_long_running_ids(content);
1262 }
1263
1264 yield Ok(cached_event);
1265 } else {
1266 let request_json = serde_json::to_string(&request).unwrap_or_default();
1268
1269 let llm_ts = std::time::SystemTime::now()
1271 .duration_since(std::time::UNIX_EPOCH)
1272 .unwrap_or_default()
1273 .as_nanos();
1274 let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
1275 let llm_span = tracing::info_span!(
1276 "call_llm",
1277 "gcp.vertex.agent.event_id" = %llm_event_id,
1278 "gcp.vertex.agent.invocation_id" = %invocation_id,
1279 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1280 "gen_ai.conversation.id" = %ctx.session_id(),
1281 "gcp.vertex.agent.llm_request" = %request_json,
1282 "gcp.vertex.agent.llm_response" = tracing::field::Empty );
1284 let _llm_guard = llm_span.enter();
1285
1286 use adk_core::StreamingMode;
1288 let streaming_mode = ctx.run_config().streaming_mode;
1289 let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi)
1290 && output_guardrails.is_empty();
1291
1292 let mut response_stream = model.generate_content(request, true).await?;
1294
1295 use futures::StreamExt;
1296
1297 let mut last_chunk: Option<LlmResponse> = None;
1299
1300 while let Some(chunk_result) = response_stream.next().await {
1302 let mut chunk = match chunk_result {
1303 Ok(c) => c,
1304 Err(e) => {
1305 yield Err(e);
1306 return;
1307 }
1308 };
1309
1310 for callback in after_model_callbacks.as_ref() {
1313 match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
1314 Ok(Some(modified_chunk)) => {
1315 chunk = modified_chunk;
1317 break;
1318 }
1319 Ok(None) => {
1320 continue;
1322 }
1323 Err(e) => {
1324 yield Err(e);
1326 return;
1327 }
1328 }
1329 }
1330
1331 normalize_option_content(&mut chunk.content);
1332
1333 if let Some(chunk_content) = chunk.content.clone() {
1335 if let Some(ref mut acc) = accumulated_content {
1336 acc.parts.extend(chunk_content.parts);
1337 } else {
1338 accumulated_content = Some(chunk_content);
1339 }
1340 }
1341
1342 if should_stream_to_client {
1344 let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
1345 partial_event.author = agent_name.clone();
1346 partial_event.llm_request = Some(request_json.clone());
1347 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1348 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&chunk).unwrap_or_default());
1349 partial_event.llm_response.partial = chunk.partial;
1350 partial_event.llm_response.turn_complete = chunk.turn_complete;
1351 partial_event.llm_response.finish_reason = chunk.finish_reason;
1352 partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
1353 partial_event.llm_response.content = chunk.content.clone();
1354 partial_event.llm_response.provider_metadata = chunk.provider_metadata.clone();
1355
1356 if let Some(ref content) = chunk.content {
1358 partial_event.long_running_tool_ids = collect_long_running_ids(content);
1359 }
1360
1361 yield Ok(partial_event);
1362 }
1363
1364 last_chunk = Some(chunk.clone());
1366
1367 if chunk.turn_complete {
1369 break;
1370 }
1371 }
1372
1373 if !should_stream_to_client {
1375 if let Some(content) = accumulated_content.take() {
1376 let has_function_calls = content
1377 .parts
1378 .iter()
1379 .any(|part| matches!(part, Part::FunctionCall { .. }));
1380 let content = if has_function_calls {
1381 content
1382 } else {
1383 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1384 };
1385 accumulated_content = Some(content);
1386 }
1387
1388 let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
1389 final_event.author = agent_name.clone();
1390 final_event.llm_request = Some(request_json.clone());
1391 final_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1392 final_event.llm_response.content = accumulated_content.clone();
1393 final_event.llm_response.partial = false;
1394 final_event.llm_response.turn_complete = true;
1395
1396 if let Some(ref last) = last_chunk {
1398 final_event.llm_response.finish_reason = last.finish_reason;
1399 final_event.llm_response.usage_metadata = last.usage_metadata.clone();
1400 final_event.llm_response.provider_metadata = last.provider_metadata.clone();
1401 final_provider_metadata = last.provider_metadata.clone();
1402 final_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(last).unwrap_or_default());
1403 }
1404
1405 if let Some(ref content) = accumulated_content {
1407 final_event.long_running_tool_ids = collect_long_running_ids(content);
1408 }
1409
1410 yield Ok(final_event);
1411 }
1412
1413 if let Some(ref content) = accumulated_content {
1415 let response_json = serde_json::to_string(content).unwrap_or_default();
1416 llm_span.record("gcp.vertex.agent.llm_response", &response_json);
1417 }
1418 }
1419
1420 let function_call_names: Vec<String> = accumulated_content.as_ref()
1422 .map(|c| c.parts.iter()
1423 .filter_map(|p| {
1424 if let Part::FunctionCall { name, .. } = p {
1425 Some(name.clone())
1426 } else {
1427 None
1428 }
1429 })
1430 .collect())
1431 .unwrap_or_default();
1432
1433 let has_function_calls = !function_call_names.is_empty();
1434
1435 let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
1439 tool_map.get(name)
1440 .map(|t| t.is_long_running())
1441 .unwrap_or(false)
1442 });
1443
1444 if let Some(ref content) = accumulated_content {
1446 conversation_history.push(Self::augment_content_for_history(
1447 content,
1448 final_provider_metadata.as_ref(),
1449 ));
1450
1451 if let Some(ref output_key) = output_key {
1453 if !has_function_calls { let mut text_parts = String::new();
1455 for part in &content.parts {
1456 if let Part::Text { text } = part {
1457 text_parts.push_str(text);
1458 }
1459 }
1460 if !text_parts.is_empty() {
1461 let mut state_event = Event::new(&invocation_id);
1463 state_event.author = agent_name.clone();
1464 state_event.actions.state_delta.insert(
1465 output_key.clone(),
1466 serde_json::Value::String(text_parts),
1467 );
1468 yield Ok(state_event);
1469 }
1470 }
1471 }
1472 }
1473
1474 if !has_function_calls {
1475 if let Some(ref content) = accumulated_content {
1478 let response_json = serde_json::to_string(content).unwrap_or_default();
1479 tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
1480 }
1481
1482 tracing::info!(agent.name = %agent_name, "Agent execution complete");
1483 break;
1484 }
1485
1486 if let Some(content) = &accumulated_content {
1488 let strategy = agent_tool_execution_strategy
1491 .unwrap_or(ToolExecutionStrategy::Sequential);
1492
1493 let mut fc_parts: Vec<(usize, String, serde_json::Value, Option<String>, String)> = Vec::new();
1497 {
1498 let mut tci = 0usize;
1499 for part in &content.parts {
1500 if let Part::FunctionCall { name, args, id, .. } = part {
1501 let fallback = format!("{}_{}_{}", invocation_id, name, tci);
1502 let fcid = id.clone().unwrap_or(fallback);
1503 fc_parts.push((tci, name.clone(), args.clone(), id.clone(), fcid));
1504 tci += 1;
1505 }
1506 }
1507 }
1508
1509 let mut transfer_handled = false;
1513 for (_, fc_name, fc_args, fc_id, _) in &fc_parts {
1514 if fc_name == "transfer_to_agent" {
1515 let target_agent = fc_args.get("agent_name")
1516 .and_then(|v| v.as_str())
1517 .unwrap_or_default()
1518 .to_string();
1519
1520 let valid_target = valid_transfer_targets.iter().any(|n| n == &target_agent);
1521 if !valid_target {
1522 let error_content = Content {
1523 role: "function".to_string(),
1524 parts: vec![Part::FunctionResponse {
1525 function_response: FunctionResponseData::new(
1526 fc_name.clone(),
1527 serde_json::json!({
1528 "error": format!(
1529 "Agent '{}' not found. Available agents: {:?}",
1530 target_agent, valid_transfer_targets
1531 )
1532 }),
1533 ),
1534 id: fc_id.clone(),
1535 }],
1536 };
1537 conversation_history.push(error_content.clone());
1538 let mut error_event = Event::new(&invocation_id);
1539 error_event.author = agent_name.clone();
1540 error_event.llm_response.content = Some(error_content);
1541 yield Ok(error_event);
1542 continue;
1543 }
1544
1545 let mut transfer_event = Event::new(&invocation_id);
1546 transfer_event.author = agent_name.clone();
1547 transfer_event.actions.transfer_to_agent = Some(target_agent);
1548 yield Ok(transfer_event);
1549 transfer_handled = true;
1550 break;
1551 }
1552 }
1553 if transfer_handled {
1554 return;
1555 }
1556
1557 let fc_parts: Vec<_> = fc_parts.into_iter().filter(|(_, fc_name, _, _, _)| {
1559 if fc_name == "transfer_to_agent" {
1560 return false;
1561 }
1562 if let Some(tool) = tool_map.get(fc_name) {
1563 if tool.is_builtin() {
1564 adk_telemetry::debug!(tool.name = %fc_name, "skipping built-in tool execution");
1565 return false;
1566 }
1567 }
1568 true
1569 }).collect();
1570
1571 let mut confirmation_interrupted = false;
1575 for (_, fc_name, fc_args, _, fc_call_id) in &fc_parts {
1576 if tool_confirmation_policy.requires_confirmation(fc_name)
1577 && ctx.run_config().tool_confirmation_decisions.get(fc_name).copied().is_none()
1578 {
1579 let mut ce = Event::new(&invocation_id);
1580 ce.author = agent_name.clone();
1581 ce.llm_response.interrupted = true;
1582 ce.llm_response.turn_complete = true;
1583 ce.llm_response.content = Some(Content {
1584 role: "model".to_string(),
1585 parts: vec![Part::Text {
1586 text: format!(
1587 "Tool confirmation required for '{}'. Provide approve/deny decision to continue.",
1588 fc_name
1589 ),
1590 }],
1591 });
1592 ce.actions.tool_confirmation = Some(ToolConfirmationRequest {
1593 tool_name: fc_name.clone(),
1594 function_call_id: Some(fc_call_id.clone()),
1595 args: fc_args.clone(),
1596 });
1597 yield Ok(ce);
1598 confirmation_interrupted = true;
1599 break;
1600 }
1601 }
1602 if confirmation_interrupted {
1603 return;
1604 }
1605
1606 let cb_mutex = std::sync::Mutex::new(circuit_breaker_state.take());
1608
1609 let execute_one_tool = |idx: usize, name: String, args: serde_json::Value,
1614 id: Option<String>, function_call_id: String| {
1615 let ctx = ctx.clone();
1616 let tool_map = &tool_map;
1617 let tool_retry_budgets = &tool_retry_budgets;
1618 let default_retry_budget = &default_retry_budget;
1619 let before_tool_callbacks = &before_tool_callbacks;
1620 let after_tool_callbacks = &after_tool_callbacks;
1621 let after_tool_callbacks_full = &after_tool_callbacks_full;
1622 let on_tool_error_callbacks = &on_tool_error_callbacks;
1623 let tool_confirmation_policy = &tool_confirmation_policy;
1624 let cb_mutex = &cb_mutex;
1625 let invocation_id = &invocation_id;
1626 async move {
1627 let mut tool_actions = EventActions::default();
1628 let mut response_content: Option<Content> = None;
1629 let mut run_after_tool_callbacks = true;
1630 let mut tool_outcome_for_callback: Option<ToolOutcome> = None;
1631 let mut executed_tool: Option<Arc<dyn Tool>> = None;
1632 let mut executed_tool_response: Option<serde_json::Value> = None;
1633
1634 if tool_confirmation_policy.requires_confirmation(&name) {
1636 match ctx.run_config().tool_confirmation_decisions.get(&name).copied() {
1637 Some(ToolConfirmationDecision::Approve) => {
1638 tool_actions.tool_confirmation_decision =
1639 Some(ToolConfirmationDecision::Approve);
1640 }
1641 Some(ToolConfirmationDecision::Deny) => {
1642 tool_actions.tool_confirmation_decision =
1643 Some(ToolConfirmationDecision::Deny);
1644 response_content = Some(Content {
1645 role: "function".to_string(),
1646 parts: vec![Part::FunctionResponse {
1647 function_response: FunctionResponseData::new(
1648 name.clone(),
1649 serde_json::json!({
1650 "error": format!("Tool '{}' execution denied by confirmation policy", name)
1651 }),
1652 ),
1653 id: id.clone(),
1654 }],
1655 });
1656 run_after_tool_callbacks = false;
1657 }
1658 None => {
1659 response_content = Some(Content {
1660 role: "function".to_string(),
1661 parts: vec![Part::FunctionResponse {
1662 function_response: FunctionResponseData::new(
1663 name.clone(),
1664 serde_json::json!({
1665 "error": format!("Tool '{}' requires confirmation", name)
1666 }),
1667 ),
1668 id: id.clone(),
1669 }],
1670 });
1671 run_after_tool_callbacks = false;
1672 }
1673 }
1674 }
1675
1676 if response_content.is_none() {
1678 let tool_ctx = Arc::new(ToolCallbackContext::new(
1679 ctx.clone(),
1680 name.clone(),
1681 args.clone(),
1682 ));
1683 for callback in before_tool_callbacks.as_ref() {
1684 match callback(tool_ctx.clone() as Arc<dyn CallbackContext>).await {
1685 Ok(Some(c)) => { response_content = Some(c); break; }
1686 Ok(None) => continue,
1687 Err(e) => {
1688 response_content = Some(Content {
1689 role: "function".to_string(),
1690 parts: vec![Part::FunctionResponse {
1691 function_response: FunctionResponseData::new(
1692 name.clone(),
1693 serde_json::json!({ "error": e.to_string() }),
1694 ),
1695 id: id.clone(),
1696 }],
1697 });
1698 run_after_tool_callbacks = false;
1699 break;
1700 }
1701 }
1702 }
1703 }
1704
1705 if response_content.is_none() {
1707 let guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
1708 if let Some(ref cb_state) = *guard {
1709 if cb_state.is_open(&name) {
1710 let msg = format!(
1711 "Tool '{}' is temporarily disabled after {} consecutive failures",
1712 name, cb_state.threshold
1713 );
1714 tracing::warn!(tool.name = %name, "circuit breaker open, skipping tool execution");
1715 response_content = Some(Content {
1716 role: "function".to_string(),
1717 parts: vec![Part::FunctionResponse {
1718 function_response: FunctionResponseData::new(
1719 name.clone(),
1720 serde_json::json!({ "error": msg }),
1721 ),
1722 id: id.clone(),
1723 }],
1724 });
1725 run_after_tool_callbacks = false;
1726 }
1727 }
1728 drop(guard);
1729 }
1730
1731 if response_content.is_none() {
1733 if let Some(tool) = tool_map.get(&name) {
1734 let tool_ctx: Arc<dyn ToolContext> = Arc::new(
1735 AgentToolContext::new(ctx.clone(), function_call_id.clone()),
1736 );
1737 let span_name = format!("execute_tool {name}");
1738 let tool_span = tracing::info_span!(
1739 "",
1740 otel.name = %span_name,
1741 tool.name = %name,
1742 "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
1743 "gcp.vertex.agent.invocation_id" = %invocation_id,
1744 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1745 "gen_ai.conversation.id" = %ctx.session_id()
1746 );
1747
1748 let budget = tool_retry_budgets.get(&name)
1749 .or(default_retry_budget.as_ref());
1750 let max_attempts = budget.map(|b| b.max_retries + 1).unwrap_or(1);
1751 let retry_delay = budget.map(|b| b.delay).unwrap_or_default();
1752
1753 let tool_clone = tool.clone();
1754 let tool_start = std::time::Instant::now();
1755 let mut last_error = String::new();
1756 let mut final_attempt: u32 = 0;
1757 let mut retry_result: Option<serde_json::Value> = None;
1758
1759 for attempt in 0..max_attempts {
1760 final_attempt = attempt;
1761 if attempt > 0 {
1762 tokio::time::sleep(retry_delay).await;
1763 }
1764 match async {
1765 tracing::info!(tool.name = %name, tool.args = %args, attempt = attempt, "tool_call");
1766 let exec_future = tool_clone.execute(tool_ctx.clone(), args.clone());
1767 tokio::time::timeout(tool_timeout, exec_future).await
1768 }.instrument(tool_span.clone()).await {
1769 Ok(Ok(value)) => {
1770 tracing::info!(tool.name = %name, tool.result = %value, "tool_result");
1771 retry_result = Some(value);
1772 break;
1773 }
1774 Ok(Err(e)) => {
1775 last_error = e.to_string();
1776 if attempt + 1 < max_attempts {
1777 tracing::warn!(tool.name = %name, attempt = attempt, error = %last_error, "tool execution failed, retrying");
1778 } else {
1779 tracing::warn!(tool.name = %name, error = %last_error, "tool_error");
1780 }
1781 }
1782 Err(_) => {
1783 last_error = format!(
1784 "Tool '{}' timed out after {} seconds",
1785 name, tool_timeout.as_secs()
1786 );
1787 if attempt + 1 < max_attempts {
1788 tracing::warn!(tool.name = %name, attempt = attempt, timeout_secs = tool_timeout.as_secs(), "tool timed out, retrying");
1789 } else {
1790 tracing::warn!(tool.name = %name, timeout_secs = tool_timeout.as_secs(), "tool_timeout");
1791 }
1792 }
1793 }
1794 }
1795
1796 let tool_duration = tool_start.elapsed();
1797 let (tool_success, tool_error_message, function_response) = match retry_result {
1798 Some(value) => (true, None, value),
1799 None => (false, Some(last_error.clone()), serde_json::json!({ "error": last_error })),
1800 };
1801
1802 let outcome = ToolOutcome {
1803 tool_name: name.clone(),
1804 tool_args: args.clone(),
1805 success: tool_success,
1806 duration: tool_duration,
1807 error_message: tool_error_message.clone(),
1808 attempt: final_attempt,
1809 };
1810 tool_outcome_for_callback = Some(outcome);
1811
1812 {
1814 let mut guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
1815 if let Some(ref mut cb_state) = *guard {
1816 cb_state.record(tool_outcome_for_callback.as_ref().unwrap());
1817 }
1818 }
1819
1820 let final_function_response = if !tool_success {
1822 let mut fallback_result = None;
1823 let error_msg = tool_error_message.clone().unwrap_or_default();
1824 for callback in on_tool_error_callbacks.as_ref() {
1825 match callback(
1826 ctx.clone() as Arc<dyn CallbackContext>,
1827 tool.clone(),
1828 args.clone(),
1829 error_msg.clone(),
1830 ).await {
1831 Ok(Some(result)) => { fallback_result = Some(result); break; }
1832 Ok(None) => continue,
1833 Err(e) => { tracing::warn!(error = %e, "on_tool_error callback failed"); break; }
1834 }
1835 }
1836 fallback_result.unwrap_or(function_response)
1837 } else {
1838 function_response
1839 };
1840
1841 let confirmation_decision = tool_actions.tool_confirmation_decision;
1842 tool_actions = tool_ctx.actions();
1843 if tool_actions.tool_confirmation_decision.is_none() {
1844 tool_actions.tool_confirmation_decision = confirmation_decision;
1845 }
1846 executed_tool = Some(tool.clone());
1847 executed_tool_response = Some(final_function_response.clone());
1848 response_content = Some(Content {
1849 role: "function".to_string(),
1850 parts: vec![Part::FunctionResponse {
1851 function_response: FunctionResponseData::from_tool_result(
1852 name.clone(),
1853 final_function_response,
1854 ),
1855 id: id.clone(),
1856 }],
1857 });
1858 } else {
1859 response_content = Some(Content {
1860 role: "function".to_string(),
1861 parts: vec![Part::FunctionResponse {
1862 function_response: FunctionResponseData::new(
1863 name.clone(),
1864 serde_json::json!({
1865 "error": format!("Tool {} not found", name)
1866 }),
1867 ),
1868 id: id.clone(),
1869 }],
1870 });
1871 }
1872 }
1873
1874 let mut response_content = response_content.expect("tool response content is set");
1876 if run_after_tool_callbacks {
1877 let outcome_ctx: Arc<dyn CallbackContext> = match tool_outcome_for_callback {
1878 Some(outcome) => Arc::new(ToolOutcomeCallbackContext {
1879 inner: ctx.clone() as Arc<dyn CallbackContext>,
1880 outcome,
1881 }),
1882 None => ctx.clone() as Arc<dyn CallbackContext>,
1883 };
1884 let cb_ctx: Arc<dyn CallbackContext> = Arc::new(ToolCallbackContext::new(
1885 outcome_ctx,
1886 name.clone(),
1887 args.clone(),
1888 ));
1889 for callback in after_tool_callbacks.as_ref() {
1890 match callback(cb_ctx.clone()).await {
1891 Ok(Some(modified)) => { response_content = modified; break; }
1892 Ok(None) => continue,
1893 Err(e) => {
1894 response_content = Content {
1895 role: "function".to_string(),
1896 parts: vec![Part::FunctionResponse {
1897 function_response: FunctionResponseData::new(
1898 name.clone(),
1899 serde_json::json!({ "error": e.to_string() }),
1900 ),
1901 id: id.clone(),
1902 }],
1903 };
1904 break;
1905 }
1906 }
1907 }
1908 if let (Some(tool_ref), Some(tool_resp)) = (&executed_tool, executed_tool_response) {
1909 for callback in after_tool_callbacks_full.as_ref() {
1910 match callback(
1911 cb_ctx.clone(), tool_ref.clone(), args.clone(), tool_resp.clone(),
1912 ).await {
1913 Ok(Some(modified_value)) => {
1914 response_content = Content {
1915 role: "function".to_string(),
1916 parts: vec![Part::FunctionResponse {
1917 function_response: FunctionResponseData::from_tool_result(
1918 name.clone(),
1919 modified_value,
1920 ),
1921 id: id.clone(),
1922 }],
1923 };
1924 break;
1925 }
1926 Ok(None) => continue,
1927 Err(e) => {
1928 response_content = Content {
1929 role: "function".to_string(),
1930 parts: vec![Part::FunctionResponse {
1931 function_response: FunctionResponseData::new(
1932 name.clone(),
1933 serde_json::json!({ "error": e.to_string() }),
1934 ),
1935 id: id.clone(),
1936 }],
1937 };
1938 break;
1939 }
1940 }
1941 }
1942 }
1943 }
1944
1945 let escalate_or_skip = tool_actions.escalate || tool_actions.skip_summarization;
1946 (idx, response_content, tool_actions, escalate_or_skip)
1947 }
1948 };
1949
1950 let results: Vec<(usize, Content, EventActions, bool)> = match strategy {
1952 ToolExecutionStrategy::Sequential => {
1953 let mut results = Vec::with_capacity(fc_parts.len());
1954 for (idx, name, args, id, fcid) in fc_parts {
1955 results.push(execute_one_tool(idx, name, args, id, fcid).await);
1956 }
1957 results
1958 }
1959 ToolExecutionStrategy::Parallel => {
1960 let futs: Vec<_> = fc_parts.into_iter()
1961 .map(|(idx, name, args, id, fcid)| execute_one_tool(idx, name, args, id, fcid))
1962 .collect();
1963 futures::future::join_all(futs).await
1964 }
1965 ToolExecutionStrategy::Auto => {
1966 let mut read_only_fcs = Vec::new();
1968 let mut mutable_fcs = Vec::new();
1969 for fc in fc_parts {
1970 let is_ro = tool_map.get(&fc.1)
1971 .map(|t| t.is_read_only())
1972 .unwrap_or(false);
1973 if is_ro { read_only_fcs.push(fc); } else { mutable_fcs.push(fc); }
1974 }
1975 let mut all_results = Vec::new();
1976 if !read_only_fcs.is_empty() {
1978 let ro_futs: Vec<_> = read_only_fcs.into_iter()
1979 .map(|(idx, name, args, id, fcid)| execute_one_tool(idx, name, args, id, fcid))
1980 .collect();
1981 all_results.extend(futures::future::join_all(ro_futs).await);
1982 }
1983 for (idx, name, args, id, fcid) in mutable_fcs {
1985 all_results.push(execute_one_tool(idx, name, args, id, fcid).await);
1986 }
1987 all_results.sort_by_key(|r| r.0);
1989 all_results
1990 }
1991 };
1992
1993 circuit_breaker_state = cb_mutex.into_inner().unwrap_or_else(|e| e.into_inner());
1995
1996 for (_, response_content, tool_actions, escalate_or_skip) in results {
1998 let mut tool_event = Event::new(&invocation_id);
1999 tool_event.author = agent_name.clone();
2000 tool_event.actions = tool_actions;
2001 tool_event.llm_response.content = Some(response_content.clone());
2002 yield Ok(tool_event);
2003
2004 if escalate_or_skip {
2005 return;
2006 }
2007
2008 conversation_history.push(response_content);
2009 }
2010 }
2011
2012 if all_calls_are_long_running {
2016 }
2020 }
2021
2022 for callback in after_agent_callbacks.as_ref() {
2025 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
2026 Ok(Some(content)) => {
2027 let mut after_event = Event::new(&invocation_id);
2029 after_event.author = agent_name.clone();
2030 after_event.llm_response.content = Some(content);
2031 yield Ok(after_event);
2032 break; }
2034 Ok(None) => {
2035 continue;
2037 }
2038 Err(e) => {
2039 yield Err(e);
2041 return;
2042 }
2043 }
2044 }
2045 };
2046
2047 Ok(Box::pin(s))
2048 }
2049}