1use adk_core::{
2 AfterAgentCallback, AfterModelCallback, AfterToolCallback, AfterToolCallbackFull, Agent,
3 BeforeAgentCallback, BeforeModelCallback, BeforeModelResult, BeforeToolCallback,
4 CallbackContext, Content, Event, EventActions, FunctionResponseData, GlobalInstructionProvider,
5 InstructionProvider, InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry,
6 OnToolErrorCallback, Part, ReadonlyContext, Result, RetryBudget, Tool, ToolCallbackContext,
7 ToolConfirmationDecision, ToolConfirmationPolicy, ToolConfirmationRequest, ToolContext,
8 ToolExecutionStrategy, ToolOutcome, Toolset,
9};
10use async_stream::stream;
11use async_trait::async_trait;
12use std::sync::{Arc, Mutex};
13use tracing::Instrument;
14
15#[cfg(feature = "enhanced-plugins")]
16use adk_plugin::{
17 BeforeModelCallResult, BeforeToolCallResult, EnhancedPlugin, EnhancedPluginManager,
18};
19
20#[cfg(feature = "skills")]
21use crate::skill_shim::load_skill_index;
22use crate::{
23 guardrails::{GuardrailSet, enforce_guardrails},
24 skill_shim::{SelectionPolicy, SkillIndex, select_skill_prompt_block},
25 tool_call_markup::normalize_option_content,
26 workflow::with_user_content_override,
27};
28
29pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
31
32pub const DEFAULT_TOOL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
34
35fn trace_json_payload<T: serde::Serialize>(
36 value: &T,
37 record_payloads: bool,
38 max_bytes: usize,
39) -> String {
40 let json = serde_json::to_string(value).unwrap_or_default();
41 if cfg!(feature = "record-payloads") && record_payloads {
42 return json;
43 }
44
45 let max_bytes = max_bytes.max(32);
46 if json.len() <= max_bytes {
47 return json;
48 }
49
50 let mut end = max_bytes;
51 while !json.is_char_boundary(end) {
52 end -= 1;
53 }
54 format!("{}...[truncated {} bytes]", &json[..end], json.len() - end)
55}
56
57pub struct LlmAgent {
65 name: String,
66 description: String,
67 model: Arc<dyn Llm>,
68 instruction: Option<String>,
69 instruction_provider: Option<Arc<InstructionProvider>>,
70 global_instruction: Option<String>,
71 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
72 skills_index: Option<Arc<SkillIndex>>,
73 skill_policy: SelectionPolicy,
74 max_skill_chars: usize,
75 #[allow(dead_code)] input_schema: Option<serde_json::Value>,
77 output_schema: Option<serde_json::Value>,
78 disallow_transfer_to_parent: bool,
79 disallow_transfer_to_peers: bool,
80 include_contents: adk_core::IncludeContents,
81 tools: Vec<Arc<dyn Tool>>,
82 #[allow(dead_code)] toolsets: Vec<Arc<dyn Toolset>>,
84 sub_agents: Vec<Arc<dyn Agent>>,
85 output_key: Option<String>,
86 generate_content_config: Option<adk_core::GenerateContentConfig>,
88 max_iterations: u32,
90 tool_timeout: std::time::Duration,
92 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
93 after_callbacks: Arc<Vec<AfterAgentCallback>>,
94 before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
95 after_model_callbacks: Arc<Vec<AfterModelCallback>>,
96 before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
97 after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
98 on_tool_error_callbacks: Arc<Vec<OnToolErrorCallback>>,
99 after_tool_callbacks_full: Arc<Vec<AfterToolCallbackFull>>,
101 default_retry_budget: Option<RetryBudget>,
103 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
105 circuit_breaker_threshold: Option<u32>,
108 tool_confirmation_policy: ToolConfirmationPolicy,
109 tool_execution_strategy: Option<ToolExecutionStrategy>,
112 input_guardrails: Arc<GuardrailSet>,
113 output_guardrails: Arc<GuardrailSet>,
114 #[cfg(feature = "enhanced-plugins")]
117 enhanced_plugin_manager: Option<Arc<EnhancedPluginManager>>,
118}
119
120impl std::fmt::Debug for LlmAgent {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("LlmAgent")
123 .field("name", &self.name)
124 .field("description", &self.description)
125 .field("model", &self.model.name())
126 .field("instruction", &self.instruction)
127 .field("tools_count", &self.tools.len())
128 .field("sub_agents_count", &self.sub_agents.len())
129 .finish()
130 }
131}
132
133impl LlmAgent {
134 async fn apply_input_guardrails(
135 ctx: Arc<dyn InvocationContext>,
136 input_guardrails: Arc<GuardrailSet>,
137 ) -> Result<Arc<dyn InvocationContext>> {
138 let content =
139 enforce_guardrails(input_guardrails.as_ref(), ctx.user_content(), "input").await?;
140 if content.role != ctx.user_content().role || content.parts != ctx.user_content().parts {
141 Ok(with_user_content_override(ctx, content))
142 } else {
143 Ok(ctx)
144 }
145 }
146
147 async fn apply_output_guardrails(
148 output_guardrails: &GuardrailSet,
149 content: Content,
150 ) -> Result<Content> {
151 enforce_guardrails(output_guardrails, &content, "output").await
152 }
153
154 fn history_parts_from_provider_metadata(
155 provider_metadata: Option<&serde_json::Value>,
156 ) -> Vec<Part> {
157 let Some(provider_metadata) = provider_metadata else {
158 return Vec::new();
159 };
160
161 let history_parts = provider_metadata
162 .get("conversation_history_parts")
163 .or_else(|| {
164 provider_metadata
165 .get("openai")
166 .and_then(|openai| openai.get("conversation_history_parts"))
167 })
168 .and_then(serde_json::Value::as_array);
169
170 history_parts
171 .into_iter()
172 .flatten()
173 .filter_map(|value| serde_json::from_value::<Part>(value.clone()).ok())
174 .collect()
175 }
176
177 fn augment_content_for_history(
178 content: &Content,
179 provider_metadata: Option<&serde_json::Value>,
180 ) -> Content {
181 let mut augmented = content.clone();
182 augmented.parts.extend(Self::history_parts_from_provider_metadata(provider_metadata));
183 augmented
184 }
185}
186
187pub struct LlmAgentBuilder {
189 name: String,
190 description: Option<String>,
191 model: Option<Arc<dyn Llm>>,
192 instruction: Option<String>,
193 instruction_provider: Option<Arc<InstructionProvider>>,
194 global_instruction: Option<String>,
195 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
196 skills_index: Option<Arc<SkillIndex>>,
197 skill_policy: SelectionPolicy,
198 max_skill_chars: usize,
199 input_schema: Option<serde_json::Value>,
200 output_schema: Option<serde_json::Value>,
201 disallow_transfer_to_parent: bool,
202 disallow_transfer_to_peers: bool,
203 include_contents: adk_core::IncludeContents,
204 tools: Vec<Arc<dyn Tool>>,
205 toolsets: Vec<Arc<dyn Toolset>>,
206 sub_agents: Vec<Arc<dyn Agent>>,
207 output_key: Option<String>,
208 generate_content_config: Option<adk_core::GenerateContentConfig>,
209 max_iterations: u32,
210 tool_timeout: std::time::Duration,
211 before_callbacks: Vec<BeforeAgentCallback>,
212 after_callbacks: Vec<AfterAgentCallback>,
213 before_model_callbacks: Vec<BeforeModelCallback>,
214 after_model_callbacks: Vec<AfterModelCallback>,
215 before_tool_callbacks: Vec<BeforeToolCallback>,
216 after_tool_callbacks: Vec<AfterToolCallback>,
217 on_tool_error_callbacks: Vec<OnToolErrorCallback>,
218 after_tool_callbacks_full: Vec<AfterToolCallbackFull>,
219 default_retry_budget: Option<RetryBudget>,
220 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
221 circuit_breaker_threshold: Option<u32>,
222 tool_confirmation_policy: ToolConfirmationPolicy,
223 tool_execution_strategy: Option<ToolExecutionStrategy>,
224 input_guardrails: GuardrailSet,
225 output_guardrails: GuardrailSet,
226 #[cfg(feature = "enhanced-plugins")]
228 enhanced_plugins: Vec<Arc<dyn EnhancedPlugin>>,
229}
230
231impl LlmAgentBuilder {
232 pub fn new(name: impl Into<String>) -> Self {
234 Self {
235 name: name.into(),
236 description: None,
237 model: None,
238 instruction: None,
239 instruction_provider: None,
240 global_instruction: None,
241 global_instruction_provider: None,
242 skills_index: None,
243 skill_policy: SelectionPolicy::default(),
244 max_skill_chars: 2000,
245 input_schema: None,
246 output_schema: None,
247 disallow_transfer_to_parent: false,
248 disallow_transfer_to_peers: false,
249 include_contents: adk_core::IncludeContents::Default,
250 tools: Vec::new(),
251 toolsets: Vec::new(),
252 sub_agents: Vec::new(),
253 output_key: None,
254 generate_content_config: None,
255 max_iterations: DEFAULT_MAX_ITERATIONS,
256 tool_timeout: DEFAULT_TOOL_TIMEOUT,
257 before_callbacks: Vec::new(),
258 after_callbacks: Vec::new(),
259 before_model_callbacks: Vec::new(),
260 after_model_callbacks: Vec::new(),
261 before_tool_callbacks: Vec::new(),
262 after_tool_callbacks: Vec::new(),
263 on_tool_error_callbacks: Vec::new(),
264 after_tool_callbacks_full: Vec::new(),
265 default_retry_budget: None,
266 tool_retry_budgets: std::collections::HashMap::new(),
267 circuit_breaker_threshold: None,
268 tool_confirmation_policy: ToolConfirmationPolicy::Never,
269 tool_execution_strategy: None,
270 input_guardrails: GuardrailSet::new(),
271 output_guardrails: GuardrailSet::new(),
272 #[cfg(feature = "enhanced-plugins")]
273 enhanced_plugins: Vec::new(),
274 }
275 }
276
277 pub fn description(mut self, desc: impl Into<String>) -> Self {
279 self.description = Some(desc.into());
280 self
281 }
282
283 pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
285 self.model = Some(model);
286 self
287 }
288
289 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
291 self.instruction = Some(instruction.into());
292 self
293 }
294
295 pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
297 self.instruction_provider = Some(Arc::new(provider));
298 self
299 }
300
301 pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
303 self.global_instruction = Some(instruction.into());
304 self
305 }
306
307 pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
309 self.global_instruction_provider = Some(Arc::new(provider));
310 self
311 }
312
313 #[cfg(feature = "skills")]
315 pub fn with_skills(mut self, index: SkillIndex) -> Self {
316 self.skills_index = Some(Arc::new(index));
317 self
318 }
319
320 #[cfg(feature = "skills")]
322 pub fn with_auto_skills(self) -> Result<Self> {
323 self.with_skills_from_root(".")
324 }
325
326 #[cfg(feature = "skills")]
328 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
329 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
330 self.skills_index = Some(Arc::new(index));
331 Ok(self)
332 }
333
334 #[cfg(feature = "skills")]
336 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
337 self.skill_policy = policy;
338 self
339 }
340
341 #[cfg(feature = "skills")]
343 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
344 self.max_skill_chars = max_chars;
345 self
346 }
347
348 pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
350 self.input_schema = Some(schema);
351 self
352 }
353
354 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
356 self.output_schema = Some(schema);
357 self
358 }
359
360 pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
362 self.disallow_transfer_to_parent = disallow;
363 self
364 }
365
366 pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
368 self.disallow_transfer_to_peers = disallow;
369 self
370 }
371
372 pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
374 self.include_contents = include;
375 self
376 }
377
378 pub fn output_key(mut self, key: impl Into<String>) -> Self {
380 self.output_key = Some(key.into());
381 self
382 }
383
384 pub fn generate_content_config(mut self, config: adk_core::GenerateContentConfig) -> Self {
405 self.generate_content_config = Some(config);
406 self
407 }
408
409 pub fn temperature(mut self, temperature: f32) -> Self {
412 self.generate_content_config
413 .get_or_insert(adk_core::GenerateContentConfig::default())
414 .temperature = Some(temperature);
415 self
416 }
417
418 pub fn top_p(mut self, top_p: f32) -> Self {
420 self.generate_content_config
421 .get_or_insert(adk_core::GenerateContentConfig::default())
422 .top_p = Some(top_p);
423 self
424 }
425
426 pub fn top_k(mut self, top_k: i32) -> Self {
428 self.generate_content_config
429 .get_or_insert(adk_core::GenerateContentConfig::default())
430 .top_k = Some(top_k);
431 self
432 }
433
434 pub fn max_output_tokens(mut self, max_tokens: i32) -> Self {
436 self.generate_content_config
437 .get_or_insert(adk_core::GenerateContentConfig::default())
438 .max_output_tokens = Some(max_tokens);
439 self
440 }
441
442 pub fn max_iterations(mut self, max: u32) -> Self {
445 self.max_iterations = max;
446 self
447 }
448
449 pub fn tool_timeout(mut self, timeout: std::time::Duration) -> Self {
452 self.tool_timeout = timeout;
453 self
454 }
455
456 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
458 self.tools.push(tool);
459 self
460 }
461
462 pub fn toolset(mut self, toolset: Arc<dyn Toolset>) -> Self {
468 self.toolsets.push(toolset);
469 self
470 }
471
472 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
474 self.sub_agents.push(agent);
475 self
476 }
477
478 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
480 self.before_callbacks.push(callback);
481 self
482 }
483
484 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
486 self.after_callbacks.push(callback);
487 self
488 }
489
490 pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
492 self.before_model_callbacks.push(callback);
493 self
494 }
495
496 pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
498 self.after_model_callbacks.push(callback);
499 self
500 }
501
502 pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
504 self.before_tool_callbacks.push(callback);
505 self
506 }
507
508 pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
510 self.after_tool_callbacks.push(callback);
511 self
512 }
513
514 pub fn after_tool_callback_full(mut self, callback: AfterToolCallbackFull) -> Self {
529 self.after_tool_callbacks_full.push(callback);
530 self
531 }
532
533 pub fn on_tool_error(mut self, callback: OnToolErrorCallback) -> Self {
541 self.on_tool_error_callbacks.push(callback);
542 self
543 }
544
545 pub fn default_retry_budget(mut self, budget: RetryBudget) -> Self {
552 self.default_retry_budget = Some(budget);
553 self
554 }
555
556 pub fn tool_retry_budget(mut self, tool_name: impl Into<String>, budget: RetryBudget) -> Self {
561 self.tool_retry_budgets.insert(tool_name.into(), budget);
562 self
563 }
564
565 pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
572 self.circuit_breaker_threshold = Some(threshold);
573 self
574 }
575
576 pub fn tool_confirmation_policy(mut self, policy: ToolConfirmationPolicy) -> Self {
578 self.tool_confirmation_policy = policy;
579 self
580 }
581
582 pub fn require_tool_confirmation(mut self, tool_name: impl Into<String>) -> Self {
584 self.tool_confirmation_policy = self.tool_confirmation_policy.with_tool(tool_name);
585 self
586 }
587
588 pub fn require_tool_confirmation_for_all(mut self) -> Self {
590 self.tool_confirmation_policy = ToolConfirmationPolicy::Always;
591 self
592 }
593
594 pub fn tool_execution_strategy(mut self, strategy: ToolExecutionStrategy) -> Self {
600 self.tool_execution_strategy = Some(strategy);
601 self
602 }
603
604 pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
613 self.input_guardrails = guardrails;
614 self
615 }
616
617 pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
626 self.output_guardrails = guardrails;
627 self
628 }
629
630 #[cfg(feature = "enhanced-plugins")]
650 pub fn enhanced_plugin(mut self, plugin: Arc<dyn EnhancedPlugin>) -> Self {
651 self.enhanced_plugins.push(plugin);
652 self
653 }
654
655 #[cfg(feature = "enhanced-plugins")]
677 pub fn enhanced_plugins(mut self, plugins: Vec<Arc<dyn EnhancedPlugin>>) -> Self {
678 self.enhanced_plugins.extend(plugins);
679 self
680 }
681
682 pub fn build(self) -> Result<LlmAgent> {
684 let model = self.model.ok_or_else(|| adk_core::AdkError::agent("Model is required"))?;
685
686 let mut seen_names = std::collections::HashSet::new();
687 for agent in &self.sub_agents {
688 if !seen_names.insert(agent.name()) {
689 return Err(adk_core::AdkError::agent(format!(
690 "Duplicate sub-agent name: {}",
691 agent.name()
692 )));
693 }
694 }
695
696 #[cfg(feature = "enhanced-plugins")]
698 let enhanced_plugin_manager = if self.enhanced_plugins.is_empty() {
699 None
700 } else {
701 Some(Arc::new(EnhancedPluginManager::new(self.enhanced_plugins)))
702 };
703
704 Ok(LlmAgent {
705 name: self.name,
706 description: self.description.unwrap_or_default(),
707 model,
708 instruction: self.instruction,
709 instruction_provider: self.instruction_provider,
710 global_instruction: self.global_instruction,
711 global_instruction_provider: self.global_instruction_provider,
712 skills_index: self.skills_index,
713 skill_policy: self.skill_policy,
714 max_skill_chars: self.max_skill_chars,
715 input_schema: self.input_schema,
716 output_schema: self.output_schema,
717 disallow_transfer_to_parent: self.disallow_transfer_to_parent,
718 disallow_transfer_to_peers: self.disallow_transfer_to_peers,
719 include_contents: self.include_contents,
720 tools: self.tools,
721 toolsets: self.toolsets,
722 sub_agents: self.sub_agents,
723 output_key: self.output_key,
724 generate_content_config: self.generate_content_config,
725 max_iterations: self.max_iterations,
726 tool_timeout: self.tool_timeout,
727 before_callbacks: Arc::new(self.before_callbacks),
728 after_callbacks: Arc::new(self.after_callbacks),
729 before_model_callbacks: Arc::new(self.before_model_callbacks),
730 after_model_callbacks: Arc::new(self.after_model_callbacks),
731 before_tool_callbacks: Arc::new(self.before_tool_callbacks),
732 after_tool_callbacks: Arc::new(self.after_tool_callbacks),
733 on_tool_error_callbacks: Arc::new(self.on_tool_error_callbacks),
734 after_tool_callbacks_full: Arc::new(self.after_tool_callbacks_full),
735 default_retry_budget: self.default_retry_budget,
736 tool_retry_budgets: self.tool_retry_budgets,
737 circuit_breaker_threshold: self.circuit_breaker_threshold,
738 tool_confirmation_policy: self.tool_confirmation_policy,
739 tool_execution_strategy: self.tool_execution_strategy,
740 input_guardrails: Arc::new(self.input_guardrails),
741 output_guardrails: Arc::new(self.output_guardrails),
742 #[cfg(feature = "enhanced-plugins")]
743 enhanced_plugin_manager,
744 })
745 }
746}
747
748struct AgentToolContext {
751 parent_ctx: Arc<dyn InvocationContext>,
752 function_call_id: String,
753 actions: Mutex<EventActions>,
754}
755
756impl AgentToolContext {
757 fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
758 Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
759 }
760
761 fn actions_guard(&self) -> std::sync::MutexGuard<'_, EventActions> {
762 self.actions.lock().unwrap_or_else(|e| e.into_inner())
763 }
764}
765
766#[async_trait]
767impl ReadonlyContext for AgentToolContext {
768 fn invocation_id(&self) -> &str {
769 self.parent_ctx.invocation_id()
770 }
771
772 fn agent_name(&self) -> &str {
773 self.parent_ctx.agent_name()
774 }
775
776 fn user_id(&self) -> &str {
777 self.parent_ctx.user_id()
779 }
780
781 fn app_name(&self) -> &str {
782 self.parent_ctx.app_name()
784 }
785
786 fn session_id(&self) -> &str {
787 self.parent_ctx.session_id()
789 }
790
791 fn branch(&self) -> &str {
792 self.parent_ctx.branch()
793 }
794
795 fn user_content(&self) -> &Content {
796 self.parent_ctx.user_content()
797 }
798}
799
800#[async_trait]
801impl CallbackContext for AgentToolContext {
802 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
803 self.parent_ctx.artifacts()
805 }
806
807 fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
808 self.parent_ctx.shared_state()
809 }
810}
811
812#[async_trait]
813impl ToolContext for AgentToolContext {
814 fn function_call_id(&self) -> &str {
815 &self.function_call_id
816 }
817
818 fn actions(&self) -> EventActions {
819 self.actions_guard().clone()
820 }
821
822 fn set_actions(&self, actions: EventActions) {
823 *self.actions_guard() = actions;
824 }
825
826 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
827 if let Some(memory) = self.parent_ctx.memory() {
829 memory.search(query).await
830 } else {
831 Ok(vec![])
832 }
833 }
834
835 fn user_scopes(&self) -> Vec<String> {
836 self.parent_ctx.user_scopes()
837 }
838
839 async fn get_secret(&self, name: &str) -> Result<Option<String>> {
840 self.parent_ctx.get_secret(name).await
841 }
842}
843
844struct ToolOutcomeCallbackContext {
848 inner: Arc<dyn CallbackContext>,
849 outcome: ToolOutcome,
850}
851
852#[async_trait]
853impl ReadonlyContext for ToolOutcomeCallbackContext {
854 fn invocation_id(&self) -> &str {
855 self.inner.invocation_id()
856 }
857
858 fn agent_name(&self) -> &str {
859 self.inner.agent_name()
860 }
861
862 fn user_id(&self) -> &str {
863 self.inner.user_id()
864 }
865
866 fn app_name(&self) -> &str {
867 self.inner.app_name()
868 }
869
870 fn session_id(&self) -> &str {
871 self.inner.session_id()
872 }
873
874 fn branch(&self) -> &str {
875 self.inner.branch()
876 }
877
878 fn user_content(&self) -> &Content {
879 self.inner.user_content()
880 }
881}
882
883#[async_trait]
884impl CallbackContext for ToolOutcomeCallbackContext {
885 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
886 self.inner.artifacts()
887 }
888
889 fn tool_outcome(&self) -> Option<ToolOutcome> {
890 Some(self.outcome.clone())
891 }
892}
893
894struct CircuitBreakerState {
904 threshold: u32,
905 failures: std::collections::HashMap<String, u32>,
907}
908
909impl CircuitBreakerState {
910 fn new(threshold: u32) -> Self {
911 Self { threshold, failures: std::collections::HashMap::new() }
912 }
913
914 fn is_open(&self, tool_name: &str) -> bool {
916 self.failures.get(tool_name).copied().unwrap_or(0) >= self.threshold
917 }
918
919 fn record(&mut self, outcome: &ToolOutcome) {
921 if outcome.success {
922 self.failures.remove(&outcome.tool_name);
923 } else {
924 let count = self.failures.entry(outcome.tool_name.clone()).or_insert(0);
925 *count += 1;
926 }
927 }
928}
929
930#[async_trait]
931impl Agent for LlmAgent {
932 fn name(&self) -> &str {
933 &self.name
934 }
935
936 fn description(&self) -> &str {
937 &self.description
938 }
939
940 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
941 &self.sub_agents
942 }
943
944 #[adk_telemetry::instrument(
945 skip(self, ctx),
946 fields(
947 agent.name = %self.name,
948 agent.description = %self.description,
949 invocation.id = %ctx.invocation_id(),
950 user.id = %ctx.user_id(),
951 session.id = %ctx.session_id()
952 )
953 )]
954 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
955 adk_telemetry::info!("Starting agent execution");
956 let ctx = Self::apply_input_guardrails(ctx, self.input_guardrails.clone()).await?;
957
958 let agent_name = self.name.clone();
959 let invocation_id = ctx.invocation_id().to_string();
960 let model = self.model.clone();
961 let tools = self.tools.clone();
962 let toolsets = self.toolsets.clone();
963 let sub_agents = self.sub_agents.clone();
964
965 let instruction = self.instruction.clone();
966 let instruction_provider = self.instruction_provider.clone();
967 let global_instruction = self.global_instruction.clone();
968 let global_instruction_provider = self.global_instruction_provider.clone();
969 let skills_index = self.skills_index.clone();
970 let skill_policy = self.skill_policy.clone();
971 let max_skill_chars = self.max_skill_chars;
972 let output_key = self.output_key.clone();
973 let output_schema = self.output_schema.clone();
974 let generate_content_config = self.generate_content_config.clone();
975 let include_contents = self.include_contents;
976 let max_iterations = self.max_iterations;
977 let tool_timeout = self.tool_timeout;
978 let before_agent_callbacks = self.before_callbacks.clone();
980 let after_agent_callbacks = self.after_callbacks.clone();
981 let before_model_callbacks = self.before_model_callbacks.clone();
982 let after_model_callbacks = self.after_model_callbacks.clone();
983 let before_tool_callbacks = self.before_tool_callbacks.clone();
984 let after_tool_callbacks = self.after_tool_callbacks.clone();
985 let on_tool_error_callbacks = self.on_tool_error_callbacks.clone();
986 let after_tool_callbacks_full = self.after_tool_callbacks_full.clone();
987 let default_retry_budget = self.default_retry_budget.clone();
988 let tool_retry_budgets = self.tool_retry_budgets.clone();
989 let circuit_breaker_threshold = self.circuit_breaker_threshold;
990 let tool_confirmation_policy = self.tool_confirmation_policy.clone();
991 let disallow_transfer_to_parent = self.disallow_transfer_to_parent;
992 let disallow_transfer_to_peers = self.disallow_transfer_to_peers;
993 let output_guardrails = self.output_guardrails.clone();
994 let agent_tool_execution_strategy = self.tool_execution_strategy;
995 #[cfg(feature = "enhanced-plugins")]
996 let enhanced_plugin_manager = self.enhanced_plugin_manager.clone();
997
998 let s = stream! {
999 for callback in before_agent_callbacks.as_ref() {
1003 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1004 Ok(Some(content)) => {
1005 let mut early_event = Event::new(&invocation_id);
1007 early_event.author = agent_name.clone();
1008 early_event.llm_response.content = Some(content);
1009 yield Ok(early_event);
1010
1011 for after_callback in after_agent_callbacks.as_ref() {
1013 match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1014 Ok(Some(after_content)) => {
1015 let mut after_event = Event::new(&invocation_id);
1016 after_event.author = agent_name.clone();
1017 after_event.llm_response.content = Some(after_content);
1018 yield Ok(after_event);
1019 return;
1020 }
1021 Ok(None) => continue,
1022 Err(e) => {
1023 yield Err(e);
1024 return;
1025 }
1026 }
1027 }
1028 return;
1029 }
1030 Ok(None) => {
1031 continue;
1033 }
1034 Err(e) => {
1035 yield Err(e);
1037 return;
1038 }
1039 }
1040 }
1041
1042 let mut prompt_preamble = Vec::new();
1044
1045 if let Some(index) = &skills_index {
1049 let user_query = ctx
1050 .user_content()
1051 .parts
1052 .iter()
1053 .filter_map(|part| match part {
1054 Part::Text { text } => Some(text.as_str()),
1055 _ => None,
1056 })
1057 .collect::<Vec<_>>()
1058 .join("\n");
1059
1060 if let Some((_matched, skill_block)) = select_skill_prompt_block(
1061 index.as_ref(),
1062 &user_query,
1063 &skill_policy,
1064 max_skill_chars,
1065 ) {
1066 prompt_preamble.push(Content {
1067 role: "user".to_string(),
1068 parts: vec![Part::Text { text: skill_block }],
1069 });
1070 }
1071 }
1072
1073 if let Some(provider) = &global_instruction_provider {
1076 let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
1078 if !global_inst.is_empty() {
1079 prompt_preamble.push(Content {
1080 role: "user".to_string(),
1081 parts: vec![Part::Text { text: global_inst }],
1082 });
1083 }
1084 } else if let Some(ref template) = global_instruction {
1085 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
1087 if !processed.is_empty() {
1088 prompt_preamble.push(Content {
1089 role: "user".to_string(),
1090 parts: vec![Part::Text { text: processed }],
1091 });
1092 }
1093 }
1094
1095 if let Some(provider) = &instruction_provider {
1098 let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
1100 if !inst.is_empty() {
1101 prompt_preamble.push(Content {
1102 role: "user".to_string(),
1103 parts: vec![Part::Text { text: inst }],
1104 });
1105 }
1106 } else if let Some(ref template) = instruction {
1107 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
1109 if !processed.is_empty() {
1110 prompt_preamble.push(Content {
1111 role: "user".to_string(),
1112 parts: vec![Part::Text { text: processed }],
1113 });
1114 }
1115 }
1116
1117 let session_history = if !ctx.run_config().transfer_targets.is_empty() {
1123 ctx.session().conversation_history_for_agent(&agent_name)
1124 } else {
1125 ctx.session().conversation_history()
1126 };
1127 let mut session_history = session_history;
1128 let current_user_content = ctx.user_content().clone();
1129 if let Some(index) = session_history.iter().rposition(|content| content.role == "user") {
1130 session_history[index] = current_user_content.clone();
1131 } else {
1132 session_history.push(current_user_content.clone());
1133 }
1134
1135 let mut conversation_history = match include_contents {
1138 adk_core::IncludeContents::None => {
1139 let mut filtered = prompt_preamble.clone();
1140 filtered.push(current_user_content);
1141 filtered
1142 }
1143 adk_core::IncludeContents::Default => {
1144 let mut full_history = prompt_preamble;
1145 full_history.extend(session_history);
1146 full_history
1147 }
1148 };
1149
1150 let mut resolved_tools: Vec<Arc<dyn Tool>> = tools.clone();
1153 let static_tool_names: std::collections::HashSet<String> =
1154 tools.iter().map(|t| t.name().to_string()).collect();
1155
1156 let mut toolset_source: std::collections::HashMap<String, String> =
1158 std::collections::HashMap::new();
1159
1160 for toolset in &toolsets {
1161 let toolset_tools = match toolset
1162 .tools(ctx.clone() as Arc<dyn ReadonlyContext>)
1163 .await
1164 {
1165 Ok(t) => t,
1166 Err(e) => {
1167 yield Err(e);
1168 return;
1169 }
1170 };
1171 for tool in &toolset_tools {
1172 let name = tool.name().to_string();
1173 if static_tool_names.contains(&name) {
1175 yield Err(adk_core::AdkError::agent(format!(
1176 "Duplicate tool name '{name}': conflict between static tool and toolset '{}'",
1177 toolset.name()
1178 )));
1179 return;
1180 }
1181 if let Some(other_toolset_name) = toolset_source.get(&name) {
1183 yield Err(adk_core::AdkError::agent(format!(
1184 "Duplicate tool name '{name}': conflict between toolset '{}' and toolset '{}'",
1185 other_toolset_name,
1186 toolset.name()
1187 )));
1188 return;
1189 }
1190 toolset_source.insert(name, toolset.name().to_string());
1191 resolved_tools.push(tool.clone());
1192 }
1193 }
1194
1195 let tool_map: std::collections::HashMap<String, Arc<dyn Tool>> = resolved_tools
1197 .iter()
1198 .map(|t| (t.name().to_string(), t.clone()))
1199 .collect();
1200
1201 let collect_long_running_ids = |content: &Content| -> Vec<String> {
1203 content.parts.iter()
1204 .filter_map(|p| {
1205 if let Part::FunctionCall { name, .. } = p {
1206 if let Some(tool) = tool_map.get(name) {
1207 if tool.is_long_running() {
1208 return Some(name.clone());
1209 }
1210 }
1211 }
1212 None
1213 })
1214 .collect()
1215 };
1216
1217 let mut tool_declarations = std::collections::HashMap::new();
1222 for tool in &resolved_tools {
1223 tool_declarations.insert(tool.name().to_string(), tool.declaration());
1224 }
1225
1226 let mut valid_transfer_targets: Vec<String> = sub_agents
1231 .iter()
1232 .map(|a| a.name().to_string())
1233 .collect();
1234
1235 let run_config_targets = &ctx.run_config().transfer_targets;
1237 let parent_agent_name = ctx.run_config().parent_agent.clone();
1238 let sub_agent_names: std::collections::HashSet<&str> = sub_agents
1239 .iter()
1240 .map(|a| a.name())
1241 .collect();
1242
1243 for target in run_config_targets {
1244 if sub_agent_names.contains(target.as_str()) {
1246 continue;
1247 }
1248
1249 let is_parent = parent_agent_name.as_deref() == Some(target.as_str());
1251 if is_parent && disallow_transfer_to_parent {
1252 continue;
1253 }
1254 if !is_parent && disallow_transfer_to_peers {
1255 continue;
1256 }
1257
1258 valid_transfer_targets.push(target.clone());
1259 }
1260
1261 if !valid_transfer_targets.is_empty() {
1263 let transfer_tool_name = "transfer_to_agent";
1264 let transfer_tool_decl = serde_json::json!({
1265 "name": transfer_tool_name,
1266 "description": format!(
1267 "Transfer execution to another agent. Valid targets: {}",
1268 valid_transfer_targets.join(", ")
1269 ),
1270 "parameters": {
1271 "type": "object",
1272 "properties": {
1273 "agent_name": {
1274 "type": "string",
1275 "description": "The name of the agent to transfer to.",
1276 "enum": valid_transfer_targets
1277 }
1278 },
1279 "required": ["agent_name"]
1280 }
1281 });
1282 tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
1283 }
1284
1285
1286 let mut circuit_breaker_state = circuit_breaker_threshold.map(CircuitBreakerState::new);
1289
1290 let mut iteration = 0;
1292
1293 loop {
1294 iteration += 1;
1295 if iteration > max_iterations {
1296 yield Err(adk_core::AdkError::agent(
1297 format!("Max iterations ({max_iterations}) exceeded")
1298 ));
1299 return;
1300 }
1301
1302 let config = match (&generate_content_config, &output_schema) {
1309 (Some(base), Some(schema)) => {
1310 let mut merged = base.clone();
1311 merged.response_schema = Some(schema.clone());
1312 Some(merged)
1313 }
1314 (Some(base), None) => Some(base.clone()),
1315 (None, Some(schema)) => Some(adk_core::GenerateContentConfig {
1316 response_schema: Some(schema.clone()),
1317 ..Default::default()
1318 }),
1319 (None, None) => None,
1320 };
1321
1322 let config = if let Some(ref cached) = ctx.run_config().cached_content {
1324 let mut cfg = config.unwrap_or_default();
1325 if cfg.cached_content.is_none() {
1327 cfg.cached_content = Some(cached.clone());
1328 }
1329 Some(cfg)
1330 } else {
1331 config
1332 };
1333
1334 let request = LlmRequest {
1335 model: model.name().to_string(),
1336 contents: conversation_history.clone(),
1337 tools: tool_declarations.clone(),
1338 config,
1339 };
1340
1341 #[cfg(feature = "enhanced-plugins")]
1345 let (request, model_response_override_from_plugin) = {
1346 if let Some(epm) = &enhanced_plugin_manager {
1347 match epm.run_before_model_call(request, ctx.clone() as Arc<dyn CallbackContext>).await {
1348 Ok(BeforeModelCallResult::Continue(modified_request)) => {
1349 (modified_request, None)
1350 }
1351 Ok(BeforeModelCallResult::ShortCircuit(response)) => {
1352 (LlmRequest::new("", vec![]), Some(response))
1354 }
1355 Err(e) => {
1356 yield Err(e);
1357 return;
1358 }
1359 }
1360 } else {
1361 (request, None)
1362 }
1363 };
1364 #[cfg(not(feature = "enhanced-plugins"))]
1365 let model_response_override_from_plugin: Option<LlmResponse> = None;
1366
1367 let mut current_request = request;
1370 let mut model_response_override = model_response_override_from_plugin;
1371 if model_response_override.is_none() {
1372 for callback in before_model_callbacks.as_ref() {
1373 match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
1374 Ok(BeforeModelResult::Continue(modified_request)) => {
1375 current_request = modified_request;
1377 }
1378 Ok(BeforeModelResult::Skip(response)) => {
1379 model_response_override = Some(response);
1381 break;
1382 }
1383 Err(e) => {
1384 yield Err(e);
1386 return;
1387 }
1388 }
1389 }
1390 }
1391 let request = current_request;
1392
1393 let mut accumulated_content: Option<Content> = None;
1395 let mut final_provider_metadata: Option<serde_json::Value> = None;
1396
1397 if let Some(cached_response) = model_response_override {
1398 accumulated_content = cached_response.content.clone();
1401 final_provider_metadata = cached_response.provider_metadata.clone();
1402 normalize_option_content(&mut accumulated_content);
1403 if let Some(content) = accumulated_content.take() {
1404 let has_function_calls = content
1405 .parts
1406 .iter()
1407 .any(|part| matches!(part, Part::FunctionCall { .. }));
1408 let content = if has_function_calls {
1409 content
1410 } else {
1411 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1412 };
1413 accumulated_content = Some(content);
1414 }
1415
1416 let mut cached_event = Event::new(&invocation_id);
1417 cached_event.author = agent_name.clone();
1418 cached_event.llm_response.content = accumulated_content.clone();
1419 cached_event.llm_response.provider_metadata = cached_response.provider_metadata.clone();
1420 cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
1421 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), serde_json::to_string(&request).unwrap_or_default());
1422 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&cached_response).unwrap_or_default());
1423
1424 if let Some(ref content) = accumulated_content {
1426 cached_event.long_running_tool_ids = collect_long_running_ids(content);
1427 }
1428
1429 yield Ok(cached_event);
1430 } else {
1431 let request_json = serde_json::to_string(&request).unwrap_or_default();
1433 let trace_request_json = trace_json_payload(
1434 &request,
1435 ctx.run_config().record_payloads,
1436 ctx.run_config().trace_payload_max_bytes,
1437 );
1438
1439 let llm_ts = std::time::SystemTime::now()
1441 .duration_since(std::time::UNIX_EPOCH)
1442 .unwrap_or_default()
1443 .as_nanos();
1444 let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
1445 let llm_span = tracing::info_span!(
1446 "call_llm",
1447 "gcp.vertex.agent.event_id" = %llm_event_id,
1448 "gcp.vertex.agent.invocation_id" = %invocation_id,
1449 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1450 "gen_ai.conversation.id" = %ctx.session_id(),
1451 "gcp.vertex.agent.llm_request" = %trace_request_json,
1452 "gcp.vertex.agent.llm_response" = tracing::field::Empty );
1454 let _llm_guard = llm_span.enter();
1455
1456 use adk_core::StreamingMode;
1458 let streaming_mode = ctx.run_config().streaming_mode;
1459 let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi)
1460 && output_guardrails.is_empty();
1461
1462 let mut response_stream = model.generate_content(request, true).await?;
1464
1465 use futures::StreamExt;
1466
1467 let mut last_chunk: Option<LlmResponse> = None;
1469
1470 while let Some(chunk_result) = response_stream.next().await {
1472 let mut chunk = match chunk_result {
1473 Ok(c) => c,
1474 Err(e) => {
1475 yield Err(e);
1476 return;
1477 }
1478 };
1479
1480 for callback in after_model_callbacks.as_ref() {
1483 match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
1484 Ok(Some(modified_chunk)) => {
1485 chunk = modified_chunk;
1487 break;
1488 }
1489 Ok(None) => {
1490 continue;
1492 }
1493 Err(e) => {
1494 yield Err(e);
1496 return;
1497 }
1498 }
1499 }
1500
1501 normalize_option_content(&mut chunk.content);
1502
1503 if let Some(chunk_content) = chunk.content.clone() {
1505 if let Some(ref mut acc) = accumulated_content {
1506 acc.parts.extend(chunk_content.parts);
1507 } else {
1508 accumulated_content = Some(chunk_content);
1509 }
1510 }
1511
1512 if should_stream_to_client {
1514 let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
1515 partial_event.author = agent_name.clone();
1516 partial_event.llm_request = Some(request_json.clone());
1517 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1518 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&chunk).unwrap_or_default());
1519 partial_event.llm_response.partial = chunk.partial;
1520 partial_event.llm_response.turn_complete = chunk.turn_complete;
1521 partial_event.llm_response.finish_reason = chunk.finish_reason;
1522 partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
1523 partial_event.llm_response.content = chunk.content.clone();
1524 partial_event.llm_response.provider_metadata = chunk.provider_metadata.clone();
1525
1526 if let Some(ref content) = chunk.content {
1528 partial_event.long_running_tool_ids = collect_long_running_ids(content);
1529 }
1530
1531 yield Ok(partial_event);
1532 }
1533
1534 last_chunk = Some(chunk.clone());
1536
1537 if chunk.turn_complete {
1539 break;
1540 }
1541 }
1542
1543 if !should_stream_to_client {
1545 if let Some(content) = accumulated_content.take() {
1546 let has_function_calls = content
1547 .parts
1548 .iter()
1549 .any(|part| matches!(part, Part::FunctionCall { .. }));
1550 let content = if has_function_calls {
1551 content
1552 } else {
1553 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1554 };
1555 accumulated_content = Some(content);
1556 }
1557
1558 let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
1559 final_event.author = agent_name.clone();
1560 final_event.llm_request = Some(request_json.clone());
1561 final_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1562 final_event.llm_response.content = accumulated_content.clone();
1563 final_event.llm_response.partial = false;
1564 final_event.llm_response.turn_complete = true;
1565
1566 if let Some(ref last) = last_chunk {
1568 final_event.llm_response.finish_reason = last.finish_reason;
1569 final_event.llm_response.usage_metadata = last.usage_metadata.clone();
1570 final_event.llm_response.provider_metadata = last.provider_metadata.clone();
1571 final_provider_metadata = last.provider_metadata.clone();
1572 final_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(last).unwrap_or_default());
1573 }
1574
1575 if let Some(ref content) = accumulated_content {
1577 final_event.long_running_tool_ids = collect_long_running_ids(content);
1578 }
1579
1580 yield Ok(final_event);
1581 }
1582
1583 if let Some(ref content) = accumulated_content {
1585 let response_json = trace_json_payload(
1586 content,
1587 ctx.run_config().record_payloads,
1588 ctx.run_config().trace_payload_max_bytes,
1589 );
1590 llm_span.record("gcp.vertex.agent.llm_response", &response_json);
1591 }
1592 }
1593
1594 #[cfg(feature = "enhanced-plugins")]
1598 if let Some(epm) = &enhanced_plugin_manager {
1599 if let Some(ref content) = accumulated_content {
1600 let response_for_hook = LlmResponse {
1601 content: Some(content.clone()),
1602 provider_metadata: final_provider_metadata.clone(),
1603 ..Default::default()
1604 };
1605 match epm.run_after_model_call(response_for_hook, ctx.clone() as Arc<dyn CallbackContext>).await {
1606 Ok(adk_plugin::AfterModelCallResult::Continue(modified_response)) => {
1607 accumulated_content = modified_response.content;
1608 if modified_response.provider_metadata.is_some() {
1609 final_provider_metadata = modified_response.provider_metadata;
1610 }
1611 }
1612 Err(e) => {
1613 yield Err(e);
1614 return;
1615 }
1616 }
1617 }
1618 }
1619
1620 let function_call_names: Vec<String> = accumulated_content.as_ref()
1622 .map(|c| c.parts.iter()
1623 .filter_map(|p| {
1624 if let Part::FunctionCall { name, .. } = p {
1625 Some(name.clone())
1626 } else {
1627 None
1628 }
1629 })
1630 .collect())
1631 .unwrap_or_default();
1632
1633 let has_function_calls = !function_call_names.is_empty();
1634
1635 let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
1639 tool_map.get(name)
1640 .map(|t| t.is_long_running())
1641 .unwrap_or(false)
1642 });
1643
1644 if let Some(ref content) = accumulated_content {
1646 conversation_history.push(Self::augment_content_for_history(
1647 content,
1648 final_provider_metadata.as_ref(),
1649 ));
1650
1651 if let Some(ref output_key) = output_key {
1653 if !has_function_calls { let mut text_parts = String::new();
1655 for part in &content.parts {
1656 if let Part::Text { text } = part {
1657 text_parts.push_str(text);
1658 }
1659 }
1660 if !text_parts.is_empty() {
1661 let mut state_event = Event::new(&invocation_id);
1663 state_event.author = agent_name.clone();
1664 state_event.actions.state_delta.insert(
1665 output_key.clone(),
1666 serde_json::Value::String(text_parts),
1667 );
1668 yield Ok(state_event);
1669 }
1670 }
1671 }
1672 }
1673
1674 if !has_function_calls {
1675 if let Some(ref content) = accumulated_content {
1678 let response_json = trace_json_payload(
1679 content,
1680 ctx.run_config().record_payloads,
1681 ctx.run_config().trace_payload_max_bytes,
1682 );
1683 tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
1684 }
1685
1686 tracing::info!(agent.name = %agent_name, "Agent execution complete");
1687 break;
1688 }
1689
1690 if let Some(content) = &accumulated_content {
1692 let strategy = agent_tool_execution_strategy
1695 .unwrap_or(ToolExecutionStrategy::Sequential);
1696
1697 let mut fc_parts: Vec<(usize, String, serde_json::Value, Option<String>, String)> = Vec::new();
1701 {
1702 let mut tci = 0usize;
1703 for part in &content.parts {
1704 if let Part::FunctionCall { name, args, id, .. } = part {
1705 let fallback = format!("{}_{}_{}", invocation_id, name, tci);
1706 let fcid = id.clone().unwrap_or(fallback);
1707 fc_parts.push((tci, name.clone(), args.clone(), id.clone(), fcid));
1708 tci += 1;
1709 }
1710 }
1711 }
1712
1713 let mut transfer_handled = false;
1717 for (_, fc_name, fc_args, fc_id, _) in &fc_parts {
1718 if fc_name == "transfer_to_agent" {
1719 let target_agent = fc_args.get("agent_name")
1720 .and_then(|v| v.as_str())
1721 .unwrap_or_default()
1722 .to_string();
1723
1724 let valid_target = valid_transfer_targets.iter().any(|n| n == &target_agent);
1725 if !valid_target {
1726 let error_content = Content {
1727 role: "function".to_string(),
1728 parts: vec![Part::FunctionResponse {
1729 function_response: FunctionResponseData::new(
1730 fc_name.clone(),
1731 serde_json::json!({
1732 "error": format!(
1733 "Agent '{}' not found. Available agents: {:?}",
1734 target_agent, valid_transfer_targets
1735 )
1736 }),
1737 ),
1738 id: fc_id.clone(),
1739 }],
1740 };
1741 conversation_history.push(error_content.clone());
1742 let mut error_event = Event::new(&invocation_id);
1743 error_event.author = agent_name.clone();
1744 error_event.llm_response.content = Some(error_content);
1745 yield Ok(error_event);
1746 continue;
1747 }
1748
1749 let mut transfer_event = Event::new(&invocation_id);
1750 transfer_event.author = agent_name.clone();
1751 transfer_event.actions.transfer_to_agent = Some(target_agent);
1752 yield Ok(transfer_event);
1753 transfer_handled = true;
1754 break;
1755 }
1756 }
1757 if transfer_handled {
1758 return;
1759 }
1760
1761 let fc_parts: Vec<_> = fc_parts.into_iter().filter(|(_, fc_name, _, _, _)| {
1763 if fc_name == "transfer_to_agent" {
1764 return false;
1765 }
1766 if let Some(tool) = tool_map.get(fc_name) {
1767 if tool.is_builtin() {
1768 adk_telemetry::debug!(tool.name = %fc_name, "skipping built-in tool execution");
1769 return false;
1770 }
1771 }
1772 true
1773 }).collect();
1774
1775 let mut confirmation_interrupted = false;
1779 for (_, fc_name, fc_args, _, fc_call_id) in &fc_parts {
1780 if tool_confirmation_policy.requires_confirmation(fc_name)
1781 && ctx.run_config().tool_confirmation_decisions.get(fc_name).copied().is_none()
1782 {
1783 let mut ce = Event::new(&invocation_id);
1784 ce.author = agent_name.clone();
1785 ce.llm_response.interrupted = true;
1786 ce.llm_response.turn_complete = true;
1787 ce.llm_response.content = Some(Content {
1788 role: "model".to_string(),
1789 parts: vec![Part::Text {
1790 text: format!(
1791 "Tool confirmation required for '{}'. Provide approve/deny decision to continue.",
1792 fc_name
1793 ),
1794 }],
1795 });
1796 ce.actions.tool_confirmation = Some(ToolConfirmationRequest {
1797 tool_name: fc_name.clone(),
1798 function_call_id: Some(fc_call_id.clone()),
1799 args: fc_args.clone(),
1800 });
1801 yield Ok(ce);
1802 confirmation_interrupted = true;
1803 break;
1804 }
1805 }
1806 if confirmation_interrupted {
1807 return;
1808 }
1809
1810 let cb_mutex = std::sync::Mutex::new(circuit_breaker_state.take());
1812
1813 let concurrency_manager = adk_core::ToolConcurrencyManager::new(
1816 &ctx.run_config().tool_concurrency,
1817 );
1818
1819 let execute_one_tool = |idx: usize, name: String, args: serde_json::Value,
1824 id: Option<String>, function_call_id: String| {
1825 let ctx = ctx.clone();
1826 let tool_map = &tool_map;
1827 let tool_retry_budgets = &tool_retry_budgets;
1828 let default_retry_budget = &default_retry_budget;
1829 let before_tool_callbacks = &before_tool_callbacks;
1830 let after_tool_callbacks = &after_tool_callbacks;
1831 let after_tool_callbacks_full = &after_tool_callbacks_full;
1832 let on_tool_error_callbacks = &on_tool_error_callbacks;
1833 let tool_confirmation_policy = &tool_confirmation_policy;
1834 let cb_mutex = &cb_mutex;
1835 let invocation_id = &invocation_id;
1836 let concurrency_manager = &concurrency_manager;
1837 #[cfg(feature = "enhanced-plugins")]
1838 let enhanced_plugin_manager = &enhanced_plugin_manager;
1839 async move {
1840 let mut tool_actions = EventActions::default();
1841 let mut response_content: Option<Content> = None;
1842 let mut run_after_tool_callbacks = true;
1843 let mut tool_outcome_for_callback: Option<ToolOutcome> = None;
1844 let mut executed_tool: Option<Arc<dyn Tool>> = None;
1845 let mut executed_tool_response: Option<serde_json::Value> = None;
1846
1847 let _concurrency_permit = match concurrency_manager.acquire(&name).await {
1851 Ok(permit) => Some(permit),
1852 Err(e) => {
1853 let error_content = Content {
1855 role: "function".to_string(),
1856 parts: vec![Part::FunctionResponse {
1857 function_response: FunctionResponseData::new(
1858 name.clone(),
1859 serde_json::json!({ "error": e.to_string() }),
1860 ),
1861 id: id.clone(),
1862 }],
1863 };
1864 return (idx, error_content, tool_actions, false);
1865 }
1866 };
1867
1868 if tool_confirmation_policy.requires_confirmation(&name) {
1870 match ctx.run_config().tool_confirmation_decisions.get(&name).copied() {
1871 Some(ToolConfirmationDecision::Approve) => {
1872 tool_actions.tool_confirmation_decision =
1873 Some(ToolConfirmationDecision::Approve);
1874 }
1875 Some(ToolConfirmationDecision::Deny) => {
1876 tool_actions.tool_confirmation_decision =
1877 Some(ToolConfirmationDecision::Deny);
1878 response_content = Some(Content {
1879 role: "function".to_string(),
1880 parts: vec![Part::FunctionResponse {
1881 function_response: FunctionResponseData::new(
1882 name.clone(),
1883 serde_json::json!({
1884 "error": format!("Tool '{}' execution denied by confirmation policy", name)
1885 }),
1886 ),
1887 id: id.clone(),
1888 }],
1889 });
1890 run_after_tool_callbacks = false;
1891 }
1892 None => {
1893 response_content = Some(Content {
1894 role: "function".to_string(),
1895 parts: vec![Part::FunctionResponse {
1896 function_response: FunctionResponseData::new(
1897 name.clone(),
1898 serde_json::json!({
1899 "error": format!("Tool '{}' requires confirmation", name)
1900 }),
1901 ),
1902 id: id.clone(),
1903 }],
1904 });
1905 run_after_tool_callbacks = false;
1906 }
1907 }
1908 }
1909
1910 #[allow(unused_mut)]
1913 let mut final_args = args.clone();
1914
1915 #[cfg(feature = "enhanced-plugins")]
1917 if response_content.is_none() {
1918 if let Some(epm) = &enhanced_plugin_manager {
1919 if let Some(tool_ref) = tool_map.get(&name) {
1920 match epm.run_before_tool_call(
1921 tool_ref.clone(),
1922 final_args.clone(),
1923 ctx.clone() as Arc<dyn CallbackContext>,
1924 ).await {
1925 Ok(BeforeToolCallResult::Continue(modified_args)) => {
1926 final_args = modified_args;
1927 }
1928 Ok(BeforeToolCallResult::ShortCircuit(synthetic_result)) => {
1929 response_content = Some(Content {
1931 role: "function".to_string(),
1932 parts: vec![Part::FunctionResponse {
1933 function_response: FunctionResponseData::from_tool_result(
1934 name.clone(),
1935 synthetic_result,
1936 ),
1937 id: id.clone(),
1938 }],
1939 });
1940 executed_tool = Some(tool_ref.clone());
1941 }
1942 Err(e) => {
1943 response_content = Some(Content {
1944 role: "function".to_string(),
1945 parts: vec![Part::FunctionResponse {
1946 function_response: FunctionResponseData::new(
1947 name.clone(),
1948 serde_json::json!({ "error": e.to_string() }),
1949 ),
1950 id: id.clone(),
1951 }],
1952 });
1953 run_after_tool_callbacks = false;
1954 }
1955 }
1956 }
1957 }
1958 }
1959
1960 if response_content.is_none() {
1961 let tool_ctx = Arc::new(ToolCallbackContext::new(
1962 ctx.clone(),
1963 name.clone(),
1964 final_args.clone(),
1965 ));
1966 for callback in before_tool_callbacks.as_ref() {
1967 match callback(tool_ctx.clone() as Arc<dyn CallbackContext>).await {
1968 Ok(Some(c)) => { response_content = Some(c); break; }
1969 Ok(None) => continue,
1970 Err(e) => {
1971 response_content = Some(Content {
1972 role: "function".to_string(),
1973 parts: vec![Part::FunctionResponse {
1974 function_response: FunctionResponseData::new(
1975 name.clone(),
1976 serde_json::json!({ "error": e.to_string() }),
1977 ),
1978 id: id.clone(),
1979 }],
1980 });
1981 run_after_tool_callbacks = false;
1982 break;
1983 }
1984 }
1985 }
1986 }
1987
1988 if response_content.is_none() {
1990 let guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
1991 if let Some(ref cb_state) = *guard {
1992 if cb_state.is_open(&name) {
1993 let msg = format!(
1994 "Tool '{}' is temporarily disabled after {} consecutive failures",
1995 name, cb_state.threshold
1996 );
1997 tracing::warn!(tool.name = %name, "circuit breaker open, skipping tool execution");
1998 response_content = Some(Content {
1999 role: "function".to_string(),
2000 parts: vec![Part::FunctionResponse {
2001 function_response: FunctionResponseData::new(
2002 name.clone(),
2003 serde_json::json!({ "error": msg }),
2004 ),
2005 id: id.clone(),
2006 }],
2007 });
2008 run_after_tool_callbacks = false;
2009 }
2010 }
2011 drop(guard);
2012 }
2013
2014 if response_content.is_none() {
2016 if let Some(tool) = tool_map.get(&name) {
2017 let tool_ctx: Arc<dyn ToolContext> = Arc::new(
2018 AgentToolContext::new(ctx.clone(), function_call_id.clone()),
2019 );
2020 let span_name = format!("execute_tool {name}");
2021 let tool_span = tracing::info_span!(
2022 "",
2023 otel.name = %span_name,
2024 tool.name = %name,
2025 "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
2026 "gcp.vertex.agent.invocation_id" = %invocation_id,
2027 "gcp.vertex.agent.session_id" = %ctx.session_id(),
2028 "gen_ai.conversation.id" = %ctx.session_id()
2029 );
2030
2031 let budget = tool_retry_budgets.get(&name)
2032 .or(default_retry_budget.as_ref());
2033 let max_attempts = budget.map(|b| b.max_retries + 1).unwrap_or(1);
2034 let retry_delay = budget.map(|b| b.delay).unwrap_or_default();
2035
2036 let tool_clone = tool.clone();
2037 let tool_start = std::time::Instant::now();
2038 let mut last_error = String::new();
2039 let mut final_attempt: u32 = 0;
2040 let mut retry_result: Option<serde_json::Value> = None;
2041
2042 for attempt in 0..max_attempts {
2043 final_attempt = attempt;
2044 if attempt > 0 {
2045 tokio::time::sleep(retry_delay).await;
2046 }
2047 match async {
2048 let args_payload = trace_json_payload(
2049 &final_args,
2050 ctx.run_config().record_payloads,
2051 ctx.run_config().trace_payload_max_bytes,
2052 );
2053 tracing::debug!(tool.name = %name, tool.args = %args_payload, attempt = attempt, "tool_call");
2054 let exec_future = tool_clone.execute(tool_ctx.clone(), final_args.clone());
2055 tokio::time::timeout(tool_timeout, exec_future).await
2056 }.instrument(tool_span.clone()).await {
2057 Ok(Ok(value)) => {
2058 let result_payload = trace_json_payload(
2059 &value,
2060 ctx.run_config().record_payloads,
2061 ctx.run_config().trace_payload_max_bytes,
2062 );
2063 tracing::debug!(tool.name = %name, tool.result = %result_payload, "tool_result");
2064 retry_result = Some(value);
2065 break;
2066 }
2067 Ok(Err(e)) => {
2068 last_error = e.to_string();
2069 if attempt + 1 < max_attempts {
2070 tracing::warn!(tool.name = %name, attempt = attempt, error = %last_error, "tool execution failed, retrying");
2071 } else {
2072 tracing::warn!(tool.name = %name, error = %last_error, "tool_error");
2073 }
2074 }
2075 Err(_) => {
2076 last_error = format!(
2077 "Tool '{}' timed out after {} seconds",
2078 name, tool_timeout.as_secs()
2079 );
2080 if attempt + 1 < max_attempts {
2081 tracing::warn!(tool.name = %name, attempt = attempt, timeout_secs = tool_timeout.as_secs(), "tool timed out, retrying");
2082 } else {
2083 tracing::warn!(tool.name = %name, timeout_secs = tool_timeout.as_secs(), "tool_timeout");
2084 }
2085 }
2086 }
2087 }
2088
2089 let tool_duration = tool_start.elapsed();
2090 let (tool_success, tool_error_message, function_response) = match retry_result {
2091 Some(value) => (true, None, value),
2092 None => (false, Some(last_error.clone()), serde_json::json!({ "error": last_error })),
2093 };
2094
2095 let outcome = ToolOutcome {
2096 tool_name: name.clone(),
2097 tool_args: final_args.clone(),
2098 success: tool_success,
2099 duration: tool_duration,
2100 error_message: tool_error_message.clone(),
2101 attempt: final_attempt,
2102 };
2103 tool_outcome_for_callback = Some(outcome);
2104
2105 {
2107 let mut guard = cb_mutex.lock().unwrap_or_else(|e| e.into_inner());
2108 if let Some(ref mut cb_state) = *guard {
2109 cb_state.record(tool_outcome_for_callback.as_ref().unwrap());
2110 }
2111 }
2112
2113 let final_function_response = if !tool_success {
2115 let mut fallback_result = None;
2116 let error_msg = tool_error_message.clone().unwrap_or_default();
2117 for callback in on_tool_error_callbacks.as_ref() {
2118 match callback(
2119 ctx.clone() as Arc<dyn CallbackContext>,
2120 tool.clone(),
2121 final_args.clone(),
2122 error_msg.clone(),
2123 ).await {
2124 Ok(Some(result)) => { fallback_result = Some(result); break; }
2125 Ok(None) => continue,
2126 Err(e) => { tracing::warn!(error = %e, "on_tool_error callback failed"); break; }
2127 }
2128 }
2129 fallback_result.unwrap_or(function_response)
2130 } else {
2131 function_response
2132 };
2133
2134 let confirmation_decision = tool_actions.tool_confirmation_decision;
2135 tool_actions = tool_ctx.actions();
2136 if tool_actions.tool_confirmation_decision.is_none() {
2137 tool_actions.tool_confirmation_decision = confirmation_decision;
2138 }
2139 executed_tool = Some(tool.clone());
2140 executed_tool_response = Some(final_function_response.clone());
2141 response_content = Some(Content {
2142 role: "function".to_string(),
2143 parts: vec![Part::FunctionResponse {
2144 function_response: FunctionResponseData::from_tool_result(
2145 name.clone(),
2146 final_function_response,
2147 ),
2148 id: id.clone(),
2149 }],
2150 });
2151 } else {
2152 response_content = Some(Content {
2153 role: "function".to_string(),
2154 parts: vec![Part::FunctionResponse {
2155 function_response: FunctionResponseData::new(
2156 name.clone(),
2157 serde_json::json!({
2158 "error": format!("Tool {} not found", name)
2159 }),
2160 ),
2161 id: id.clone(),
2162 }],
2163 });
2164 }
2165 }
2166
2167 let mut response_content = response_content.expect("tool response content is set");
2169 if run_after_tool_callbacks {
2170 let outcome_ctx: Arc<dyn CallbackContext> = match tool_outcome_for_callback {
2171 Some(outcome) => Arc::new(ToolOutcomeCallbackContext {
2172 inner: ctx.clone() as Arc<dyn CallbackContext>,
2173 outcome,
2174 }),
2175 None => ctx.clone() as Arc<dyn CallbackContext>,
2176 };
2177 let cb_ctx: Arc<dyn CallbackContext> = Arc::new(ToolCallbackContext::new(
2178 outcome_ctx,
2179 name.clone(),
2180 final_args.clone(),
2181 ));
2182 for callback in after_tool_callbacks.as_ref() {
2183 match callback(cb_ctx.clone()).await {
2184 Ok(Some(modified)) => { response_content = modified; break; }
2185 Ok(None) => continue,
2186 Err(e) => {
2187 response_content = Content {
2188 role: "function".to_string(),
2189 parts: vec![Part::FunctionResponse {
2190 function_response: FunctionResponseData::new(
2191 name.clone(),
2192 serde_json::json!({ "error": e.to_string() }),
2193 ),
2194 id: id.clone(),
2195 }],
2196 };
2197 break;
2198 }
2199 }
2200 }
2201 if let (Some(tool_ref), Some(tool_resp)) = (&executed_tool, executed_tool_response) {
2202 for callback in after_tool_callbacks_full.as_ref() {
2203 match callback(
2204 cb_ctx.clone(), tool_ref.clone(), final_args.clone(), tool_resp.clone(),
2205 ).await {
2206 Ok(Some(modified_value)) => {
2207 response_content = Content {
2208 role: "function".to_string(),
2209 parts: vec![Part::FunctionResponse {
2210 function_response: FunctionResponseData::from_tool_result(
2211 name.clone(),
2212 modified_value,
2213 ),
2214 id: id.clone(),
2215 }],
2216 };
2217 break;
2218 }
2219 Ok(None) => continue,
2220 Err(e) => {
2221 response_content = Content {
2222 role: "function".to_string(),
2223 parts: vec![Part::FunctionResponse {
2224 function_response: FunctionResponseData::new(
2225 name.clone(),
2226 serde_json::json!({ "error": e.to_string() }),
2227 ),
2228 id: id.clone(),
2229 }],
2230 };
2231 break;
2232 }
2233 }
2234 }
2235 }
2236
2237 #[cfg(feature = "enhanced-plugins")]
2240 if let Some(epm) = &enhanced_plugin_manager {
2241 if let Some(tool_ref) = &executed_tool {
2242 let result_value = response_content.parts.iter()
2244 .find_map(|p| {
2245 if let Part::FunctionResponse { function_response, .. } = p {
2246 Some(function_response.response.clone())
2247 } else {
2248 None
2249 }
2250 })
2251 .unwrap_or(serde_json::json!(null));
2252
2253 match epm.run_after_tool_call(
2254 tool_ref.clone(),
2255 &final_args,
2256 result_value,
2257 ctx.clone() as Arc<dyn CallbackContext>,
2258 ).await {
2259 Ok(adk_plugin::AfterToolCallResult::Continue(modified_result)) => {
2260 response_content = Content {
2261 role: "function".to_string(),
2262 parts: vec![Part::FunctionResponse {
2263 function_response: FunctionResponseData::from_tool_result(
2264 name.clone(),
2265 modified_result,
2266 ),
2267 id: id.clone(),
2268 }],
2269 };
2270 }
2271 Err(e) => {
2272 response_content = Content {
2273 role: "function".to_string(),
2274 parts: vec![Part::FunctionResponse {
2275 function_response: FunctionResponseData::new(
2276 name.clone(),
2277 serde_json::json!({ "error": e.to_string() }),
2278 ),
2279 id: id.clone(),
2280 }],
2281 };
2282 }
2283 }
2284 }
2285 }
2286 }
2287
2288 let escalate_or_skip = tool_actions.escalate || tool_actions.skip_summarization;
2289 (idx, response_content, tool_actions, escalate_or_skip)
2290 }
2291 };
2292
2293 let mut results: Vec<(usize, Content, EventActions, bool)> = match strategy {
2295 ToolExecutionStrategy::Sequential => {
2296 let mut results = Vec::with_capacity(fc_parts.len());
2297 for (idx, name, args, id, fcid) in fc_parts {
2298 results.push(execute_one_tool(idx, name, args, id, fcid).await);
2299 }
2300 results
2301 }
2302 ToolExecutionStrategy::Parallel => {
2303 use futures::StreamExt as _;
2304 let buffer_size = fc_parts.len().max(1);
2309 futures::stream::iter(fc_parts.into_iter().map(
2310 |(idx, name, args, id, fcid)| {
2311 execute_one_tool(idx, name, args, id, fcid)
2312 },
2313 ))
2314 .buffer_unordered(buffer_size)
2315 .collect()
2316 .await
2317 }
2318 ToolExecutionStrategy::Auto => {
2319 let mut read_only_fcs = Vec::new();
2321 let mut mutable_fcs = Vec::new();
2322 for fc in fc_parts {
2323 let is_ro = tool_map.get(&fc.1)
2324 .map(|t| t.is_read_only())
2325 .unwrap_or(false);
2326 if is_ro { read_only_fcs.push(fc); } else { mutable_fcs.push(fc); }
2327 }
2328 let mut all_results = Vec::new();
2329 if !read_only_fcs.is_empty() {
2333 use futures::StreamExt as _;
2334 let buffer_size = read_only_fcs.len().max(1);
2335 all_results.extend(
2336 futures::stream::iter(read_only_fcs.into_iter().map(
2337 |(idx, name, args, id, fcid)| {
2338 execute_one_tool(idx, name, args, id, fcid)
2339 },
2340 ))
2341 .buffer_unordered(buffer_size)
2342 .collect::<Vec<_>>()
2343 .await,
2344 );
2345 }
2346 for (idx, name, args, id, fcid) in mutable_fcs {
2348 all_results.push(execute_one_tool(idx, name, args, id, fcid).await);
2349 }
2350 all_results
2351 }
2352 };
2353 results.sort_by_key(|r| r.0);
2355
2356 circuit_breaker_state = cb_mutex.into_inner().unwrap_or_else(|e| e.into_inner());
2358
2359 for (_, response_content, tool_actions, escalate_or_skip) in results {
2361 let mut tool_event = Event::new(&invocation_id);
2362 tool_event.author = agent_name.clone();
2363 tool_event.actions = tool_actions;
2364 tool_event.llm_response.content = Some(response_content.clone());
2365 yield Ok(tool_event);
2366
2367 if escalate_or_skip {
2368 return;
2369 }
2370
2371 conversation_history.push(response_content);
2372 }
2373 }
2374
2375 if all_calls_are_long_running {
2379 }
2383 }
2384
2385 for callback in after_agent_callbacks.as_ref() {
2388 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
2389 Ok(Some(content)) => {
2390 let mut after_event = Event::new(&invocation_id);
2392 after_event.author = agent_name.clone();
2393 after_event.llm_response.content = Some(content);
2394 yield Ok(after_event);
2395 break; }
2397 Ok(None) => {
2398 continue;
2400 }
2401 Err(e) => {
2402 yield Err(e);
2404 return;
2405 }
2406 }
2407 }
2408 };
2409
2410 Ok(Box::pin(s))
2411 }
2412}