1use adk_core::{
2 AfterAgentCallback, AfterModelCallback, AfterToolCallback, AfterToolCallbackFull, Agent,
3 BeforeAgentCallback, BeforeModelCallback, BeforeModelResult, BeforeToolCallback,
4 CallbackContext, Content, Event, EventActions, FunctionResponseData, GlobalInstructionProvider,
5 InstructionProvider, InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry,
6 OnToolErrorCallback, Part, ReadonlyContext, Result, RetryBudget, Tool, ToolCallbackContext,
7 ToolConfirmationDecision, ToolConfirmationPolicy, ToolConfirmationRequest, ToolContext,
8 ToolExecutionStrategy, ToolOutcome, Toolset,
9};
10use async_stream::stream;
11use async_trait::async_trait;
12use std::sync::{Arc, Mutex};
13use tracing::Instrument;
14
15#[cfg(feature = "enhanced-plugins")]
16use adk_plugin::{
17 BeforeModelCallResult, BeforeToolCallResult, EnhancedPlugin, EnhancedPluginManager,
18};
19
20#[cfg(feature = "skills")]
21use crate::skill_shim::load_skill_index;
22use crate::{
23 guardrails::{GuardrailSet, enforce_guardrails},
24 skill_shim::{SelectionPolicy, SkillIndex, select_skill_prompt_block},
25 tool_call_markup::normalize_option_content,
26 workflow::with_user_content_override,
27};
28
29pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
31
32pub const DEFAULT_TOOL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
34
35fn trace_json_payload<T: serde::Serialize>(
36 value: &T,
37 record_payloads: bool,
38 max_bytes: usize,
39) -> String {
40 let json = serde_json::to_string(value).unwrap_or_default();
41 if cfg!(feature = "record-payloads") && record_payloads {
42 return json;
43 }
44
45 let max_bytes = max_bytes.max(32);
46 if json.len() <= max_bytes {
47 return json;
48 }
49
50 let mut end = max_bytes;
51 while !json.is_char_boundary(end) {
52 end -= 1;
53 }
54 format!("{}...[truncated {} bytes]", &json[..end], json.len() - end)
55}
56
57pub struct LlmAgent {
58 name: String,
59 description: String,
60 model: Arc<dyn Llm>,
61 instruction: Option<String>,
62 instruction_provider: Option<Arc<InstructionProvider>>,
63 global_instruction: Option<String>,
64 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
65 skills_index: Option<Arc<SkillIndex>>,
66 skill_policy: SelectionPolicy,
67 max_skill_chars: usize,
68 #[allow(dead_code)] input_schema: Option<serde_json::Value>,
70 output_schema: Option<serde_json::Value>,
71 disallow_transfer_to_parent: bool,
72 disallow_transfer_to_peers: bool,
73 include_contents: adk_core::IncludeContents,
74 tools: Vec<Arc<dyn Tool>>,
75 #[allow(dead_code)] toolsets: Vec<Arc<dyn Toolset>>,
77 sub_agents: Vec<Arc<dyn Agent>>,
78 output_key: Option<String>,
79 generate_content_config: Option<adk_core::GenerateContentConfig>,
81 max_iterations: u32,
83 tool_timeout: std::time::Duration,
85 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
86 after_callbacks: Arc<Vec<AfterAgentCallback>>,
87 before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
88 after_model_callbacks: Arc<Vec<AfterModelCallback>>,
89 before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
90 after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
91 on_tool_error_callbacks: Arc<Vec<OnToolErrorCallback>>,
92 after_tool_callbacks_full: Arc<Vec<AfterToolCallbackFull>>,
94 default_retry_budget: Option<RetryBudget>,
96 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
98 circuit_breaker_threshold: Option<u32>,
101 tool_confirmation_policy: ToolConfirmationPolicy,
102 tool_execution_strategy: Option<ToolExecutionStrategy>,
105 input_guardrails: Arc<GuardrailSet>,
106 output_guardrails: Arc<GuardrailSet>,
107 #[cfg(feature = "enhanced-plugins")]
110 enhanced_plugin_manager: Option<Arc<EnhancedPluginManager>>,
111}
112
113impl std::fmt::Debug for LlmAgent {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("LlmAgent")
116 .field("name", &self.name)
117 .field("description", &self.description)
118 .field("model", &self.model.name())
119 .field("instruction", &self.instruction)
120 .field("tools_count", &self.tools.len())
121 .field("sub_agents_count", &self.sub_agents.len())
122 .finish()
123 }
124}
125
126impl LlmAgent {
127 async fn apply_input_guardrails(
128 ctx: Arc<dyn InvocationContext>,
129 input_guardrails: Arc<GuardrailSet>,
130 ) -> Result<Arc<dyn InvocationContext>> {
131 let content =
132 enforce_guardrails(input_guardrails.as_ref(), ctx.user_content(), "input").await?;
133 if content.role != ctx.user_content().role || content.parts != ctx.user_content().parts {
134 Ok(with_user_content_override(ctx, content))
135 } else {
136 Ok(ctx)
137 }
138 }
139
140 async fn apply_output_guardrails(
141 output_guardrails: &GuardrailSet,
142 content: Content,
143 ) -> Result<Content> {
144 enforce_guardrails(output_guardrails, &content, "output").await
145 }
146
147 fn history_parts_from_provider_metadata(
148 provider_metadata: Option<&serde_json::Value>,
149 ) -> Vec<Part> {
150 let Some(provider_metadata) = provider_metadata else {
151 return Vec::new();
152 };
153
154 let history_parts = provider_metadata
155 .get("conversation_history_parts")
156 .or_else(|| {
157 provider_metadata
158 .get("openai")
159 .and_then(|openai| openai.get("conversation_history_parts"))
160 })
161 .and_then(serde_json::Value::as_array);
162
163 history_parts
164 .into_iter()
165 .flatten()
166 .filter_map(|value| serde_json::from_value::<Part>(value.clone()).ok())
167 .collect()
168 }
169
170 fn augment_content_for_history(
171 content: &Content,
172 provider_metadata: Option<&serde_json::Value>,
173 ) -> Content {
174 let mut augmented = content.clone();
175 augmented.parts.extend(Self::history_parts_from_provider_metadata(provider_metadata));
176 augmented
177 }
178}
179
180pub struct LlmAgentBuilder {
181 name: String,
182 description: Option<String>,
183 model: Option<Arc<dyn Llm>>,
184 instruction: Option<String>,
185 instruction_provider: Option<Arc<InstructionProvider>>,
186 global_instruction: Option<String>,
187 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
188 skills_index: Option<Arc<SkillIndex>>,
189 skill_policy: SelectionPolicy,
190 max_skill_chars: usize,
191 input_schema: Option<serde_json::Value>,
192 output_schema: Option<serde_json::Value>,
193 disallow_transfer_to_parent: bool,
194 disallow_transfer_to_peers: bool,
195 include_contents: adk_core::IncludeContents,
196 tools: Vec<Arc<dyn Tool>>,
197 toolsets: Vec<Arc<dyn Toolset>>,
198 sub_agents: Vec<Arc<dyn Agent>>,
199 output_key: Option<String>,
200 generate_content_config: Option<adk_core::GenerateContentConfig>,
201 max_iterations: u32,
202 tool_timeout: std::time::Duration,
203 before_callbacks: Vec<BeforeAgentCallback>,
204 after_callbacks: Vec<AfterAgentCallback>,
205 before_model_callbacks: Vec<BeforeModelCallback>,
206 after_model_callbacks: Vec<AfterModelCallback>,
207 before_tool_callbacks: Vec<BeforeToolCallback>,
208 after_tool_callbacks: Vec<AfterToolCallback>,
209 on_tool_error_callbacks: Vec<OnToolErrorCallback>,
210 after_tool_callbacks_full: Vec<AfterToolCallbackFull>,
211 default_retry_budget: Option<RetryBudget>,
212 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
213 circuit_breaker_threshold: Option<u32>,
214 tool_confirmation_policy: ToolConfirmationPolicy,
215 tool_execution_strategy: Option<ToolExecutionStrategy>,
216 input_guardrails: GuardrailSet,
217 output_guardrails: GuardrailSet,
218 #[cfg(feature = "enhanced-plugins")]
220 enhanced_plugins: Vec<Arc<dyn EnhancedPlugin>>,
221}
222
223impl LlmAgentBuilder {
224 pub fn new(name: impl Into<String>) -> Self {
225 Self {
226 name: name.into(),
227 description: None,
228 model: None,
229 instruction: None,
230 instruction_provider: None,
231 global_instruction: None,
232 global_instruction_provider: None,
233 skills_index: None,
234 skill_policy: SelectionPolicy::default(),
235 max_skill_chars: 2000,
236 input_schema: None,
237 output_schema: None,
238 disallow_transfer_to_parent: false,
239 disallow_transfer_to_peers: false,
240 include_contents: adk_core::IncludeContents::Default,
241 tools: Vec::new(),
242 toolsets: Vec::new(),
243 sub_agents: Vec::new(),
244 output_key: None,
245 generate_content_config: None,
246 max_iterations: DEFAULT_MAX_ITERATIONS,
247 tool_timeout: DEFAULT_TOOL_TIMEOUT,
248 before_callbacks: Vec::new(),
249 after_callbacks: Vec::new(),
250 before_model_callbacks: Vec::new(),
251 after_model_callbacks: Vec::new(),
252 before_tool_callbacks: Vec::new(),
253 after_tool_callbacks: Vec::new(),
254 on_tool_error_callbacks: Vec::new(),
255 after_tool_callbacks_full: Vec::new(),
256 default_retry_budget: None,
257 tool_retry_budgets: std::collections::HashMap::new(),
258 circuit_breaker_threshold: None,
259 tool_confirmation_policy: ToolConfirmationPolicy::Never,
260 tool_execution_strategy: None,
261 input_guardrails: GuardrailSet::new(),
262 output_guardrails: GuardrailSet::new(),
263 #[cfg(feature = "enhanced-plugins")]
264 enhanced_plugins: Vec::new(),
265 }
266 }
267
268 pub fn description(mut self, desc: impl Into<String>) -> Self {
269 self.description = Some(desc.into());
270 self
271 }
272
273 pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
274 self.model = Some(model);
275 self
276 }
277
278 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
279 self.instruction = Some(instruction.into());
280 self
281 }
282
283 pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
284 self.instruction_provider = Some(Arc::new(provider));
285 self
286 }
287
288 pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
289 self.global_instruction = Some(instruction.into());
290 self
291 }
292
293 pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
294 self.global_instruction_provider = Some(Arc::new(provider));
295 self
296 }
297
298 #[cfg(feature = "skills")]
300 pub fn with_skills(mut self, index: SkillIndex) -> Self {
301 self.skills_index = Some(Arc::new(index));
302 self
303 }
304
305 #[cfg(feature = "skills")]
307 pub fn with_auto_skills(self) -> Result<Self> {
308 self.with_skills_from_root(".")
309 }
310
311 #[cfg(feature = "skills")]
313 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
314 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
315 self.skills_index = Some(Arc::new(index));
316 Ok(self)
317 }
318
319 #[cfg(feature = "skills")]
321 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
322 self.skill_policy = policy;
323 self
324 }
325
326 #[cfg(feature = "skills")]
328 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
329 self.max_skill_chars = max_chars;
330 self
331 }
332
333 pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
334 self.input_schema = Some(schema);
335 self
336 }
337
338 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
339 self.output_schema = Some(schema);
340 self
341 }
342
343 pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
344 self.disallow_transfer_to_parent = disallow;
345 self
346 }
347
348 pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
349 self.disallow_transfer_to_peers = disallow;
350 self
351 }
352
353 pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
354 self.include_contents = include;
355 self
356 }
357
358 pub fn output_key(mut self, key: impl Into<String>) -> Self {
359 self.output_key = Some(key.into());
360 self
361 }
362
363 pub fn generate_content_config(mut self, config: adk_core::GenerateContentConfig) -> Self {
384 self.generate_content_config = Some(config);
385 self
386 }
387
388 pub fn temperature(mut self, temperature: f32) -> Self {
391 self.generate_content_config
392 .get_or_insert(adk_core::GenerateContentConfig::default())
393 .temperature = Some(temperature);
394 self
395 }
396
397 pub fn top_p(mut self, top_p: f32) -> Self {
399 self.generate_content_config
400 .get_or_insert(adk_core::GenerateContentConfig::default())
401 .top_p = Some(top_p);
402 self
403 }
404
405 pub fn top_k(mut self, top_k: i32) -> Self {
407 self.generate_content_config
408 .get_or_insert(adk_core::GenerateContentConfig::default())
409 .top_k = Some(top_k);
410 self
411 }
412
413 pub fn max_output_tokens(mut self, max_tokens: i32) -> Self {
415 self.generate_content_config
416 .get_or_insert(adk_core::GenerateContentConfig::default())
417 .max_output_tokens = Some(max_tokens);
418 self
419 }
420
421 pub fn max_iterations(mut self, max: u32) -> Self {
424 self.max_iterations = max;
425 self
426 }
427
428 pub fn tool_timeout(mut self, timeout: std::time::Duration) -> Self {
431 self.tool_timeout = timeout;
432 self
433 }
434
435 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
436 self.tools.push(tool);
437 self
438 }
439
440 pub fn toolset(mut self, toolset: Arc<dyn Toolset>) -> Self {
446 self.toolsets.push(toolset);
447 self
448 }
449
450 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
451 self.sub_agents.push(agent);
452 self
453 }
454
455 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
456 self.before_callbacks.push(callback);
457 self
458 }
459
460 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
461 self.after_callbacks.push(callback);
462 self
463 }
464
465 pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
466 self.before_model_callbacks.push(callback);
467 self
468 }
469
470 pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
471 self.after_model_callbacks.push(callback);
472 self
473 }
474
475 pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
476 self.before_tool_callbacks.push(callback);
477 self
478 }
479
480 pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
481 self.after_tool_callbacks.push(callback);
482 self
483 }
484
485 pub fn after_tool_callback_full(mut self, callback: AfterToolCallbackFull) -> Self {
500 self.after_tool_callbacks_full.push(callback);
501 self
502 }
503
504 pub fn on_tool_error(mut self, callback: OnToolErrorCallback) -> Self {
512 self.on_tool_error_callbacks.push(callback);
513 self
514 }
515
516 pub fn default_retry_budget(mut self, budget: RetryBudget) -> Self {
523 self.default_retry_budget = Some(budget);
524 self
525 }
526
527 pub fn tool_retry_budget(mut self, tool_name: impl Into<String>, budget: RetryBudget) -> Self {
532 self.tool_retry_budgets.insert(tool_name.into(), budget);
533 self
534 }
535
536 pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
543 self.circuit_breaker_threshold = Some(threshold);
544 self
545 }
546
547 pub fn tool_confirmation_policy(mut self, policy: ToolConfirmationPolicy) -> Self {
549 self.tool_confirmation_policy = policy;
550 self
551 }
552
553 pub fn require_tool_confirmation(mut self, tool_name: impl Into<String>) -> Self {
555 self.tool_confirmation_policy = self.tool_confirmation_policy.with_tool(tool_name);
556 self
557 }
558
559 pub fn require_tool_confirmation_for_all(mut self) -> Self {
561 self.tool_confirmation_policy = ToolConfirmationPolicy::Always;
562 self
563 }
564
565 pub fn tool_execution_strategy(mut self, strategy: ToolExecutionStrategy) -> Self {
571 self.tool_execution_strategy = Some(strategy);
572 self
573 }
574
575 pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
584 self.input_guardrails = guardrails;
585 self
586 }
587
588 pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
597 self.output_guardrails = guardrails;
598 self
599 }
600
601 #[cfg(feature = "enhanced-plugins")]
621 pub fn enhanced_plugin(mut self, plugin: Arc<dyn EnhancedPlugin>) -> Self {
622 self.enhanced_plugins.push(plugin);
623 self
624 }
625
626 #[cfg(feature = "enhanced-plugins")]
648 pub fn enhanced_plugins(mut self, plugins: Vec<Arc<dyn EnhancedPlugin>>) -> Self {
649 self.enhanced_plugins.extend(plugins);
650 self
651 }
652
653 pub fn build(self) -> Result<LlmAgent> {
654 let model = self.model.ok_or_else(|| adk_core::AdkError::agent("Model is required"))?;
655
656 let mut seen_names = std::collections::HashSet::new();
657 for agent in &self.sub_agents {
658 if !seen_names.insert(agent.name()) {
659 return Err(adk_core::AdkError::agent(format!(
660 "Duplicate sub-agent name: {}",
661 agent.name()
662 )));
663 }
664 }
665
666 #[cfg(feature = "enhanced-plugins")]
668 let enhanced_plugin_manager = if self.enhanced_plugins.is_empty() {
669 None
670 } else {
671 Some(Arc::new(EnhancedPluginManager::new(self.enhanced_plugins)))
672 };
673
674 Ok(LlmAgent {
675 name: self.name,
676 description: self.description.unwrap_or_default(),
677 model,
678 instruction: self.instruction,
679 instruction_provider: self.instruction_provider,
680 global_instruction: self.global_instruction,
681 global_instruction_provider: self.global_instruction_provider,
682 skills_index: self.skills_index,
683 skill_policy: self.skill_policy,
684 max_skill_chars: self.max_skill_chars,
685 input_schema: self.input_schema,
686 output_schema: self.output_schema,
687 disallow_transfer_to_parent: self.disallow_transfer_to_parent,
688 disallow_transfer_to_peers: self.disallow_transfer_to_peers,
689 include_contents: self.include_contents,
690 tools: self.tools,
691 toolsets: self.toolsets,
692 sub_agents: self.sub_agents,
693 output_key: self.output_key,
694 generate_content_config: self.generate_content_config,
695 max_iterations: self.max_iterations,
696 tool_timeout: self.tool_timeout,
697 before_callbacks: Arc::new(self.before_callbacks),
698 after_callbacks: Arc::new(self.after_callbacks),
699 before_model_callbacks: Arc::new(self.before_model_callbacks),
700 after_model_callbacks: Arc::new(self.after_model_callbacks),
701 before_tool_callbacks: Arc::new(self.before_tool_callbacks),
702 after_tool_callbacks: Arc::new(self.after_tool_callbacks),
703 on_tool_error_callbacks: Arc::new(self.on_tool_error_callbacks),
704 after_tool_callbacks_full: Arc::new(self.after_tool_callbacks_full),
705 default_retry_budget: self.default_retry_budget,
706 tool_retry_budgets: self.tool_retry_budgets,
707 circuit_breaker_threshold: self.circuit_breaker_threshold,
708 tool_confirmation_policy: self.tool_confirmation_policy,
709 tool_execution_strategy: self.tool_execution_strategy,
710 input_guardrails: Arc::new(self.input_guardrails),
711 output_guardrails: Arc::new(self.output_guardrails),
712 #[cfg(feature = "enhanced-plugins")]
713 enhanced_plugin_manager,
714 })
715 }
716}
717
718struct AgentToolContext {
721 parent_ctx: Arc<dyn InvocationContext>,
722 function_call_id: String,
723 actions: Mutex<EventActions>,
724}
725
726impl AgentToolContext {
727 fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
728 Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
729 }
730
731 fn actions_guard(&self) -> std::sync::MutexGuard<'_, EventActions> {
732 self.actions.lock().unwrap_or_else(|e| e.into_inner())
733 }
734}
735
736#[async_trait]
737impl ReadonlyContext for AgentToolContext {
738 fn invocation_id(&self) -> &str {
739 self.parent_ctx.invocation_id()
740 }
741
742 fn agent_name(&self) -> &str {
743 self.parent_ctx.agent_name()
744 }
745
746 fn user_id(&self) -> &str {
747 self.parent_ctx.user_id()
749 }
750
751 fn app_name(&self) -> &str {
752 self.parent_ctx.app_name()
754 }
755
756 fn session_id(&self) -> &str {
757 self.parent_ctx.session_id()
759 }
760
761 fn branch(&self) -> &str {
762 self.parent_ctx.branch()
763 }
764
765 fn user_content(&self) -> &Content {
766 self.parent_ctx.user_content()
767 }
768}
769
770#[async_trait]
771impl CallbackContext for AgentToolContext {
772 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
773 self.parent_ctx.artifacts()
775 }
776
777 fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
778 self.parent_ctx.shared_state()
779 }
780}
781
782#[async_trait]
783impl ToolContext for AgentToolContext {
784 fn function_call_id(&self) -> &str {
785 &self.function_call_id
786 }
787
788 fn actions(&self) -> EventActions {
789 self.actions_guard().clone()
790 }
791
792 fn set_actions(&self, actions: EventActions) {
793 *self.actions_guard() = actions;
794 }
795
796 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
797 if let Some(memory) = self.parent_ctx.memory() {
799 memory.search(query).await
800 } else {
801 Ok(vec![])
802 }
803 }
804
805 fn user_scopes(&self) -> Vec<String> {
806 self.parent_ctx.user_scopes()
807 }
808
809 async fn get_secret(&self, name: &str) -> Result<Option<String>> {
810 self.parent_ctx.get_secret(name).await
811 }
812}
813
814struct ToolOutcomeCallbackContext {
818 inner: Arc<dyn CallbackContext>,
819 outcome: ToolOutcome,
820}
821
822#[async_trait]
823impl ReadonlyContext for ToolOutcomeCallbackContext {
824 fn invocation_id(&self) -> &str {
825 self.inner.invocation_id()
826 }
827
828 fn agent_name(&self) -> &str {
829 self.inner.agent_name()
830 }
831
832 fn user_id(&self) -> &str {
833 self.inner.user_id()
834 }
835
836 fn app_name(&self) -> &str {
837 self.inner.app_name()
838 }
839
840 fn session_id(&self) -> &str {
841 self.inner.session_id()
842 }
843
844 fn branch(&self) -> &str {
845 self.inner.branch()
846 }
847
848 fn user_content(&self) -> &Content {
849 self.inner.user_content()
850 }
851}
852
853#[async_trait]
854impl CallbackContext for ToolOutcomeCallbackContext {
855 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
856 self.inner.artifacts()
857 }
858
859 fn tool_outcome(&self) -> Option<ToolOutcome> {
860 Some(self.outcome.clone())
861 }
862}
863
864struct CircuitBreakerState {
874 threshold: u32,
875 failures: std::collections::HashMap<String, u32>,
877}
878
879impl CircuitBreakerState {
880 fn new(threshold: u32) -> Self {
881 Self { threshold, failures: std::collections::HashMap::new() }
882 }
883
884 fn is_open(&self, tool_name: &str) -> bool {
886 self.failures.get(tool_name).copied().unwrap_or(0) >= self.threshold
887 }
888
889 fn record(&mut self, outcome: &ToolOutcome) {
891 if outcome.success {
892 self.failures.remove(&outcome.tool_name);
893 } else {
894 let count = self.failures.entry(outcome.tool_name.clone()).or_insert(0);
895 *count += 1;
896 }
897 }
898}
899
900#[async_trait]
901impl Agent for LlmAgent {
902 fn name(&self) -> &str {
903 &self.name
904 }
905
906 fn description(&self) -> &str {
907 &self.description
908 }
909
910 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
911 &self.sub_agents
912 }
913
914 #[adk_telemetry::instrument(
915 skip(self, ctx),
916 fields(
917 agent.name = %self.name,
918 agent.description = %self.description,
919 invocation.id = %ctx.invocation_id(),
920 user.id = %ctx.user_id(),
921 session.id = %ctx.session_id()
922 )
923 )]
924 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
925 adk_telemetry::info!("Starting agent execution");
926 let ctx = Self::apply_input_guardrails(ctx, self.input_guardrails.clone()).await?;
927
928 let agent_name = self.name.clone();
929 let invocation_id = ctx.invocation_id().to_string();
930 let model = self.model.clone();
931 let tools = self.tools.clone();
932 let toolsets = self.toolsets.clone();
933 let sub_agents = self.sub_agents.clone();
934
935 let instruction = self.instruction.clone();
936 let instruction_provider = self.instruction_provider.clone();
937 let global_instruction = self.global_instruction.clone();
938 let global_instruction_provider = self.global_instruction_provider.clone();
939 let skills_index = self.skills_index.clone();
940 let skill_policy = self.skill_policy.clone();
941 let max_skill_chars = self.max_skill_chars;
942 let output_key = self.output_key.clone();
943 let output_schema = self.output_schema.clone();
944 let generate_content_config = self.generate_content_config.clone();
945 let include_contents = self.include_contents;
946 let max_iterations = self.max_iterations;
947 let tool_timeout = self.tool_timeout;
948 let before_agent_callbacks = self.before_callbacks.clone();
950 let after_agent_callbacks = self.after_callbacks.clone();
951 let before_model_callbacks = self.before_model_callbacks.clone();
952 let after_model_callbacks = self.after_model_callbacks.clone();
953 let before_tool_callbacks = self.before_tool_callbacks.clone();
954 let after_tool_callbacks = self.after_tool_callbacks.clone();
955 let on_tool_error_callbacks = self.on_tool_error_callbacks.clone();
956 let after_tool_callbacks_full = self.after_tool_callbacks_full.clone();
957 let default_retry_budget = self.default_retry_budget.clone();
958 let tool_retry_budgets = self.tool_retry_budgets.clone();
959 let circuit_breaker_threshold = self.circuit_breaker_threshold;
960 let tool_confirmation_policy = self.tool_confirmation_policy.clone();
961 let disallow_transfer_to_parent = self.disallow_transfer_to_parent;
962 let disallow_transfer_to_peers = self.disallow_transfer_to_peers;
963 let output_guardrails = self.output_guardrails.clone();
964 let agent_tool_execution_strategy = self.tool_execution_strategy;
965 #[cfg(feature = "enhanced-plugins")]
966 let enhanced_plugin_manager = self.enhanced_plugin_manager.clone();
967
968 let s = stream! {
969 for callback in before_agent_callbacks.as_ref() {
973 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
974 Ok(Some(content)) => {
975 let mut early_event = Event::new(&invocation_id);
977 early_event.author = agent_name.clone();
978 early_event.llm_response.content = Some(content);
979 yield Ok(early_event);
980
981 for after_callback in after_agent_callbacks.as_ref() {
983 match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
984 Ok(Some(after_content)) => {
985 let mut after_event = Event::new(&invocation_id);
986 after_event.author = agent_name.clone();
987 after_event.llm_response.content = Some(after_content);
988 yield Ok(after_event);
989 return;
990 }
991 Ok(None) => continue,
992 Err(e) => {
993 yield Err(e);
994 return;
995 }
996 }
997 }
998 return;
999 }
1000 Ok(None) => {
1001 continue;
1003 }
1004 Err(e) => {
1005 yield Err(e);
1007 return;
1008 }
1009 }
1010 }
1011
1012 let mut prompt_preamble = Vec::new();
1014
1015 if let Some(index) = &skills_index {
1019 let user_query = ctx
1020 .user_content()
1021 .parts
1022 .iter()
1023 .filter_map(|part| match part {
1024 Part::Text { text } => Some(text.as_str()),
1025 _ => None,
1026 })
1027 .collect::<Vec<_>>()
1028 .join("\n");
1029
1030 if let Some((_matched, skill_block)) = select_skill_prompt_block(
1031 index.as_ref(),
1032 &user_query,
1033 &skill_policy,
1034 max_skill_chars,
1035 ) {
1036 prompt_preamble.push(Content {
1037 role: "user".to_string(),
1038 parts: vec![Part::Text { text: skill_block }],
1039 });
1040 }
1041 }
1042
1043 if let Some(provider) = &global_instruction_provider {
1046 let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
1048 if !global_inst.is_empty() {
1049 prompt_preamble.push(Content {
1050 role: "user".to_string(),
1051 parts: vec![Part::Text { text: global_inst }],
1052 });
1053 }
1054 } else if let Some(ref template) = global_instruction {
1055 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
1057 if !processed.is_empty() {
1058 prompt_preamble.push(Content {
1059 role: "user".to_string(),
1060 parts: vec![Part::Text { text: processed }],
1061 });
1062 }
1063 }
1064
1065 if let Some(provider) = &instruction_provider {
1068 let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
1070 if !inst.is_empty() {
1071 prompt_preamble.push(Content {
1072 role: "user".to_string(),
1073 parts: vec![Part::Text { text: inst }],
1074 });
1075 }
1076 } else if let Some(ref template) = instruction {
1077 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
1079 if !processed.is_empty() {
1080 prompt_preamble.push(Content {
1081 role: "user".to_string(),
1082 parts: vec![Part::Text { text: processed }],
1083 });
1084 }
1085 }
1086
1087 let session_history = if !ctx.run_config().transfer_targets.is_empty() {
1093 ctx.session().conversation_history_for_agent(&agent_name)
1094 } else {
1095 ctx.session().conversation_history()
1096 };
1097 let mut session_history = session_history;
1098 let current_user_content = ctx.user_content().clone();
1099 if let Some(index) = session_history.iter().rposition(|content| content.role == "user") {
1100 session_history[index] = current_user_content.clone();
1101 } else {
1102 session_history.push(current_user_content.clone());
1103 }
1104
1105 let mut conversation_history = match include_contents {
1108 adk_core::IncludeContents::None => {
1109 let mut filtered = prompt_preamble.clone();
1110 filtered.push(current_user_content);
1111 filtered
1112 }
1113 adk_core::IncludeContents::Default => {
1114 let mut full_history = prompt_preamble;
1115 full_history.extend(session_history);
1116 full_history
1117 }
1118 };
1119
1120 let mut resolved_tools: Vec<Arc<dyn Tool>> = tools.clone();
1123 let static_tool_names: std::collections::HashSet<String> =
1124 tools.iter().map(|t| t.name().to_string()).collect();
1125
1126 let mut toolset_source: std::collections::HashMap<String, String> =
1128 std::collections::HashMap::new();
1129
1130 for toolset in &toolsets {
1131 let toolset_tools = match toolset
1132 .tools(ctx.clone() as Arc<dyn ReadonlyContext>)
1133 .await
1134 {
1135 Ok(t) => t,
1136 Err(e) => {
1137 yield Err(e);
1138 return;
1139 }
1140 };
1141 for tool in &toolset_tools {
1142 let name = tool.name().to_string();
1143 if static_tool_names.contains(&name) {
1145 yield Err(adk_core::AdkError::agent(format!(
1146 "Duplicate tool name '{name}': conflict between static tool and toolset '{}'",
1147 toolset.name()
1148 )));
1149 return;
1150 }
1151 if let Some(other_toolset_name) = toolset_source.get(&name) {
1153 yield Err(adk_core::AdkError::agent(format!(
1154 "Duplicate tool name '{name}': conflict between toolset '{}' and toolset '{}'",
1155 other_toolset_name,
1156 toolset.name()
1157 )));
1158 return;
1159 }
1160 toolset_source.insert(name, toolset.name().to_string());
1161 resolved_tools.push(tool.clone());
1162 }
1163 }
1164
1165 let tool_map: std::collections::HashMap<String, Arc<dyn Tool>> = resolved_tools
1167 .iter()
1168 .map(|t| (t.name().to_string(), t.clone()))
1169 .collect();
1170
1171 let collect_long_running_ids = |content: &Content| -> Vec<String> {
1173 content.parts.iter()
1174 .filter_map(|p| {
1175 if let Part::FunctionCall { name, .. } = p {
1176 if let Some(tool) = tool_map.get(name) {
1177 if tool.is_long_running() {
1178 return Some(name.clone());
1179 }
1180 }
1181 }
1182 None
1183 })
1184 .collect()
1185 };
1186
1187 let mut tool_declarations = std::collections::HashMap::new();
1192 for tool in &resolved_tools {
1193 tool_declarations.insert(tool.name().to_string(), tool.declaration());
1194 }
1195
1196 let mut valid_transfer_targets: Vec<String> = sub_agents
1201 .iter()
1202 .map(|a| a.name().to_string())
1203 .collect();
1204
1205 let run_config_targets = &ctx.run_config().transfer_targets;
1207 let parent_agent_name = ctx.run_config().parent_agent.clone();
1208 let sub_agent_names: std::collections::HashSet<&str> = sub_agents
1209 .iter()
1210 .map(|a| a.name())
1211 .collect();
1212
1213 for target in run_config_targets {
1214 if sub_agent_names.contains(target.as_str()) {
1216 continue;
1217 }
1218
1219 let is_parent = parent_agent_name.as_deref() == Some(target.as_str());
1221 if is_parent && disallow_transfer_to_parent {
1222 continue;
1223 }
1224 if !is_parent && disallow_transfer_to_peers {
1225 continue;
1226 }
1227
1228 valid_transfer_targets.push(target.clone());
1229 }
1230
1231 if !valid_transfer_targets.is_empty() {
1233 let transfer_tool_name = "transfer_to_agent";
1234 let transfer_tool_decl = serde_json::json!({
1235 "name": transfer_tool_name,
1236 "description": format!(
1237 "Transfer execution to another agent. Valid targets: {}",
1238 valid_transfer_targets.join(", ")
1239 ),
1240 "parameters": {
1241 "type": "object",
1242 "properties": {
1243 "agent_name": {
1244 "type": "string",
1245 "description": "The name of the agent to transfer to.",
1246 "enum": valid_transfer_targets
1247 }
1248 },
1249 "required": ["agent_name"]
1250 }
1251 });
1252 tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
1253 }
1254
1255
1256 let mut circuit_breaker_state = circuit_breaker_threshold.map(CircuitBreakerState::new);
1259
1260 let mut iteration = 0;
1262
1263 loop {
1264 iteration += 1;
1265 if iteration > max_iterations {
1266 yield Err(adk_core::AdkError::agent(
1267 format!("Max iterations ({max_iterations}) exceeded")
1268 ));
1269 return;
1270 }
1271
1272 let config = match (&generate_content_config, &output_schema) {
1279 (Some(base), Some(schema)) => {
1280 let mut merged = base.clone();
1281 merged.response_schema = Some(schema.clone());
1282 Some(merged)
1283 }
1284 (Some(base), None) => Some(base.clone()),
1285 (None, Some(schema)) => Some(adk_core::GenerateContentConfig {
1286 response_schema: Some(schema.clone()),
1287 ..Default::default()
1288 }),
1289 (None, None) => None,
1290 };
1291
1292 let config = if let Some(ref cached) = ctx.run_config().cached_content {
1294 let mut cfg = config.unwrap_or_default();
1295 if cfg.cached_content.is_none() {
1297 cfg.cached_content = Some(cached.clone());
1298 }
1299 Some(cfg)
1300 } else {
1301 config
1302 };
1303
1304 let request = LlmRequest {
1305 model: model.name().to_string(),
1306 contents: conversation_history.clone(),
1307 tools: tool_declarations.clone(),
1308 config,
1309 };
1310
1311 #[cfg(feature = "enhanced-plugins")]
1315 let (request, model_response_override_from_plugin) = {
1316 if let Some(epm) = &enhanced_plugin_manager {
1317 match epm.run_before_model_call(request, ctx.clone() as Arc<dyn CallbackContext>).await {
1318 Ok(BeforeModelCallResult::Continue(modified_request)) => {
1319 (modified_request, None)
1320 }
1321 Ok(BeforeModelCallResult::ShortCircuit(response)) => {
1322 (LlmRequest::new("", vec![]), Some(response))
1324 }
1325 Err(e) => {
1326 yield Err(e);
1327 return;
1328 }
1329 }
1330 } else {
1331 (request, None)
1332 }
1333 };
1334 #[cfg(not(feature = "enhanced-plugins"))]
1335 let model_response_override_from_plugin: Option<LlmResponse> = None;
1336
1337 let mut current_request = request;
1340 let mut model_response_override = model_response_override_from_plugin;
1341 if model_response_override.is_none() {
1342 for callback in before_model_callbacks.as_ref() {
1343 match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
1344 Ok(BeforeModelResult::Continue(modified_request)) => {
1345 current_request = modified_request;
1347 }
1348 Ok(BeforeModelResult::Skip(response)) => {
1349 model_response_override = Some(response);
1351 break;
1352 }
1353 Err(e) => {
1354 yield Err(e);
1356 return;
1357 }
1358 }
1359 }
1360 }
1361 let request = current_request;
1362
1363 let mut accumulated_content: Option<Content> = None;
1365 let mut final_provider_metadata: Option<serde_json::Value> = None;
1366
1367 if let Some(cached_response) = model_response_override {
1368 accumulated_content = cached_response.content.clone();
1371 final_provider_metadata = cached_response.provider_metadata.clone();
1372 normalize_option_content(&mut accumulated_content);
1373 if let Some(content) = accumulated_content.take() {
1374 let has_function_calls = content
1375 .parts
1376 .iter()
1377 .any(|part| matches!(part, Part::FunctionCall { .. }));
1378 let content = if has_function_calls {
1379 content
1380 } else {
1381 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1382 };
1383 accumulated_content = Some(content);
1384 }
1385
1386 let mut cached_event = Event::new(&invocation_id);
1387 cached_event.author = agent_name.clone();
1388 cached_event.llm_response.content = accumulated_content.clone();
1389 cached_event.llm_response.provider_metadata = cached_response.provider_metadata.clone();
1390 cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
1391 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), serde_json::to_string(&request).unwrap_or_default());
1392 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&cached_response).unwrap_or_default());
1393
1394 if let Some(ref content) = accumulated_content {
1396 cached_event.long_running_tool_ids = collect_long_running_ids(content);
1397 }
1398
1399 yield Ok(cached_event);
1400 } else {
1401 let request_json = serde_json::to_string(&request).unwrap_or_default();
1403 let trace_request_json = trace_json_payload(
1404 &request,
1405 ctx.run_config().record_payloads,
1406 ctx.run_config().trace_payload_max_bytes,
1407 );
1408
1409 let llm_ts = std::time::SystemTime::now()
1411 .duration_since(std::time::UNIX_EPOCH)
1412 .unwrap_or_default()
1413 .as_nanos();
1414 let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
1415 let llm_span = tracing::info_span!(
1416 "call_llm",
1417 "gcp.vertex.agent.event_id" = %llm_event_id,
1418 "gcp.vertex.agent.invocation_id" = %invocation_id,
1419 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1420 "gen_ai.conversation.id" = %ctx.session_id(),
1421 "gcp.vertex.agent.llm_request" = %trace_request_json,
1422 "gcp.vertex.agent.llm_response" = tracing::field::Empty );
1424 let _llm_guard = llm_span.enter();
1425
1426 use adk_core::StreamingMode;
1428 let streaming_mode = ctx.run_config().streaming_mode;
1429 let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi)
1430 && output_guardrails.is_empty();
1431
1432 let mut response_stream = model.generate_content(request, true).await?;
1434
1435 use futures::StreamExt;
1436
1437 let mut last_chunk: Option<LlmResponse> = None;
1439
1440 while let Some(chunk_result) = response_stream.next().await {
1442 let mut chunk = match chunk_result {
1443 Ok(c) => c,
1444 Err(e) => {
1445 yield Err(e);
1446 return;
1447 }
1448 };
1449
1450 for callback in after_model_callbacks.as_ref() {
1453 match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
1454 Ok(Some(modified_chunk)) => {
1455 chunk = modified_chunk;
1457 break;
1458 }
1459 Ok(None) => {
1460 continue;
1462 }
1463 Err(e) => {
1464 yield Err(e);
1466 return;
1467 }
1468 }
1469 }
1470
1471 normalize_option_content(&mut chunk.content);
1472
1473 if let Some(chunk_content) = chunk.content.clone() {
1475 if let Some(ref mut acc) = accumulated_content {
1476 acc.parts.extend(chunk_content.parts);
1477 } else {
1478 accumulated_content = Some(chunk_content);
1479 }
1480 }
1481
1482 if should_stream_to_client {
1484 let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
1485 partial_event.author = agent_name.clone();
1486 partial_event.llm_request = Some(request_json.clone());
1487 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1488 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&chunk).unwrap_or_default());
1489 partial_event.llm_response.partial = chunk.partial;
1490 partial_event.llm_response.turn_complete = chunk.turn_complete;
1491 partial_event.llm_response.finish_reason = chunk.finish_reason;
1492 partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
1493 partial_event.llm_response.content = chunk.content.clone();
1494 partial_event.llm_response.provider_metadata = chunk.provider_metadata.clone();
1495
1496 if let Some(ref content) = chunk.content {
1498 partial_event.long_running_tool_ids = collect_long_running_ids(content);
1499 }
1500
1501 yield Ok(partial_event);
1502 }
1503
1504 last_chunk = Some(chunk.clone());
1506
1507 if chunk.turn_complete {
1509 break;
1510 }
1511 }
1512
1513 if !should_stream_to_client {
1515 if let Some(content) = accumulated_content.take() {
1516 let has_function_calls = content
1517 .parts
1518 .iter()
1519 .any(|part| matches!(part, Part::FunctionCall { .. }));
1520 let content = if has_function_calls {
1521 content
1522 } else {
1523 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1524 };
1525 accumulated_content = Some(content);
1526 }
1527
1528 let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
1529 final_event.author = agent_name.clone();
1530 final_event.llm_request = Some(request_json.clone());
1531 final_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1532 final_event.llm_response.content = accumulated_content.clone();
1533 final_event.llm_response.partial = false;
1534 final_event.llm_response.turn_complete = true;
1535
1536 if let Some(ref last) = last_chunk {
1538 final_event.llm_response.finish_reason = last.finish_reason;
1539 final_event.llm_response.usage_metadata = last.usage_metadata.clone();
1540 final_event.llm_response.provider_metadata = last.provider_metadata.clone();
1541 final_provider_metadata = last.provider_metadata.clone();
1542 final_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(last).unwrap_or_default());
1543 }
1544
1545 if let Some(ref content) = accumulated_content {
1547 final_event.long_running_tool_ids = collect_long_running_ids(content);
1548 }
1549
1550 yield Ok(final_event);
1551 }
1552
1553 if let Some(ref content) = accumulated_content {
1555 let response_json = trace_json_payload(
1556 content,
1557 ctx.run_config().record_payloads,
1558 ctx.run_config().trace_payload_max_bytes,
1559 );
1560 llm_span.record("gcp.vertex.agent.llm_response", &response_json);
1561 }
1562 }
1563
1564 #[cfg(feature = "enhanced-plugins")]
1568 if let Some(epm) = &enhanced_plugin_manager {
1569 if let Some(ref content) = accumulated_content {
1570 let response_for_hook = LlmResponse {
1571 content: Some(content.clone()),
1572 provider_metadata: final_provider_metadata.clone(),
1573 ..Default::default()
1574 };
1575 match epm.run_after_model_call(response_for_hook, ctx.clone() as Arc<dyn CallbackContext>).await {
1576 Ok(adk_plugin::AfterModelCallResult::Continue(modified_response)) => {
1577 accumulated_content = modified_response.content;
1578 if modified_response.provider_metadata.is_some() {
1579 final_provider_metadata = modified_response.provider_metadata;
1580 }
1581 }
1582 Err(e) => {
1583 yield Err(e);
1584 return;
1585 }
1586 }
1587 }
1588 }
1589
1590 let function_call_names: Vec<String> = accumulated_content.as_ref()
1592 .map(|c| c.parts.iter()
1593 .filter_map(|p| {
1594 if let Part::FunctionCall { name, .. } = p {
1595 Some(name.clone())
1596 } else {
1597 None
1598 }
1599 })
1600 .collect())
1601 .unwrap_or_default();
1602
1603 let has_function_calls = !function_call_names.is_empty();
1604
1605 let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
1609 tool_map.get(name)
1610 .map(|t| t.is_long_running())
1611 .unwrap_or(false)
1612 });
1613
1614 if let Some(ref content) = accumulated_content {
1616 conversation_history.push(Self::augment_content_for_history(
1617 content,
1618 final_provider_metadata.as_ref(),
1619 ));
1620
1621 if let Some(ref output_key) = output_key {
1623 if !has_function_calls { let mut text_parts = String::new();
1625 for part in &content.parts {
1626 if let Part::Text { text } = part {
1627 text_parts.push_str(text);
1628 }
1629 }
1630 if !text_parts.is_empty() {
1631 let mut state_event = Event::new(&invocation_id);
1633 state_event.author = agent_name.clone();
1634 state_event.actions.state_delta.insert(
1635 output_key.clone(),
1636 serde_json::Value::String(text_parts),
1637 );
1638 yield Ok(state_event);
1639 }
1640 }
1641 }
1642 }
1643
1644 if !has_function_calls {
1645 if let Some(ref content) = accumulated_content {
1648 let response_json = trace_json_payload(
1649 content,
1650 ctx.run_config().record_payloads,
1651 ctx.run_config().trace_payload_max_bytes,
1652 );
1653 tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
1654 }
1655
1656 tracing::info!(agent.name = %agent_name, "Agent execution complete");
1657 break;
1658 }
1659
1660 if let Some(content) = &accumulated_content {
1662 let strategy = agent_tool_execution_strategy
1665 .unwrap_or(ToolExecutionStrategy::Sequential);
1666
1667 let mut fc_parts: Vec<(usize, String, serde_json::Value, Option<String>, String)> = Vec::new();
1671 {
1672 let mut tci = 0usize;
1673 for part in &content.parts {
1674 if let Part::FunctionCall { name, args, id, .. } = part {
1675 let fallback = format!("{}_{}_{}", invocation_id, name, tci);
1676 let fcid = id.clone().unwrap_or(fallback);
1677 fc_parts.push((tci, name.clone(), args.clone(), id.clone(), fcid));
1678 tci += 1;
1679 }
1680 }
1681 }
1682
1683 let mut transfer_handled = false;
1687 for (_, fc_name, fc_args, fc_id, _) in &fc_parts {
1688 if fc_name == "transfer_to_agent" {
1689 let target_agent = fc_args.get("agent_name")
1690 .and_then(|v| v.as_str())
1691 .unwrap_or_default()
1692 .to_string();
1693
1694 let valid_target = valid_transfer_targets.iter().any(|n| n == &target_agent);
1695 if !valid_target {
1696 let error_content = Content {
1697 role: "function".to_string(),
1698 parts: vec![Part::FunctionResponse {
1699 function_response: FunctionResponseData::new(
1700 fc_name.clone(),
1701 serde_json::json!({
1702 "error": format!(
1703 "Agent '{}' not found. Available agents: {:?}",
1704 target_agent, valid_transfer_targets
1705 )
1706 }),
1707 ),
1708 id: fc_id.clone(),
1709 }],
1710 };
1711 conversation_history.push(error_content.clone());
1712 let mut error_event = Event::new(&invocation_id);
1713 error_event.author = agent_name.clone();
1714 error_event.llm_response.content = Some(error_content);
1715 yield Ok(error_event);
1716 continue;
1717 }
1718
1719 let mut transfer_event = Event::new(&invocation_id);
1720 transfer_event.author = agent_name.clone();
1721 transfer_event.actions.transfer_to_agent = Some(target_agent);
1722 yield Ok(transfer_event);
1723 transfer_handled = true;
1724 break;
1725 }
1726 }
1727 if transfer_handled {
1728 return;
1729 }
1730
1731 let fc_parts: Vec<_> = fc_parts.into_iter().filter(|(_, fc_name, _, _, _)| {
1733 if fc_name == "transfer_to_agent" {
1734 return false;
1735 }
1736 if let Some(tool) = tool_map.get(fc_name) {
1737 if tool.is_builtin() {
1738 adk_telemetry::debug!(tool.name = %fc_name, "skipping built-in tool execution");
1739 return false;
1740 }
1741 }
1742 true
1743 }).collect();
1744
1745 let mut confirmation_interrupted = false;
1749 for (_, fc_name, fc_args, _, fc_call_id) in &fc_parts {
1750 if tool_confirmation_policy.requires_confirmation(fc_name)
1751 && ctx.run_config().tool_confirmation_decisions.get(fc_name).copied().is_none()
1752 {
1753 let mut ce = Event::new(&invocation_id);
1754 ce.author = agent_name.clone();
1755 ce.llm_response.interrupted = true;
1756 ce.llm_response.turn_complete = true;
1757 ce.llm_response.content = Some(Content {
1758 role: "model".to_string(),
1759 parts: vec![Part::Text {
1760 text: format!(
1761 "Tool confirmation required for '{}'. Provide approve/deny decision to continue.",
1762 fc_name
1763 ),
1764 }],
1765 });
1766 ce.actions.tool_confirmation = Some(ToolConfirmationRequest {
1767 tool_name: fc_name.clone(),
1768 function_call_id: Some(fc_call_id.clone()),
1769 args: fc_args.clone(),
1770 });
1771 yield Ok(ce);
1772 confirmation_interrupted = true;
1773 break;
1774 }
1775 }
1776 if confirmation_interrupted {
1777 return;
1778 }
1779
1780 let cb_mutex = std::sync::Mutex::new(circuit_breaker_state.take());
1782
1783 let execute_one_tool = |idx: usize, name: String, args: serde_json::Value,
1788 id: Option<String>, function_call_id: String| {
1789 let ctx = ctx.clone();
1790 let tool_map = &tool_map;
1791 let tool_retry_budgets = &tool_retry_budgets;
1792 let default_retry_budget = &default_retry_budget;
1793 let before_tool_callbacks = &before_tool_callbacks;
1794 let after_tool_callbacks = &after_tool_callbacks;
1795 let after_tool_callbacks_full = &after_tool_callbacks_full;
1796 let on_tool_error_callbacks = &on_tool_error_callbacks;
1797 let tool_confirmation_policy = &tool_confirmation_policy;
1798 let cb_mutex = &cb_mutex;
1799 let invocation_id = &invocation_id;
1800 #[cfg(feature = "enhanced-plugins")]
1801 let enhanced_plugin_manager = &enhanced_plugin_manager;
1802 async move {
1803 let mut tool_actions = EventActions::default();
1804 let mut response_content: Option<Content> = None;
1805 let mut run_after_tool_callbacks = true;
1806 let mut tool_outcome_for_callback: Option<ToolOutcome> = None;
1807 let mut executed_tool: Option<Arc<dyn Tool>> = None;
1808 let mut executed_tool_response: Option<serde_json::Value> = None;
1809
1810 if tool_confirmation_policy.requires_confirmation(&name) {
1812 match ctx.run_config().tool_confirmation_decisions.get(&name).copied() {
1813 Some(ToolConfirmationDecision::Approve) => {
1814 tool_actions.tool_confirmation_decision =
1815 Some(ToolConfirmationDecision::Approve);
1816 }
1817 Some(ToolConfirmationDecision::Deny) => {
1818 tool_actions.tool_confirmation_decision =
1819 Some(ToolConfirmationDecision::Deny);
1820 response_content = Some(Content {
1821 role: "function".to_string(),
1822 parts: vec![Part::FunctionResponse {
1823 function_response: FunctionResponseData::new(
1824 name.clone(),
1825 serde_json::json!({
1826 "error": format!("Tool '{}' execution denied by confirmation policy", name)
1827 }),
1828 ),
1829 id: id.clone(),
1830 }],
1831 });
1832 run_after_tool_callbacks = false;
1833 }
1834 None => {
1835 response_content = Some(Content {
1836 role: "function".to_string(),
1837 parts: vec![Part::FunctionResponse {
1838 function_response: FunctionResponseData::new(
1839 name.clone(),
1840 serde_json::json!({
1841 "error": format!("Tool '{}' requires confirmation", name)
1842 }),
1843 ),
1844 id: id.clone(),
1845 }],
1846 });
1847 run_after_tool_callbacks = false;
1848 }
1849 }
1850 }
1851
1852 #[allow(unused_mut)]
1855 let mut final_args = args.clone();
1856
1857 #[cfg(feature = "enhanced-plugins")]
1859 if response_content.is_none() {
1860 if let Some(epm) = &enhanced_plugin_manager {
1861 if let Some(tool_ref) = tool_map.get(&name) {
1862 match epm.run_before_tool_call(
1863 tool_ref.clone(),
1864 final_args.clone(),
1865 ctx.clone() as Arc<dyn CallbackContext>,
1866 ).await {
1867 Ok(BeforeToolCallResult::Continue(modified_args)) => {
1868 final_args = modified_args;
1869 }
1870 Ok(BeforeToolCallResult::ShortCircuit(synthetic_result)) => {
1871 response_content = Some(Content {
1873 role: "function".to_string(),
1874 parts: vec![Part::FunctionResponse {
1875 function_response: FunctionResponseData::from_tool_result(
1876 name.clone(),
1877 synthetic_result,
1878 ),
1879 id: id.clone(),
1880 }],
1881 });
1882 executed_tool = Some(tool_ref.clone());
1883 }
1884 Err(e) => {
1885 response_content = Some(Content {
1886 role: "function".to_string(),
1887 parts: vec![Part::FunctionResponse {
1888 function_response: FunctionResponseData::new(
1889 name.clone(),
1890 serde_json::json!({ "error": e.to_string() }),
1891 ),
1892 id: id.clone(),
1893 }],
1894 });
1895 run_after_tool_callbacks = false;
1896 }
1897 }
1898 }
1899 }
1900 }
1901
1902 if response_content.is_none() {
1903 let tool_ctx = Arc::new(ToolCallbackContext::new(
1904 ctx.clone(),
1905 name.clone(),
1906 final_args.clone(),
1907 ));
1908 for callback in before_tool_callbacks.as_ref() {
1909 match callback(tool_ctx.clone() as Arc<dyn CallbackContext>).await {
1910 Ok(Some(c)) => { response_content = Some(c); break; }
1911 Ok(None) => continue,
1912 Err(e) => {
1913 response_content = Some(Content {
1914 role: "function".to_string(),
1915 parts: vec![Part::FunctionResponse {
1916 function_response: FunctionResponseData::new(
1917 name.clone(),
1918 serde_json::json!({ "error": e.to_string() }),
1919 ),
1920 id: id.clone(),
1921 }],
1922 });
1923 run_after_tool_callbacks = false;
1924 break;
1925 }
1926 }
1927 }
1928 }
1929
1930 if response_content.is_none() {
1932 let guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
1933 if let Some(ref cb_state) = *guard {
1934 if cb_state.is_open(&name) {
1935 let msg = format!(
1936 "Tool '{}' is temporarily disabled after {} consecutive failures",
1937 name, cb_state.threshold
1938 );
1939 tracing::warn!(tool.name = %name, "circuit breaker open, skipping tool execution");
1940 response_content = Some(Content {
1941 role: "function".to_string(),
1942 parts: vec![Part::FunctionResponse {
1943 function_response: FunctionResponseData::new(
1944 name.clone(),
1945 serde_json::json!({ "error": msg }),
1946 ),
1947 id: id.clone(),
1948 }],
1949 });
1950 run_after_tool_callbacks = false;
1951 }
1952 }
1953 drop(guard);
1954 }
1955
1956 if response_content.is_none() {
1958 if let Some(tool) = tool_map.get(&name) {
1959 let tool_ctx: Arc<dyn ToolContext> = Arc::new(
1960 AgentToolContext::new(ctx.clone(), function_call_id.clone()),
1961 );
1962 let span_name = format!("execute_tool {name}");
1963 let tool_span = tracing::info_span!(
1964 "",
1965 otel.name = %span_name,
1966 tool.name = %name,
1967 "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
1968 "gcp.vertex.agent.invocation_id" = %invocation_id,
1969 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1970 "gen_ai.conversation.id" = %ctx.session_id()
1971 );
1972
1973 let budget = tool_retry_budgets.get(&name)
1974 .or(default_retry_budget.as_ref());
1975 let max_attempts = budget.map(|b| b.max_retries + 1).unwrap_or(1);
1976 let retry_delay = budget.map(|b| b.delay).unwrap_or_default();
1977
1978 let tool_clone = tool.clone();
1979 let tool_start = std::time::Instant::now();
1980 let mut last_error = String::new();
1981 let mut final_attempt: u32 = 0;
1982 let mut retry_result: Option<serde_json::Value> = None;
1983
1984 for attempt in 0..max_attempts {
1985 final_attempt = attempt;
1986 if attempt > 0 {
1987 tokio::time::sleep(retry_delay).await;
1988 }
1989 match async {
1990 let args_payload = trace_json_payload(
1991 &final_args,
1992 ctx.run_config().record_payloads,
1993 ctx.run_config().trace_payload_max_bytes,
1994 );
1995 tracing::debug!(tool.name = %name, tool.args = %args_payload, attempt = attempt, "tool_call");
1996 let exec_future = tool_clone.execute(tool_ctx.clone(), final_args.clone());
1997 tokio::time::timeout(tool_timeout, exec_future).await
1998 }.instrument(tool_span.clone()).await {
1999 Ok(Ok(value)) => {
2000 let result_payload = trace_json_payload(
2001 &value,
2002 ctx.run_config().record_payloads,
2003 ctx.run_config().trace_payload_max_bytes,
2004 );
2005 tracing::debug!(tool.name = %name, tool.result = %result_payload, "tool_result");
2006 retry_result = Some(value);
2007 break;
2008 }
2009 Ok(Err(e)) => {
2010 last_error = e.to_string();
2011 if attempt + 1 < max_attempts {
2012 tracing::warn!(tool.name = %name, attempt = attempt, error = %last_error, "tool execution failed, retrying");
2013 } else {
2014 tracing::warn!(tool.name = %name, error = %last_error, "tool_error");
2015 }
2016 }
2017 Err(_) => {
2018 last_error = format!(
2019 "Tool '{}' timed out after {} seconds",
2020 name, tool_timeout.as_secs()
2021 );
2022 if attempt + 1 < max_attempts {
2023 tracing::warn!(tool.name = %name, attempt = attempt, timeout_secs = tool_timeout.as_secs(), "tool timed out, retrying");
2024 } else {
2025 tracing::warn!(tool.name = %name, timeout_secs = tool_timeout.as_secs(), "tool_timeout");
2026 }
2027 }
2028 }
2029 }
2030
2031 let tool_duration = tool_start.elapsed();
2032 let (tool_success, tool_error_message, function_response) = match retry_result {
2033 Some(value) => (true, None, value),
2034 None => (false, Some(last_error.clone()), serde_json::json!({ "error": last_error })),
2035 };
2036
2037 let outcome = ToolOutcome {
2038 tool_name: name.clone(),
2039 tool_args: final_args.clone(),
2040 success: tool_success,
2041 duration: tool_duration,
2042 error_message: tool_error_message.clone(),
2043 attempt: final_attempt,
2044 };
2045 tool_outcome_for_callback = Some(outcome);
2046
2047 {
2049 let mut guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
2050 if let Some(ref mut cb_state) = *guard {
2051 cb_state.record(tool_outcome_for_callback.as_ref().unwrap());
2052 }
2053 }
2054
2055 let final_function_response = if !tool_success {
2057 let mut fallback_result = None;
2058 let error_msg = tool_error_message.clone().unwrap_or_default();
2059 for callback in on_tool_error_callbacks.as_ref() {
2060 match callback(
2061 ctx.clone() as Arc<dyn CallbackContext>,
2062 tool.clone(),
2063 final_args.clone(),
2064 error_msg.clone(),
2065 ).await {
2066 Ok(Some(result)) => { fallback_result = Some(result); break; }
2067 Ok(None) => continue,
2068 Err(e) => { tracing::warn!(error = %e, "on_tool_error callback failed"); break; }
2069 }
2070 }
2071 fallback_result.unwrap_or(function_response)
2072 } else {
2073 function_response
2074 };
2075
2076 let confirmation_decision = tool_actions.tool_confirmation_decision;
2077 tool_actions = tool_ctx.actions();
2078 if tool_actions.tool_confirmation_decision.is_none() {
2079 tool_actions.tool_confirmation_decision = confirmation_decision;
2080 }
2081 executed_tool = Some(tool.clone());
2082 executed_tool_response = Some(final_function_response.clone());
2083 response_content = Some(Content {
2084 role: "function".to_string(),
2085 parts: vec![Part::FunctionResponse {
2086 function_response: FunctionResponseData::from_tool_result(
2087 name.clone(),
2088 final_function_response,
2089 ),
2090 id: id.clone(),
2091 }],
2092 });
2093 } else {
2094 response_content = Some(Content {
2095 role: "function".to_string(),
2096 parts: vec![Part::FunctionResponse {
2097 function_response: FunctionResponseData::new(
2098 name.clone(),
2099 serde_json::json!({
2100 "error": format!("Tool {} not found", name)
2101 }),
2102 ),
2103 id: id.clone(),
2104 }],
2105 });
2106 }
2107 }
2108
2109 let mut response_content = response_content.expect("tool response content is set");
2111 if run_after_tool_callbacks {
2112 let outcome_ctx: Arc<dyn CallbackContext> = match tool_outcome_for_callback {
2113 Some(outcome) => Arc::new(ToolOutcomeCallbackContext {
2114 inner: ctx.clone() as Arc<dyn CallbackContext>,
2115 outcome,
2116 }),
2117 None => ctx.clone() as Arc<dyn CallbackContext>,
2118 };
2119 let cb_ctx: Arc<dyn CallbackContext> = Arc::new(ToolCallbackContext::new(
2120 outcome_ctx,
2121 name.clone(),
2122 final_args.clone(),
2123 ));
2124 for callback in after_tool_callbacks.as_ref() {
2125 match callback(cb_ctx.clone()).await {
2126 Ok(Some(modified)) => { response_content = modified; break; }
2127 Ok(None) => continue,
2128 Err(e) => {
2129 response_content = Content {
2130 role: "function".to_string(),
2131 parts: vec![Part::FunctionResponse {
2132 function_response: FunctionResponseData::new(
2133 name.clone(),
2134 serde_json::json!({ "error": e.to_string() }),
2135 ),
2136 id: id.clone(),
2137 }],
2138 };
2139 break;
2140 }
2141 }
2142 }
2143 if let (Some(tool_ref), Some(tool_resp)) = (&executed_tool, executed_tool_response) {
2144 for callback in after_tool_callbacks_full.as_ref() {
2145 match callback(
2146 cb_ctx.clone(), tool_ref.clone(), final_args.clone(), tool_resp.clone(),
2147 ).await {
2148 Ok(Some(modified_value)) => {
2149 response_content = Content {
2150 role: "function".to_string(),
2151 parts: vec![Part::FunctionResponse {
2152 function_response: FunctionResponseData::from_tool_result(
2153 name.clone(),
2154 modified_value,
2155 ),
2156 id: id.clone(),
2157 }],
2158 };
2159 break;
2160 }
2161 Ok(None) => continue,
2162 Err(e) => {
2163 response_content = Content {
2164 role: "function".to_string(),
2165 parts: vec![Part::FunctionResponse {
2166 function_response: FunctionResponseData::new(
2167 name.clone(),
2168 serde_json::json!({ "error": e.to_string() }),
2169 ),
2170 id: id.clone(),
2171 }],
2172 };
2173 break;
2174 }
2175 }
2176 }
2177 }
2178
2179 #[cfg(feature = "enhanced-plugins")]
2182 if let Some(epm) = &enhanced_plugin_manager {
2183 if let Some(tool_ref) = &executed_tool {
2184 let result_value = response_content.parts.iter()
2186 .find_map(|p| {
2187 if let Part::FunctionResponse { function_response, .. } = p {
2188 Some(function_response.response.clone())
2189 } else {
2190 None
2191 }
2192 })
2193 .unwrap_or(serde_json::json!(null));
2194
2195 match epm.run_after_tool_call(
2196 tool_ref.clone(),
2197 &final_args,
2198 result_value,
2199 ctx.clone() as Arc<dyn CallbackContext>,
2200 ).await {
2201 Ok(adk_plugin::AfterToolCallResult::Continue(modified_result)) => {
2202 response_content = Content {
2203 role: "function".to_string(),
2204 parts: vec![Part::FunctionResponse {
2205 function_response: FunctionResponseData::from_tool_result(
2206 name.clone(),
2207 modified_result,
2208 ),
2209 id: id.clone(),
2210 }],
2211 };
2212 }
2213 Err(e) => {
2214 response_content = Content {
2215 role: "function".to_string(),
2216 parts: vec![Part::FunctionResponse {
2217 function_response: FunctionResponseData::new(
2218 name.clone(),
2219 serde_json::json!({ "error": e.to_string() }),
2220 ),
2221 id: id.clone(),
2222 }],
2223 };
2224 }
2225 }
2226 }
2227 }
2228 }
2229
2230 let escalate_or_skip = tool_actions.escalate || tool_actions.skip_summarization;
2231 (idx, response_content, tool_actions, escalate_or_skip)
2232 }
2233 };
2234
2235 let mut results: Vec<(usize, Content, EventActions, bool)> = match strategy {
2237 ToolExecutionStrategy::Sequential => {
2238 let mut results = Vec::with_capacity(fc_parts.len());
2239 for (idx, name, args, id, fcid) in fc_parts {
2240 results.push(execute_one_tool(idx, name, args, id, fcid).await);
2241 }
2242 results
2243 }
2244 ToolExecutionStrategy::Parallel => {
2245 use futures::StreamExt as _;
2246 let limit = ctx
2247 .run_config()
2248 .max_tool_concurrency
2249 .unwrap_or(fc_parts.len())
2250 .max(1);
2251 futures::stream::iter(fc_parts.into_iter().map(
2252 |(idx, name, args, id, fcid)| {
2253 execute_one_tool(idx, name, args, id, fcid)
2254 },
2255 ))
2256 .buffer_unordered(limit)
2257 .collect()
2258 .await
2259 }
2260 ToolExecutionStrategy::Auto => {
2261 let mut read_only_fcs = Vec::new();
2263 let mut mutable_fcs = Vec::new();
2264 for fc in fc_parts {
2265 let is_ro = tool_map.get(&fc.1)
2266 .map(|t| t.is_read_only())
2267 .unwrap_or(false);
2268 if is_ro { read_only_fcs.push(fc); } else { mutable_fcs.push(fc); }
2269 }
2270 let mut all_results = Vec::new();
2271 if !read_only_fcs.is_empty() {
2273 use futures::StreamExt as _;
2274 let limit = ctx
2275 .run_config()
2276 .max_tool_concurrency
2277 .unwrap_or(read_only_fcs.len())
2278 .max(1);
2279 all_results.extend(
2280 futures::stream::iter(read_only_fcs.into_iter().map(
2281 |(idx, name, args, id, fcid)| {
2282 execute_one_tool(idx, name, args, id, fcid)
2283 },
2284 ))
2285 .buffer_unordered(limit)
2286 .collect::<Vec<_>>()
2287 .await,
2288 );
2289 }
2290 for (idx, name, args, id, fcid) in mutable_fcs {
2292 all_results.push(execute_one_tool(idx, name, args, id, fcid).await);
2293 }
2294 all_results
2295 }
2296 };
2297 results.sort_by_key(|r| r.0);
2299
2300 circuit_breaker_state = cb_mutex.into_inner().unwrap_or_else(|e| e.into_inner());
2302
2303 for (_, response_content, tool_actions, escalate_or_skip) in results {
2305 let mut tool_event = Event::new(&invocation_id);
2306 tool_event.author = agent_name.clone();
2307 tool_event.actions = tool_actions;
2308 tool_event.llm_response.content = Some(response_content.clone());
2309 yield Ok(tool_event);
2310
2311 if escalate_or_skip {
2312 return;
2313 }
2314
2315 conversation_history.push(response_content);
2316 }
2317 }
2318
2319 if all_calls_are_long_running {
2323 }
2327 }
2328
2329 for callback in after_agent_callbacks.as_ref() {
2332 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
2333 Ok(Some(content)) => {
2334 let mut after_event = Event::new(&invocation_id);
2336 after_event.author = agent_name.clone();
2337 after_event.llm_response.content = Some(content);
2338 yield Ok(after_event);
2339 break; }
2341 Ok(None) => {
2342 continue;
2344 }
2345 Err(e) => {
2346 yield Err(e);
2348 return;
2349 }
2350 }
2351 }
2352 };
2353
2354 Ok(Box::pin(s))
2355 }
2356}