1use adk_core::{
2 AfterAgentCallback, AfterModelCallback, AfterToolCallback, AfterToolCallbackFull, Agent,
3 BeforeAgentCallback, BeforeModelCallback, BeforeModelResult, BeforeToolCallback,
4 CallbackContext, Content, Event, EventActions, FunctionResponseData, GlobalInstructionProvider,
5 InstructionProvider, InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry,
6 OnToolErrorCallback, Part, ReadonlyContext, Result, RetryBudget, Tool, ToolCallbackContext,
7 ToolConfirmationDecision, ToolConfirmationPolicy, ToolConfirmationRequest, ToolContext,
8 ToolExecutionStrategy, ToolOutcome, Toolset,
9};
10use async_stream::stream;
11use async_trait::async_trait;
12use std::sync::{Arc, Mutex};
13use tracing::Instrument;
14
15#[cfg(feature = "skills")]
16use crate::skill_shim::load_skill_index;
17use crate::{
18 guardrails::{GuardrailSet, enforce_guardrails},
19 skill_shim::{SelectionPolicy, SkillIndex, select_skill_prompt_block},
20 tool_call_markup::normalize_option_content,
21 workflow::with_user_content_override,
22};
23
24pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
26
27pub const DEFAULT_TOOL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
29
30fn trace_json_payload<T: serde::Serialize>(
31 value: &T,
32 record_payloads: bool,
33 max_bytes: usize,
34) -> String {
35 let json = serde_json::to_string(value).unwrap_or_default();
36 if cfg!(feature = "record-payloads") && record_payloads {
37 return json;
38 }
39
40 let max_bytes = max_bytes.max(32);
41 if json.len() <= max_bytes {
42 return json;
43 }
44
45 let mut end = max_bytes;
46 while !json.is_char_boundary(end) {
47 end -= 1;
48 }
49 format!("{}...[truncated {} bytes]", &json[..end], json.len() - end)
50}
51
52pub struct LlmAgent {
53 name: String,
54 description: String,
55 model: Arc<dyn Llm>,
56 instruction: Option<String>,
57 instruction_provider: Option<Arc<InstructionProvider>>,
58 global_instruction: Option<String>,
59 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
60 skills_index: Option<Arc<SkillIndex>>,
61 skill_policy: SelectionPolicy,
62 max_skill_chars: usize,
63 #[allow(dead_code)] input_schema: Option<serde_json::Value>,
65 output_schema: Option<serde_json::Value>,
66 disallow_transfer_to_parent: bool,
67 disallow_transfer_to_peers: bool,
68 include_contents: adk_core::IncludeContents,
69 tools: Vec<Arc<dyn Tool>>,
70 #[allow(dead_code)] toolsets: Vec<Arc<dyn Toolset>>,
72 sub_agents: Vec<Arc<dyn Agent>>,
73 output_key: Option<String>,
74 generate_content_config: Option<adk_core::GenerateContentConfig>,
76 max_iterations: u32,
78 tool_timeout: std::time::Duration,
80 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
81 after_callbacks: Arc<Vec<AfterAgentCallback>>,
82 before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
83 after_model_callbacks: Arc<Vec<AfterModelCallback>>,
84 before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
85 after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
86 on_tool_error_callbacks: Arc<Vec<OnToolErrorCallback>>,
87 after_tool_callbacks_full: Arc<Vec<AfterToolCallbackFull>>,
89 default_retry_budget: Option<RetryBudget>,
91 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
93 circuit_breaker_threshold: Option<u32>,
96 tool_confirmation_policy: ToolConfirmationPolicy,
97 tool_execution_strategy: Option<ToolExecutionStrategy>,
100 input_guardrails: Arc<GuardrailSet>,
101 output_guardrails: Arc<GuardrailSet>,
102}
103
104impl std::fmt::Debug for LlmAgent {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("LlmAgent")
107 .field("name", &self.name)
108 .field("description", &self.description)
109 .field("model", &self.model.name())
110 .field("instruction", &self.instruction)
111 .field("tools_count", &self.tools.len())
112 .field("sub_agents_count", &self.sub_agents.len())
113 .finish()
114 }
115}
116
117impl LlmAgent {
118 async fn apply_input_guardrails(
119 ctx: Arc<dyn InvocationContext>,
120 input_guardrails: Arc<GuardrailSet>,
121 ) -> Result<Arc<dyn InvocationContext>> {
122 let content =
123 enforce_guardrails(input_guardrails.as_ref(), ctx.user_content(), "input").await?;
124 if content.role != ctx.user_content().role || content.parts != ctx.user_content().parts {
125 Ok(with_user_content_override(ctx, content))
126 } else {
127 Ok(ctx)
128 }
129 }
130
131 async fn apply_output_guardrails(
132 output_guardrails: &GuardrailSet,
133 content: Content,
134 ) -> Result<Content> {
135 enforce_guardrails(output_guardrails, &content, "output").await
136 }
137
138 fn history_parts_from_provider_metadata(
139 provider_metadata: Option<&serde_json::Value>,
140 ) -> Vec<Part> {
141 let Some(provider_metadata) = provider_metadata else {
142 return Vec::new();
143 };
144
145 let history_parts = provider_metadata
146 .get("conversation_history_parts")
147 .or_else(|| {
148 provider_metadata
149 .get("openai")
150 .and_then(|openai| openai.get("conversation_history_parts"))
151 })
152 .and_then(serde_json::Value::as_array);
153
154 history_parts
155 .into_iter()
156 .flatten()
157 .filter_map(|value| serde_json::from_value::<Part>(value.clone()).ok())
158 .collect()
159 }
160
161 fn augment_content_for_history(
162 content: &Content,
163 provider_metadata: Option<&serde_json::Value>,
164 ) -> Content {
165 let mut augmented = content.clone();
166 augmented.parts.extend(Self::history_parts_from_provider_metadata(provider_metadata));
167 augmented
168 }
169}
170
171pub struct LlmAgentBuilder {
172 name: String,
173 description: Option<String>,
174 model: Option<Arc<dyn Llm>>,
175 instruction: Option<String>,
176 instruction_provider: Option<Arc<InstructionProvider>>,
177 global_instruction: Option<String>,
178 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
179 skills_index: Option<Arc<SkillIndex>>,
180 skill_policy: SelectionPolicy,
181 max_skill_chars: usize,
182 input_schema: Option<serde_json::Value>,
183 output_schema: Option<serde_json::Value>,
184 disallow_transfer_to_parent: bool,
185 disallow_transfer_to_peers: bool,
186 include_contents: adk_core::IncludeContents,
187 tools: Vec<Arc<dyn Tool>>,
188 toolsets: Vec<Arc<dyn Toolset>>,
189 sub_agents: Vec<Arc<dyn Agent>>,
190 output_key: Option<String>,
191 generate_content_config: Option<adk_core::GenerateContentConfig>,
192 max_iterations: u32,
193 tool_timeout: std::time::Duration,
194 before_callbacks: Vec<BeforeAgentCallback>,
195 after_callbacks: Vec<AfterAgentCallback>,
196 before_model_callbacks: Vec<BeforeModelCallback>,
197 after_model_callbacks: Vec<AfterModelCallback>,
198 before_tool_callbacks: Vec<BeforeToolCallback>,
199 after_tool_callbacks: Vec<AfterToolCallback>,
200 on_tool_error_callbacks: Vec<OnToolErrorCallback>,
201 after_tool_callbacks_full: Vec<AfterToolCallbackFull>,
202 default_retry_budget: Option<RetryBudget>,
203 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
204 circuit_breaker_threshold: Option<u32>,
205 tool_confirmation_policy: ToolConfirmationPolicy,
206 tool_execution_strategy: Option<ToolExecutionStrategy>,
207 input_guardrails: GuardrailSet,
208 output_guardrails: GuardrailSet,
209}
210
211impl LlmAgentBuilder {
212 pub fn new(name: impl Into<String>) -> Self {
213 Self {
214 name: name.into(),
215 description: None,
216 model: None,
217 instruction: None,
218 instruction_provider: None,
219 global_instruction: None,
220 global_instruction_provider: None,
221 skills_index: None,
222 skill_policy: SelectionPolicy::default(),
223 max_skill_chars: 2000,
224 input_schema: None,
225 output_schema: None,
226 disallow_transfer_to_parent: false,
227 disallow_transfer_to_peers: false,
228 include_contents: adk_core::IncludeContents::Default,
229 tools: Vec::new(),
230 toolsets: Vec::new(),
231 sub_agents: Vec::new(),
232 output_key: None,
233 generate_content_config: None,
234 max_iterations: DEFAULT_MAX_ITERATIONS,
235 tool_timeout: DEFAULT_TOOL_TIMEOUT,
236 before_callbacks: Vec::new(),
237 after_callbacks: Vec::new(),
238 before_model_callbacks: Vec::new(),
239 after_model_callbacks: Vec::new(),
240 before_tool_callbacks: Vec::new(),
241 after_tool_callbacks: Vec::new(),
242 on_tool_error_callbacks: Vec::new(),
243 after_tool_callbacks_full: Vec::new(),
244 default_retry_budget: None,
245 tool_retry_budgets: std::collections::HashMap::new(),
246 circuit_breaker_threshold: None,
247 tool_confirmation_policy: ToolConfirmationPolicy::Never,
248 tool_execution_strategy: None,
249 input_guardrails: GuardrailSet::new(),
250 output_guardrails: GuardrailSet::new(),
251 }
252 }
253
254 pub fn description(mut self, desc: impl Into<String>) -> Self {
255 self.description = Some(desc.into());
256 self
257 }
258
259 pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
260 self.model = Some(model);
261 self
262 }
263
264 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
265 self.instruction = Some(instruction.into());
266 self
267 }
268
269 pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
270 self.instruction_provider = Some(Arc::new(provider));
271 self
272 }
273
274 pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
275 self.global_instruction = Some(instruction.into());
276 self
277 }
278
279 pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
280 self.global_instruction_provider = Some(Arc::new(provider));
281 self
282 }
283
284 #[cfg(feature = "skills")]
286 pub fn with_skills(mut self, index: SkillIndex) -> Self {
287 self.skills_index = Some(Arc::new(index));
288 self
289 }
290
291 #[cfg(feature = "skills")]
293 pub fn with_auto_skills(self) -> Result<Self> {
294 self.with_skills_from_root(".")
295 }
296
297 #[cfg(feature = "skills")]
299 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
300 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
301 self.skills_index = Some(Arc::new(index));
302 Ok(self)
303 }
304
305 #[cfg(feature = "skills")]
307 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
308 self.skill_policy = policy;
309 self
310 }
311
312 #[cfg(feature = "skills")]
314 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
315 self.max_skill_chars = max_chars;
316 self
317 }
318
319 pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
320 self.input_schema = Some(schema);
321 self
322 }
323
324 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
325 self.output_schema = Some(schema);
326 self
327 }
328
329 pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
330 self.disallow_transfer_to_parent = disallow;
331 self
332 }
333
334 pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
335 self.disallow_transfer_to_peers = disallow;
336 self
337 }
338
339 pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
340 self.include_contents = include;
341 self
342 }
343
344 pub fn output_key(mut self, key: impl Into<String>) -> Self {
345 self.output_key = Some(key.into());
346 self
347 }
348
349 pub fn generate_content_config(mut self, config: adk_core::GenerateContentConfig) -> Self {
370 self.generate_content_config = Some(config);
371 self
372 }
373
374 pub fn temperature(mut self, temperature: f32) -> Self {
377 self.generate_content_config
378 .get_or_insert(adk_core::GenerateContentConfig::default())
379 .temperature = Some(temperature);
380 self
381 }
382
383 pub fn top_p(mut self, top_p: f32) -> Self {
385 self.generate_content_config
386 .get_or_insert(adk_core::GenerateContentConfig::default())
387 .top_p = Some(top_p);
388 self
389 }
390
391 pub fn top_k(mut self, top_k: i32) -> Self {
393 self.generate_content_config
394 .get_or_insert(adk_core::GenerateContentConfig::default())
395 .top_k = Some(top_k);
396 self
397 }
398
399 pub fn max_output_tokens(mut self, max_tokens: i32) -> Self {
401 self.generate_content_config
402 .get_or_insert(adk_core::GenerateContentConfig::default())
403 .max_output_tokens = Some(max_tokens);
404 self
405 }
406
407 pub fn max_iterations(mut self, max: u32) -> Self {
410 self.max_iterations = max;
411 self
412 }
413
414 pub fn tool_timeout(mut self, timeout: std::time::Duration) -> Self {
417 self.tool_timeout = timeout;
418 self
419 }
420
421 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
422 self.tools.push(tool);
423 self
424 }
425
426 pub fn toolset(mut self, toolset: Arc<dyn Toolset>) -> Self {
432 self.toolsets.push(toolset);
433 self
434 }
435
436 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
437 self.sub_agents.push(agent);
438 self
439 }
440
441 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
442 self.before_callbacks.push(callback);
443 self
444 }
445
446 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
447 self.after_callbacks.push(callback);
448 self
449 }
450
451 pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
452 self.before_model_callbacks.push(callback);
453 self
454 }
455
456 pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
457 self.after_model_callbacks.push(callback);
458 self
459 }
460
461 pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
462 self.before_tool_callbacks.push(callback);
463 self
464 }
465
466 pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
467 self.after_tool_callbacks.push(callback);
468 self
469 }
470
471 pub fn after_tool_callback_full(mut self, callback: AfterToolCallbackFull) -> Self {
486 self.after_tool_callbacks_full.push(callback);
487 self
488 }
489
490 pub fn on_tool_error(mut self, callback: OnToolErrorCallback) -> Self {
498 self.on_tool_error_callbacks.push(callback);
499 self
500 }
501
502 pub fn default_retry_budget(mut self, budget: RetryBudget) -> Self {
509 self.default_retry_budget = Some(budget);
510 self
511 }
512
513 pub fn tool_retry_budget(mut self, tool_name: impl Into<String>, budget: RetryBudget) -> Self {
518 self.tool_retry_budgets.insert(tool_name.into(), budget);
519 self
520 }
521
522 pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
529 self.circuit_breaker_threshold = Some(threshold);
530 self
531 }
532
533 pub fn tool_confirmation_policy(mut self, policy: ToolConfirmationPolicy) -> Self {
535 self.tool_confirmation_policy = policy;
536 self
537 }
538
539 pub fn require_tool_confirmation(mut self, tool_name: impl Into<String>) -> Self {
541 self.tool_confirmation_policy = self.tool_confirmation_policy.with_tool(tool_name);
542 self
543 }
544
545 pub fn require_tool_confirmation_for_all(mut self) -> Self {
547 self.tool_confirmation_policy = ToolConfirmationPolicy::Always;
548 self
549 }
550
551 pub fn tool_execution_strategy(mut self, strategy: ToolExecutionStrategy) -> Self {
557 self.tool_execution_strategy = Some(strategy);
558 self
559 }
560
561 pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
570 self.input_guardrails = guardrails;
571 self
572 }
573
574 pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
583 self.output_guardrails = guardrails;
584 self
585 }
586
587 pub fn build(self) -> Result<LlmAgent> {
588 let model = self.model.ok_or_else(|| adk_core::AdkError::agent("Model is required"))?;
589
590 let mut seen_names = std::collections::HashSet::new();
591 for agent in &self.sub_agents {
592 if !seen_names.insert(agent.name()) {
593 return Err(adk_core::AdkError::agent(format!(
594 "Duplicate sub-agent name: {}",
595 agent.name()
596 )));
597 }
598 }
599
600 Ok(LlmAgent {
601 name: self.name,
602 description: self.description.unwrap_or_default(),
603 model,
604 instruction: self.instruction,
605 instruction_provider: self.instruction_provider,
606 global_instruction: self.global_instruction,
607 global_instruction_provider: self.global_instruction_provider,
608 skills_index: self.skills_index,
609 skill_policy: self.skill_policy,
610 max_skill_chars: self.max_skill_chars,
611 input_schema: self.input_schema,
612 output_schema: self.output_schema,
613 disallow_transfer_to_parent: self.disallow_transfer_to_parent,
614 disallow_transfer_to_peers: self.disallow_transfer_to_peers,
615 include_contents: self.include_contents,
616 tools: self.tools,
617 toolsets: self.toolsets,
618 sub_agents: self.sub_agents,
619 output_key: self.output_key,
620 generate_content_config: self.generate_content_config,
621 max_iterations: self.max_iterations,
622 tool_timeout: self.tool_timeout,
623 before_callbacks: Arc::new(self.before_callbacks),
624 after_callbacks: Arc::new(self.after_callbacks),
625 before_model_callbacks: Arc::new(self.before_model_callbacks),
626 after_model_callbacks: Arc::new(self.after_model_callbacks),
627 before_tool_callbacks: Arc::new(self.before_tool_callbacks),
628 after_tool_callbacks: Arc::new(self.after_tool_callbacks),
629 on_tool_error_callbacks: Arc::new(self.on_tool_error_callbacks),
630 after_tool_callbacks_full: Arc::new(self.after_tool_callbacks_full),
631 default_retry_budget: self.default_retry_budget,
632 tool_retry_budgets: self.tool_retry_budgets,
633 circuit_breaker_threshold: self.circuit_breaker_threshold,
634 tool_confirmation_policy: self.tool_confirmation_policy,
635 tool_execution_strategy: self.tool_execution_strategy,
636 input_guardrails: Arc::new(self.input_guardrails),
637 output_guardrails: Arc::new(self.output_guardrails),
638 })
639 }
640}
641
642struct AgentToolContext {
645 parent_ctx: Arc<dyn InvocationContext>,
646 function_call_id: String,
647 actions: Mutex<EventActions>,
648}
649
650impl AgentToolContext {
651 fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
652 Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
653 }
654
655 fn actions_guard(&self) -> std::sync::MutexGuard<'_, EventActions> {
656 self.actions.lock().unwrap_or_else(|e| e.into_inner())
657 }
658}
659
660#[async_trait]
661impl ReadonlyContext for AgentToolContext {
662 fn invocation_id(&self) -> &str {
663 self.parent_ctx.invocation_id()
664 }
665
666 fn agent_name(&self) -> &str {
667 self.parent_ctx.agent_name()
668 }
669
670 fn user_id(&self) -> &str {
671 self.parent_ctx.user_id()
673 }
674
675 fn app_name(&self) -> &str {
676 self.parent_ctx.app_name()
678 }
679
680 fn session_id(&self) -> &str {
681 self.parent_ctx.session_id()
683 }
684
685 fn branch(&self) -> &str {
686 self.parent_ctx.branch()
687 }
688
689 fn user_content(&self) -> &Content {
690 self.parent_ctx.user_content()
691 }
692}
693
694#[async_trait]
695impl CallbackContext for AgentToolContext {
696 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
697 self.parent_ctx.artifacts()
699 }
700
701 fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
702 self.parent_ctx.shared_state()
703 }
704}
705
706#[async_trait]
707impl ToolContext for AgentToolContext {
708 fn function_call_id(&self) -> &str {
709 &self.function_call_id
710 }
711
712 fn actions(&self) -> EventActions {
713 self.actions_guard().clone()
714 }
715
716 fn set_actions(&self, actions: EventActions) {
717 *self.actions_guard() = actions;
718 }
719
720 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
721 if let Some(memory) = self.parent_ctx.memory() {
723 memory.search(query).await
724 } else {
725 Ok(vec![])
726 }
727 }
728
729 fn user_scopes(&self) -> Vec<String> {
730 self.parent_ctx.user_scopes()
731 }
732
733 async fn get_secret(&self, name: &str) -> Result<Option<String>> {
734 self.parent_ctx.get_secret(name).await
735 }
736}
737
738struct ToolOutcomeCallbackContext {
742 inner: Arc<dyn CallbackContext>,
743 outcome: ToolOutcome,
744}
745
746#[async_trait]
747impl ReadonlyContext for ToolOutcomeCallbackContext {
748 fn invocation_id(&self) -> &str {
749 self.inner.invocation_id()
750 }
751
752 fn agent_name(&self) -> &str {
753 self.inner.agent_name()
754 }
755
756 fn user_id(&self) -> &str {
757 self.inner.user_id()
758 }
759
760 fn app_name(&self) -> &str {
761 self.inner.app_name()
762 }
763
764 fn session_id(&self) -> &str {
765 self.inner.session_id()
766 }
767
768 fn branch(&self) -> &str {
769 self.inner.branch()
770 }
771
772 fn user_content(&self) -> &Content {
773 self.inner.user_content()
774 }
775}
776
777#[async_trait]
778impl CallbackContext for ToolOutcomeCallbackContext {
779 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
780 self.inner.artifacts()
781 }
782
783 fn tool_outcome(&self) -> Option<ToolOutcome> {
784 Some(self.outcome.clone())
785 }
786}
787
788struct CircuitBreakerState {
798 threshold: u32,
799 failures: std::collections::HashMap<String, u32>,
801}
802
803impl CircuitBreakerState {
804 fn new(threshold: u32) -> Self {
805 Self { threshold, failures: std::collections::HashMap::new() }
806 }
807
808 fn is_open(&self, tool_name: &str) -> bool {
810 self.failures.get(tool_name).copied().unwrap_or(0) >= self.threshold
811 }
812
813 fn record(&mut self, outcome: &ToolOutcome) {
815 if outcome.success {
816 self.failures.remove(&outcome.tool_name);
817 } else {
818 let count = self.failures.entry(outcome.tool_name.clone()).or_insert(0);
819 *count += 1;
820 }
821 }
822}
823
824#[async_trait]
825impl Agent for LlmAgent {
826 fn name(&self) -> &str {
827 &self.name
828 }
829
830 fn description(&self) -> &str {
831 &self.description
832 }
833
834 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
835 &self.sub_agents
836 }
837
838 #[adk_telemetry::instrument(
839 skip(self, ctx),
840 fields(
841 agent.name = %self.name,
842 agent.description = %self.description,
843 invocation.id = %ctx.invocation_id(),
844 user.id = %ctx.user_id(),
845 session.id = %ctx.session_id()
846 )
847 )]
848 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
849 adk_telemetry::info!("Starting agent execution");
850 let ctx = Self::apply_input_guardrails(ctx, self.input_guardrails.clone()).await?;
851
852 let agent_name = self.name.clone();
853 let invocation_id = ctx.invocation_id().to_string();
854 let model = self.model.clone();
855 let tools = self.tools.clone();
856 let toolsets = self.toolsets.clone();
857 let sub_agents = self.sub_agents.clone();
858
859 let instruction = self.instruction.clone();
860 let instruction_provider = self.instruction_provider.clone();
861 let global_instruction = self.global_instruction.clone();
862 let global_instruction_provider = self.global_instruction_provider.clone();
863 let skills_index = self.skills_index.clone();
864 let skill_policy = self.skill_policy.clone();
865 let max_skill_chars = self.max_skill_chars;
866 let output_key = self.output_key.clone();
867 let output_schema = self.output_schema.clone();
868 let generate_content_config = self.generate_content_config.clone();
869 let include_contents = self.include_contents;
870 let max_iterations = self.max_iterations;
871 let tool_timeout = self.tool_timeout;
872 let before_agent_callbacks = self.before_callbacks.clone();
874 let after_agent_callbacks = self.after_callbacks.clone();
875 let before_model_callbacks = self.before_model_callbacks.clone();
876 let after_model_callbacks = self.after_model_callbacks.clone();
877 let before_tool_callbacks = self.before_tool_callbacks.clone();
878 let after_tool_callbacks = self.after_tool_callbacks.clone();
879 let on_tool_error_callbacks = self.on_tool_error_callbacks.clone();
880 let after_tool_callbacks_full = self.after_tool_callbacks_full.clone();
881 let default_retry_budget = self.default_retry_budget.clone();
882 let tool_retry_budgets = self.tool_retry_budgets.clone();
883 let circuit_breaker_threshold = self.circuit_breaker_threshold;
884 let tool_confirmation_policy = self.tool_confirmation_policy.clone();
885 let disallow_transfer_to_parent = self.disallow_transfer_to_parent;
886 let disallow_transfer_to_peers = self.disallow_transfer_to_peers;
887 let output_guardrails = self.output_guardrails.clone();
888 let agent_tool_execution_strategy = self.tool_execution_strategy;
889
890 let s = stream! {
891 for callback in before_agent_callbacks.as_ref() {
895 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
896 Ok(Some(content)) => {
897 let mut early_event = Event::new(&invocation_id);
899 early_event.author = agent_name.clone();
900 early_event.llm_response.content = Some(content);
901 yield Ok(early_event);
902
903 for after_callback in after_agent_callbacks.as_ref() {
905 match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
906 Ok(Some(after_content)) => {
907 let mut after_event = Event::new(&invocation_id);
908 after_event.author = agent_name.clone();
909 after_event.llm_response.content = Some(after_content);
910 yield Ok(after_event);
911 return;
912 }
913 Ok(None) => continue,
914 Err(e) => {
915 yield Err(e);
916 return;
917 }
918 }
919 }
920 return;
921 }
922 Ok(None) => {
923 continue;
925 }
926 Err(e) => {
927 yield Err(e);
929 return;
930 }
931 }
932 }
933
934 let mut prompt_preamble = Vec::new();
936
937 if let Some(index) = &skills_index {
941 let user_query = ctx
942 .user_content()
943 .parts
944 .iter()
945 .filter_map(|part| match part {
946 Part::Text { text } => Some(text.as_str()),
947 _ => None,
948 })
949 .collect::<Vec<_>>()
950 .join("\n");
951
952 if let Some((_matched, skill_block)) = select_skill_prompt_block(
953 index.as_ref(),
954 &user_query,
955 &skill_policy,
956 max_skill_chars,
957 ) {
958 prompt_preamble.push(Content {
959 role: "user".to_string(),
960 parts: vec![Part::Text { text: skill_block }],
961 });
962 }
963 }
964
965 if let Some(provider) = &global_instruction_provider {
968 let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
970 if !global_inst.is_empty() {
971 prompt_preamble.push(Content {
972 role: "user".to_string(),
973 parts: vec![Part::Text { text: global_inst }],
974 });
975 }
976 } else if let Some(ref template) = global_instruction {
977 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
979 if !processed.is_empty() {
980 prompt_preamble.push(Content {
981 role: "user".to_string(),
982 parts: vec![Part::Text { text: processed }],
983 });
984 }
985 }
986
987 if let Some(provider) = &instruction_provider {
990 let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
992 if !inst.is_empty() {
993 prompt_preamble.push(Content {
994 role: "user".to_string(),
995 parts: vec![Part::Text { text: inst }],
996 });
997 }
998 } else if let Some(ref template) = instruction {
999 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
1001 if !processed.is_empty() {
1002 prompt_preamble.push(Content {
1003 role: "user".to_string(),
1004 parts: vec![Part::Text { text: processed }],
1005 });
1006 }
1007 }
1008
1009 let session_history = if !ctx.run_config().transfer_targets.is_empty() {
1015 ctx.session().conversation_history_for_agent(&agent_name)
1016 } else {
1017 ctx.session().conversation_history()
1018 };
1019 let mut session_history = session_history;
1020 let current_user_content = ctx.user_content().clone();
1021 if let Some(index) = session_history.iter().rposition(|content| content.role == "user") {
1022 session_history[index] = current_user_content.clone();
1023 } else {
1024 session_history.push(current_user_content.clone());
1025 }
1026
1027 let mut conversation_history = match include_contents {
1030 adk_core::IncludeContents::None => {
1031 let mut filtered = prompt_preamble.clone();
1032 filtered.push(current_user_content);
1033 filtered
1034 }
1035 adk_core::IncludeContents::Default => {
1036 let mut full_history = prompt_preamble;
1037 full_history.extend(session_history);
1038 full_history
1039 }
1040 };
1041
1042 let mut resolved_tools: Vec<Arc<dyn Tool>> = tools.clone();
1045 let static_tool_names: std::collections::HashSet<String> =
1046 tools.iter().map(|t| t.name().to_string()).collect();
1047
1048 let mut toolset_source: std::collections::HashMap<String, String> =
1050 std::collections::HashMap::new();
1051
1052 for toolset in &toolsets {
1053 let toolset_tools = match toolset
1054 .tools(ctx.clone() as Arc<dyn ReadonlyContext>)
1055 .await
1056 {
1057 Ok(t) => t,
1058 Err(e) => {
1059 yield Err(e);
1060 return;
1061 }
1062 };
1063 for tool in &toolset_tools {
1064 let name = tool.name().to_string();
1065 if static_tool_names.contains(&name) {
1067 yield Err(adk_core::AdkError::agent(format!(
1068 "Duplicate tool name '{name}': conflict between static tool and toolset '{}'",
1069 toolset.name()
1070 )));
1071 return;
1072 }
1073 if let Some(other_toolset_name) = toolset_source.get(&name) {
1075 yield Err(adk_core::AdkError::agent(format!(
1076 "Duplicate tool name '{name}': conflict between toolset '{}' and toolset '{}'",
1077 other_toolset_name,
1078 toolset.name()
1079 )));
1080 return;
1081 }
1082 toolset_source.insert(name, toolset.name().to_string());
1083 resolved_tools.push(tool.clone());
1084 }
1085 }
1086
1087 let tool_map: std::collections::HashMap<String, Arc<dyn Tool>> = resolved_tools
1089 .iter()
1090 .map(|t| (t.name().to_string(), t.clone()))
1091 .collect();
1092
1093 let collect_long_running_ids = |content: &Content| -> Vec<String> {
1095 content.parts.iter()
1096 .filter_map(|p| {
1097 if let Part::FunctionCall { name, .. } = p {
1098 if let Some(tool) = tool_map.get(name) {
1099 if tool.is_long_running() {
1100 return Some(name.clone());
1101 }
1102 }
1103 }
1104 None
1105 })
1106 .collect()
1107 };
1108
1109 let mut tool_declarations = std::collections::HashMap::new();
1114 for tool in &resolved_tools {
1115 tool_declarations.insert(tool.name().to_string(), tool.declaration());
1116 }
1117
1118 let mut valid_transfer_targets: Vec<String> = sub_agents
1123 .iter()
1124 .map(|a| a.name().to_string())
1125 .collect();
1126
1127 let run_config_targets = &ctx.run_config().transfer_targets;
1129 let parent_agent_name = ctx.run_config().parent_agent.clone();
1130 let sub_agent_names: std::collections::HashSet<&str> = sub_agents
1131 .iter()
1132 .map(|a| a.name())
1133 .collect();
1134
1135 for target in run_config_targets {
1136 if sub_agent_names.contains(target.as_str()) {
1138 continue;
1139 }
1140
1141 let is_parent = parent_agent_name.as_deref() == Some(target.as_str());
1143 if is_parent && disallow_transfer_to_parent {
1144 continue;
1145 }
1146 if !is_parent && disallow_transfer_to_peers {
1147 continue;
1148 }
1149
1150 valid_transfer_targets.push(target.clone());
1151 }
1152
1153 if !valid_transfer_targets.is_empty() {
1155 let transfer_tool_name = "transfer_to_agent";
1156 let transfer_tool_decl = serde_json::json!({
1157 "name": transfer_tool_name,
1158 "description": format!(
1159 "Transfer execution to another agent. Valid targets: {}",
1160 valid_transfer_targets.join(", ")
1161 ),
1162 "parameters": {
1163 "type": "object",
1164 "properties": {
1165 "agent_name": {
1166 "type": "string",
1167 "description": "The name of the agent to transfer to.",
1168 "enum": valid_transfer_targets
1169 }
1170 },
1171 "required": ["agent_name"]
1172 }
1173 });
1174 tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
1175 }
1176
1177
1178 let mut circuit_breaker_state = circuit_breaker_threshold.map(CircuitBreakerState::new);
1181
1182 let mut iteration = 0;
1184
1185 loop {
1186 iteration += 1;
1187 if iteration > max_iterations {
1188 yield Err(adk_core::AdkError::agent(
1189 format!("Max iterations ({max_iterations}) exceeded")
1190 ));
1191 return;
1192 }
1193
1194 let config = match (&generate_content_config, &output_schema) {
1201 (Some(base), Some(schema)) => {
1202 let mut merged = base.clone();
1203 merged.response_schema = Some(schema.clone());
1204 Some(merged)
1205 }
1206 (Some(base), None) => Some(base.clone()),
1207 (None, Some(schema)) => Some(adk_core::GenerateContentConfig {
1208 response_schema: Some(schema.clone()),
1209 ..Default::default()
1210 }),
1211 (None, None) => None,
1212 };
1213
1214 let config = if let Some(ref cached) = ctx.run_config().cached_content {
1216 let mut cfg = config.unwrap_or_default();
1217 if cfg.cached_content.is_none() {
1219 cfg.cached_content = Some(cached.clone());
1220 }
1221 Some(cfg)
1222 } else {
1223 config
1224 };
1225
1226 let request = LlmRequest {
1227 model: model.name().to_string(),
1228 contents: conversation_history.clone(),
1229 tools: tool_declarations.clone(),
1230 config,
1231 };
1232
1233 let mut current_request = request;
1236 let mut model_response_override = None;
1237 for callback in before_model_callbacks.as_ref() {
1238 match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
1239 Ok(BeforeModelResult::Continue(modified_request)) => {
1240 current_request = modified_request;
1242 }
1243 Ok(BeforeModelResult::Skip(response)) => {
1244 model_response_override = Some(response);
1246 break;
1247 }
1248 Err(e) => {
1249 yield Err(e);
1251 return;
1252 }
1253 }
1254 }
1255 let request = current_request;
1256
1257 let mut accumulated_content: Option<Content> = None;
1259 let mut final_provider_metadata: Option<serde_json::Value> = None;
1260
1261 if let Some(cached_response) = model_response_override {
1262 accumulated_content = cached_response.content.clone();
1265 final_provider_metadata = cached_response.provider_metadata.clone();
1266 normalize_option_content(&mut accumulated_content);
1267 if let Some(content) = accumulated_content.take() {
1268 let has_function_calls = content
1269 .parts
1270 .iter()
1271 .any(|part| matches!(part, Part::FunctionCall { .. }));
1272 let content = if has_function_calls {
1273 content
1274 } else {
1275 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1276 };
1277 accumulated_content = Some(content);
1278 }
1279
1280 let mut cached_event = Event::new(&invocation_id);
1281 cached_event.author = agent_name.clone();
1282 cached_event.llm_response.content = accumulated_content.clone();
1283 cached_event.llm_response.provider_metadata = cached_response.provider_metadata.clone();
1284 cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
1285 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), serde_json::to_string(&request).unwrap_or_default());
1286 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&cached_response).unwrap_or_default());
1287
1288 if let Some(ref content) = accumulated_content {
1290 cached_event.long_running_tool_ids = collect_long_running_ids(content);
1291 }
1292
1293 yield Ok(cached_event);
1294 } else {
1295 let request_json = serde_json::to_string(&request).unwrap_or_default();
1297 let trace_request_json = trace_json_payload(
1298 &request,
1299 ctx.run_config().record_payloads,
1300 ctx.run_config().trace_payload_max_bytes,
1301 );
1302
1303 let llm_ts = std::time::SystemTime::now()
1305 .duration_since(std::time::UNIX_EPOCH)
1306 .unwrap_or_default()
1307 .as_nanos();
1308 let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
1309 let llm_span = tracing::info_span!(
1310 "call_llm",
1311 "gcp.vertex.agent.event_id" = %llm_event_id,
1312 "gcp.vertex.agent.invocation_id" = %invocation_id,
1313 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1314 "gen_ai.conversation.id" = %ctx.session_id(),
1315 "gcp.vertex.agent.llm_request" = %trace_request_json,
1316 "gcp.vertex.agent.llm_response" = tracing::field::Empty );
1318 let _llm_guard = llm_span.enter();
1319
1320 use adk_core::StreamingMode;
1322 let streaming_mode = ctx.run_config().streaming_mode;
1323 let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi)
1324 && output_guardrails.is_empty();
1325
1326 let mut response_stream = model.generate_content(request, true).await?;
1328
1329 use futures::StreamExt;
1330
1331 let mut last_chunk: Option<LlmResponse> = None;
1333
1334 while let Some(chunk_result) = response_stream.next().await {
1336 let mut chunk = match chunk_result {
1337 Ok(c) => c,
1338 Err(e) => {
1339 yield Err(e);
1340 return;
1341 }
1342 };
1343
1344 for callback in after_model_callbacks.as_ref() {
1347 match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
1348 Ok(Some(modified_chunk)) => {
1349 chunk = modified_chunk;
1351 break;
1352 }
1353 Ok(None) => {
1354 continue;
1356 }
1357 Err(e) => {
1358 yield Err(e);
1360 return;
1361 }
1362 }
1363 }
1364
1365 normalize_option_content(&mut chunk.content);
1366
1367 if let Some(chunk_content) = chunk.content.clone() {
1369 if let Some(ref mut acc) = accumulated_content {
1370 acc.parts.extend(chunk_content.parts);
1371 } else {
1372 accumulated_content = Some(chunk_content);
1373 }
1374 }
1375
1376 if should_stream_to_client {
1378 let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
1379 partial_event.author = agent_name.clone();
1380 partial_event.llm_request = Some(request_json.clone());
1381 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1382 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&chunk).unwrap_or_default());
1383 partial_event.llm_response.partial = chunk.partial;
1384 partial_event.llm_response.turn_complete = chunk.turn_complete;
1385 partial_event.llm_response.finish_reason = chunk.finish_reason;
1386 partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
1387 partial_event.llm_response.content = chunk.content.clone();
1388 partial_event.llm_response.provider_metadata = chunk.provider_metadata.clone();
1389
1390 if let Some(ref content) = chunk.content {
1392 partial_event.long_running_tool_ids = collect_long_running_ids(content);
1393 }
1394
1395 yield Ok(partial_event);
1396 }
1397
1398 last_chunk = Some(chunk.clone());
1400
1401 if chunk.turn_complete {
1403 break;
1404 }
1405 }
1406
1407 if !should_stream_to_client {
1409 if let Some(content) = accumulated_content.take() {
1410 let has_function_calls = content
1411 .parts
1412 .iter()
1413 .any(|part| matches!(part, Part::FunctionCall { .. }));
1414 let content = if has_function_calls {
1415 content
1416 } else {
1417 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1418 };
1419 accumulated_content = Some(content);
1420 }
1421
1422 let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
1423 final_event.author = agent_name.clone();
1424 final_event.llm_request = Some(request_json.clone());
1425 final_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1426 final_event.llm_response.content = accumulated_content.clone();
1427 final_event.llm_response.partial = false;
1428 final_event.llm_response.turn_complete = true;
1429
1430 if let Some(ref last) = last_chunk {
1432 final_event.llm_response.finish_reason = last.finish_reason;
1433 final_event.llm_response.usage_metadata = last.usage_metadata.clone();
1434 final_event.llm_response.provider_metadata = last.provider_metadata.clone();
1435 final_provider_metadata = last.provider_metadata.clone();
1436 final_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(last).unwrap_or_default());
1437 }
1438
1439 if let Some(ref content) = accumulated_content {
1441 final_event.long_running_tool_ids = collect_long_running_ids(content);
1442 }
1443
1444 yield Ok(final_event);
1445 }
1446
1447 if let Some(ref content) = accumulated_content {
1449 let response_json = trace_json_payload(
1450 content,
1451 ctx.run_config().record_payloads,
1452 ctx.run_config().trace_payload_max_bytes,
1453 );
1454 llm_span.record("gcp.vertex.agent.llm_response", &response_json);
1455 }
1456 }
1457
1458 let function_call_names: Vec<String> = accumulated_content.as_ref()
1460 .map(|c| c.parts.iter()
1461 .filter_map(|p| {
1462 if let Part::FunctionCall { name, .. } = p {
1463 Some(name.clone())
1464 } else {
1465 None
1466 }
1467 })
1468 .collect())
1469 .unwrap_or_default();
1470
1471 let has_function_calls = !function_call_names.is_empty();
1472
1473 let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
1477 tool_map.get(name)
1478 .map(|t| t.is_long_running())
1479 .unwrap_or(false)
1480 });
1481
1482 if let Some(ref content) = accumulated_content {
1484 conversation_history.push(Self::augment_content_for_history(
1485 content,
1486 final_provider_metadata.as_ref(),
1487 ));
1488
1489 if let Some(ref output_key) = output_key {
1491 if !has_function_calls { let mut text_parts = String::new();
1493 for part in &content.parts {
1494 if let Part::Text { text } = part {
1495 text_parts.push_str(text);
1496 }
1497 }
1498 if !text_parts.is_empty() {
1499 let mut state_event = Event::new(&invocation_id);
1501 state_event.author = agent_name.clone();
1502 state_event.actions.state_delta.insert(
1503 output_key.clone(),
1504 serde_json::Value::String(text_parts),
1505 );
1506 yield Ok(state_event);
1507 }
1508 }
1509 }
1510 }
1511
1512 if !has_function_calls {
1513 if let Some(ref content) = accumulated_content {
1516 let response_json = trace_json_payload(
1517 content,
1518 ctx.run_config().record_payloads,
1519 ctx.run_config().trace_payload_max_bytes,
1520 );
1521 tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
1522 }
1523
1524 tracing::info!(agent.name = %agent_name, "Agent execution complete");
1525 break;
1526 }
1527
1528 if let Some(content) = &accumulated_content {
1530 let strategy = agent_tool_execution_strategy
1533 .unwrap_or(ToolExecutionStrategy::Sequential);
1534
1535 let mut fc_parts: Vec<(usize, String, serde_json::Value, Option<String>, String)> = Vec::new();
1539 {
1540 let mut tci = 0usize;
1541 for part in &content.parts {
1542 if let Part::FunctionCall { name, args, id, .. } = part {
1543 let fallback = format!("{}_{}_{}", invocation_id, name, tci);
1544 let fcid = id.clone().unwrap_or(fallback);
1545 fc_parts.push((tci, name.clone(), args.clone(), id.clone(), fcid));
1546 tci += 1;
1547 }
1548 }
1549 }
1550
1551 let mut transfer_handled = false;
1555 for (_, fc_name, fc_args, fc_id, _) in &fc_parts {
1556 if fc_name == "transfer_to_agent" {
1557 let target_agent = fc_args.get("agent_name")
1558 .and_then(|v| v.as_str())
1559 .unwrap_or_default()
1560 .to_string();
1561
1562 let valid_target = valid_transfer_targets.iter().any(|n| n == &target_agent);
1563 if !valid_target {
1564 let error_content = Content {
1565 role: "function".to_string(),
1566 parts: vec![Part::FunctionResponse {
1567 function_response: FunctionResponseData::new(
1568 fc_name.clone(),
1569 serde_json::json!({
1570 "error": format!(
1571 "Agent '{}' not found. Available agents: {:?}",
1572 target_agent, valid_transfer_targets
1573 )
1574 }),
1575 ),
1576 id: fc_id.clone(),
1577 }],
1578 };
1579 conversation_history.push(error_content.clone());
1580 let mut error_event = Event::new(&invocation_id);
1581 error_event.author = agent_name.clone();
1582 error_event.llm_response.content = Some(error_content);
1583 yield Ok(error_event);
1584 continue;
1585 }
1586
1587 let mut transfer_event = Event::new(&invocation_id);
1588 transfer_event.author = agent_name.clone();
1589 transfer_event.actions.transfer_to_agent = Some(target_agent);
1590 yield Ok(transfer_event);
1591 transfer_handled = true;
1592 break;
1593 }
1594 }
1595 if transfer_handled {
1596 return;
1597 }
1598
1599 let fc_parts: Vec<_> = fc_parts.into_iter().filter(|(_, fc_name, _, _, _)| {
1601 if fc_name == "transfer_to_agent" {
1602 return false;
1603 }
1604 if let Some(tool) = tool_map.get(fc_name) {
1605 if tool.is_builtin() {
1606 adk_telemetry::debug!(tool.name = %fc_name, "skipping built-in tool execution");
1607 return false;
1608 }
1609 }
1610 true
1611 }).collect();
1612
1613 let mut confirmation_interrupted = false;
1617 for (_, fc_name, fc_args, _, fc_call_id) in &fc_parts {
1618 if tool_confirmation_policy.requires_confirmation(fc_name)
1619 && ctx.run_config().tool_confirmation_decisions.get(fc_name).copied().is_none()
1620 {
1621 let mut ce = Event::new(&invocation_id);
1622 ce.author = agent_name.clone();
1623 ce.llm_response.interrupted = true;
1624 ce.llm_response.turn_complete = true;
1625 ce.llm_response.content = Some(Content {
1626 role: "model".to_string(),
1627 parts: vec![Part::Text {
1628 text: format!(
1629 "Tool confirmation required for '{}'. Provide approve/deny decision to continue.",
1630 fc_name
1631 ),
1632 }],
1633 });
1634 ce.actions.tool_confirmation = Some(ToolConfirmationRequest {
1635 tool_name: fc_name.clone(),
1636 function_call_id: Some(fc_call_id.clone()),
1637 args: fc_args.clone(),
1638 });
1639 yield Ok(ce);
1640 confirmation_interrupted = true;
1641 break;
1642 }
1643 }
1644 if confirmation_interrupted {
1645 return;
1646 }
1647
1648 let cb_mutex = std::sync::Mutex::new(circuit_breaker_state.take());
1650
1651 let execute_one_tool = |idx: usize, name: String, args: serde_json::Value,
1656 id: Option<String>, function_call_id: String| {
1657 let ctx = ctx.clone();
1658 let tool_map = &tool_map;
1659 let tool_retry_budgets = &tool_retry_budgets;
1660 let default_retry_budget = &default_retry_budget;
1661 let before_tool_callbacks = &before_tool_callbacks;
1662 let after_tool_callbacks = &after_tool_callbacks;
1663 let after_tool_callbacks_full = &after_tool_callbacks_full;
1664 let on_tool_error_callbacks = &on_tool_error_callbacks;
1665 let tool_confirmation_policy = &tool_confirmation_policy;
1666 let cb_mutex = &cb_mutex;
1667 let invocation_id = &invocation_id;
1668 async move {
1669 let mut tool_actions = EventActions::default();
1670 let mut response_content: Option<Content> = None;
1671 let mut run_after_tool_callbacks = true;
1672 let mut tool_outcome_for_callback: Option<ToolOutcome> = None;
1673 let mut executed_tool: Option<Arc<dyn Tool>> = None;
1674 let mut executed_tool_response: Option<serde_json::Value> = None;
1675
1676 if tool_confirmation_policy.requires_confirmation(&name) {
1678 match ctx.run_config().tool_confirmation_decisions.get(&name).copied() {
1679 Some(ToolConfirmationDecision::Approve) => {
1680 tool_actions.tool_confirmation_decision =
1681 Some(ToolConfirmationDecision::Approve);
1682 }
1683 Some(ToolConfirmationDecision::Deny) => {
1684 tool_actions.tool_confirmation_decision =
1685 Some(ToolConfirmationDecision::Deny);
1686 response_content = Some(Content {
1687 role: "function".to_string(),
1688 parts: vec![Part::FunctionResponse {
1689 function_response: FunctionResponseData::new(
1690 name.clone(),
1691 serde_json::json!({
1692 "error": format!("Tool '{}' execution denied by confirmation policy", name)
1693 }),
1694 ),
1695 id: id.clone(),
1696 }],
1697 });
1698 run_after_tool_callbacks = false;
1699 }
1700 None => {
1701 response_content = Some(Content {
1702 role: "function".to_string(),
1703 parts: vec![Part::FunctionResponse {
1704 function_response: FunctionResponseData::new(
1705 name.clone(),
1706 serde_json::json!({
1707 "error": format!("Tool '{}' requires confirmation", name)
1708 }),
1709 ),
1710 id: id.clone(),
1711 }],
1712 });
1713 run_after_tool_callbacks = false;
1714 }
1715 }
1716 }
1717
1718 if response_content.is_none() {
1720 let tool_ctx = Arc::new(ToolCallbackContext::new(
1721 ctx.clone(),
1722 name.clone(),
1723 args.clone(),
1724 ));
1725 for callback in before_tool_callbacks.as_ref() {
1726 match callback(tool_ctx.clone() as Arc<dyn CallbackContext>).await {
1727 Ok(Some(c)) => { response_content = Some(c); break; }
1728 Ok(None) => continue,
1729 Err(e) => {
1730 response_content = Some(Content {
1731 role: "function".to_string(),
1732 parts: vec![Part::FunctionResponse {
1733 function_response: FunctionResponseData::new(
1734 name.clone(),
1735 serde_json::json!({ "error": e.to_string() }),
1736 ),
1737 id: id.clone(),
1738 }],
1739 });
1740 run_after_tool_callbacks = false;
1741 break;
1742 }
1743 }
1744 }
1745 }
1746
1747 if response_content.is_none() {
1749 let guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
1750 if let Some(ref cb_state) = *guard {
1751 if cb_state.is_open(&name) {
1752 let msg = format!(
1753 "Tool '{}' is temporarily disabled after {} consecutive failures",
1754 name, cb_state.threshold
1755 );
1756 tracing::warn!(tool.name = %name, "circuit breaker open, skipping tool execution");
1757 response_content = Some(Content {
1758 role: "function".to_string(),
1759 parts: vec![Part::FunctionResponse {
1760 function_response: FunctionResponseData::new(
1761 name.clone(),
1762 serde_json::json!({ "error": msg }),
1763 ),
1764 id: id.clone(),
1765 }],
1766 });
1767 run_after_tool_callbacks = false;
1768 }
1769 }
1770 drop(guard);
1771 }
1772
1773 if response_content.is_none() {
1775 if let Some(tool) = tool_map.get(&name) {
1776 let tool_ctx: Arc<dyn ToolContext> = Arc::new(
1777 AgentToolContext::new(ctx.clone(), function_call_id.clone()),
1778 );
1779 let span_name = format!("execute_tool {name}");
1780 let tool_span = tracing::info_span!(
1781 "",
1782 otel.name = %span_name,
1783 tool.name = %name,
1784 "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
1785 "gcp.vertex.agent.invocation_id" = %invocation_id,
1786 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1787 "gen_ai.conversation.id" = %ctx.session_id()
1788 );
1789
1790 let budget = tool_retry_budgets.get(&name)
1791 .or(default_retry_budget.as_ref());
1792 let max_attempts = budget.map(|b| b.max_retries + 1).unwrap_or(1);
1793 let retry_delay = budget.map(|b| b.delay).unwrap_or_default();
1794
1795 let tool_clone = tool.clone();
1796 let tool_start = std::time::Instant::now();
1797 let mut last_error = String::new();
1798 let mut final_attempt: u32 = 0;
1799 let mut retry_result: Option<serde_json::Value> = None;
1800
1801 for attempt in 0..max_attempts {
1802 final_attempt = attempt;
1803 if attempt > 0 {
1804 tokio::time::sleep(retry_delay).await;
1805 }
1806 match async {
1807 let args_payload = trace_json_payload(
1808 &args,
1809 ctx.run_config().record_payloads,
1810 ctx.run_config().trace_payload_max_bytes,
1811 );
1812 tracing::debug!(tool.name = %name, tool.args = %args_payload, attempt = attempt, "tool_call");
1813 let exec_future = tool_clone.execute(tool_ctx.clone(), args.clone());
1814 tokio::time::timeout(tool_timeout, exec_future).await
1815 }.instrument(tool_span.clone()).await {
1816 Ok(Ok(value)) => {
1817 let result_payload = trace_json_payload(
1818 &value,
1819 ctx.run_config().record_payloads,
1820 ctx.run_config().trace_payload_max_bytes,
1821 );
1822 tracing::debug!(tool.name = %name, tool.result = %result_payload, "tool_result");
1823 retry_result = Some(value);
1824 break;
1825 }
1826 Ok(Err(e)) => {
1827 last_error = e.to_string();
1828 if attempt + 1 < max_attempts {
1829 tracing::warn!(tool.name = %name, attempt = attempt, error = %last_error, "tool execution failed, retrying");
1830 } else {
1831 tracing::warn!(tool.name = %name, error = %last_error, "tool_error");
1832 }
1833 }
1834 Err(_) => {
1835 last_error = format!(
1836 "Tool '{}' timed out after {} seconds",
1837 name, tool_timeout.as_secs()
1838 );
1839 if attempt + 1 < max_attempts {
1840 tracing::warn!(tool.name = %name, attempt = attempt, timeout_secs = tool_timeout.as_secs(), "tool timed out, retrying");
1841 } else {
1842 tracing::warn!(tool.name = %name, timeout_secs = tool_timeout.as_secs(), "tool_timeout");
1843 }
1844 }
1845 }
1846 }
1847
1848 let tool_duration = tool_start.elapsed();
1849 let (tool_success, tool_error_message, function_response) = match retry_result {
1850 Some(value) => (true, None, value),
1851 None => (false, Some(last_error.clone()), serde_json::json!({ "error": last_error })),
1852 };
1853
1854 let outcome = ToolOutcome {
1855 tool_name: name.clone(),
1856 tool_args: args.clone(),
1857 success: tool_success,
1858 duration: tool_duration,
1859 error_message: tool_error_message.clone(),
1860 attempt: final_attempt,
1861 };
1862 tool_outcome_for_callback = Some(outcome);
1863
1864 {
1866 let mut guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
1867 if let Some(ref mut cb_state) = *guard {
1868 cb_state.record(tool_outcome_for_callback.as_ref().unwrap());
1869 }
1870 }
1871
1872 let final_function_response = if !tool_success {
1874 let mut fallback_result = None;
1875 let error_msg = tool_error_message.clone().unwrap_or_default();
1876 for callback in on_tool_error_callbacks.as_ref() {
1877 match callback(
1878 ctx.clone() as Arc<dyn CallbackContext>,
1879 tool.clone(),
1880 args.clone(),
1881 error_msg.clone(),
1882 ).await {
1883 Ok(Some(result)) => { fallback_result = Some(result); break; }
1884 Ok(None) => continue,
1885 Err(e) => { tracing::warn!(error = %e, "on_tool_error callback failed"); break; }
1886 }
1887 }
1888 fallback_result.unwrap_or(function_response)
1889 } else {
1890 function_response
1891 };
1892
1893 let confirmation_decision = tool_actions.tool_confirmation_decision;
1894 tool_actions = tool_ctx.actions();
1895 if tool_actions.tool_confirmation_decision.is_none() {
1896 tool_actions.tool_confirmation_decision = confirmation_decision;
1897 }
1898 executed_tool = Some(tool.clone());
1899 executed_tool_response = Some(final_function_response.clone());
1900 response_content = Some(Content {
1901 role: "function".to_string(),
1902 parts: vec![Part::FunctionResponse {
1903 function_response: FunctionResponseData::from_tool_result(
1904 name.clone(),
1905 final_function_response,
1906 ),
1907 id: id.clone(),
1908 }],
1909 });
1910 } else {
1911 response_content = Some(Content {
1912 role: "function".to_string(),
1913 parts: vec![Part::FunctionResponse {
1914 function_response: FunctionResponseData::new(
1915 name.clone(),
1916 serde_json::json!({
1917 "error": format!("Tool {} not found", name)
1918 }),
1919 ),
1920 id: id.clone(),
1921 }],
1922 });
1923 }
1924 }
1925
1926 let mut response_content = response_content.expect("tool response content is set");
1928 if run_after_tool_callbacks {
1929 let outcome_ctx: Arc<dyn CallbackContext> = match tool_outcome_for_callback {
1930 Some(outcome) => Arc::new(ToolOutcomeCallbackContext {
1931 inner: ctx.clone() as Arc<dyn CallbackContext>,
1932 outcome,
1933 }),
1934 None => ctx.clone() as Arc<dyn CallbackContext>,
1935 };
1936 let cb_ctx: Arc<dyn CallbackContext> = Arc::new(ToolCallbackContext::new(
1937 outcome_ctx,
1938 name.clone(),
1939 args.clone(),
1940 ));
1941 for callback in after_tool_callbacks.as_ref() {
1942 match callback(cb_ctx.clone()).await {
1943 Ok(Some(modified)) => { response_content = modified; break; }
1944 Ok(None) => continue,
1945 Err(e) => {
1946 response_content = Content {
1947 role: "function".to_string(),
1948 parts: vec![Part::FunctionResponse {
1949 function_response: FunctionResponseData::new(
1950 name.clone(),
1951 serde_json::json!({ "error": e.to_string() }),
1952 ),
1953 id: id.clone(),
1954 }],
1955 };
1956 break;
1957 }
1958 }
1959 }
1960 if let (Some(tool_ref), Some(tool_resp)) = (&executed_tool, executed_tool_response) {
1961 for callback in after_tool_callbacks_full.as_ref() {
1962 match callback(
1963 cb_ctx.clone(), tool_ref.clone(), args.clone(), tool_resp.clone(),
1964 ).await {
1965 Ok(Some(modified_value)) => {
1966 response_content = Content {
1967 role: "function".to_string(),
1968 parts: vec![Part::FunctionResponse {
1969 function_response: FunctionResponseData::from_tool_result(
1970 name.clone(),
1971 modified_value,
1972 ),
1973 id: id.clone(),
1974 }],
1975 };
1976 break;
1977 }
1978 Ok(None) => continue,
1979 Err(e) => {
1980 response_content = Content {
1981 role: "function".to_string(),
1982 parts: vec![Part::FunctionResponse {
1983 function_response: FunctionResponseData::new(
1984 name.clone(),
1985 serde_json::json!({ "error": e.to_string() }),
1986 ),
1987 id: id.clone(),
1988 }],
1989 };
1990 break;
1991 }
1992 }
1993 }
1994 }
1995 }
1996
1997 let escalate_or_skip = tool_actions.escalate || tool_actions.skip_summarization;
1998 (idx, response_content, tool_actions, escalate_or_skip)
1999 }
2000 };
2001
2002 let mut results: Vec<(usize, Content, EventActions, bool)> = match strategy {
2004 ToolExecutionStrategy::Sequential => {
2005 let mut results = Vec::with_capacity(fc_parts.len());
2006 for (idx, name, args, id, fcid) in fc_parts {
2007 results.push(execute_one_tool(idx, name, args, id, fcid).await);
2008 }
2009 results
2010 }
2011 ToolExecutionStrategy::Parallel => {
2012 use futures::StreamExt as _;
2013 let limit = ctx
2014 .run_config()
2015 .max_tool_concurrency
2016 .unwrap_or(fc_parts.len())
2017 .max(1);
2018 futures::stream::iter(fc_parts.into_iter().map(
2019 |(idx, name, args, id, fcid)| {
2020 execute_one_tool(idx, name, args, id, fcid)
2021 },
2022 ))
2023 .buffer_unordered(limit)
2024 .collect()
2025 .await
2026 }
2027 ToolExecutionStrategy::Auto => {
2028 let mut read_only_fcs = Vec::new();
2030 let mut mutable_fcs = Vec::new();
2031 for fc in fc_parts {
2032 let is_ro = tool_map.get(&fc.1)
2033 .map(|t| t.is_read_only())
2034 .unwrap_or(false);
2035 if is_ro { read_only_fcs.push(fc); } else { mutable_fcs.push(fc); }
2036 }
2037 let mut all_results = Vec::new();
2038 if !read_only_fcs.is_empty() {
2040 use futures::StreamExt as _;
2041 let limit = ctx
2042 .run_config()
2043 .max_tool_concurrency
2044 .unwrap_or(read_only_fcs.len())
2045 .max(1);
2046 all_results.extend(
2047 futures::stream::iter(read_only_fcs.into_iter().map(
2048 |(idx, name, args, id, fcid)| {
2049 execute_one_tool(idx, name, args, id, fcid)
2050 },
2051 ))
2052 .buffer_unordered(limit)
2053 .collect::<Vec<_>>()
2054 .await,
2055 );
2056 }
2057 for (idx, name, args, id, fcid) in mutable_fcs {
2059 all_results.push(execute_one_tool(idx, name, args, id, fcid).await);
2060 }
2061 all_results
2062 }
2063 };
2064 results.sort_by_key(|r| r.0);
2066
2067 circuit_breaker_state = cb_mutex.into_inner().unwrap_or_else(|e| e.into_inner());
2069
2070 for (_, response_content, tool_actions, escalate_or_skip) in results {
2072 let mut tool_event = Event::new(&invocation_id);
2073 tool_event.author = agent_name.clone();
2074 tool_event.actions = tool_actions;
2075 tool_event.llm_response.content = Some(response_content.clone());
2076 yield Ok(tool_event);
2077
2078 if escalate_or_skip {
2079 return;
2080 }
2081
2082 conversation_history.push(response_content);
2083 }
2084 }
2085
2086 if all_calls_are_long_running {
2090 }
2094 }
2095
2096 for callback in after_agent_callbacks.as_ref() {
2099 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
2100 Ok(Some(content)) => {
2101 let mut after_event = Event::new(&invocation_id);
2103 after_event.author = agent_name.clone();
2104 after_event.llm_response.content = Some(content);
2105 yield Ok(after_event);
2106 break; }
2108 Ok(None) => {
2109 continue;
2111 }
2112 Err(e) => {
2113 yield Err(e);
2115 return;
2116 }
2117 }
2118 }
2119 };
2120
2121 Ok(Box::pin(s))
2122 }
2123}