Skip to main content

adk_agent/
llm_agent.rs

1use adk_core::{
2    AfterAgentCallback, AfterModelCallback, AfterToolCallback, AfterToolCallbackFull, Agent,
3    BeforeAgentCallback, BeforeModelCallback, BeforeModelResult, BeforeToolCallback,
4    CallbackContext, Content, Event, EventActions, FunctionResponseData, GlobalInstructionProvider,
5    InstructionProvider, InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry,
6    OnToolErrorCallback, Part, ReadonlyContext, Result, RetryBudget, Tool,
7    ToolConfirmationDecision, ToolConfirmationPolicy, ToolConfirmationRequest, ToolContext,
8    ToolOutcome, Toolset,
9};
10use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index, select_skill_prompt_block};
11use async_stream::stream;
12use async_trait::async_trait;
13use std::sync::{Arc, Mutex};
14use tracing::Instrument;
15
16use crate::{
17    guardrails::{GuardrailSet, enforce_guardrails},
18    tool_call_markup::normalize_option_content,
19    workflow::with_user_content_override,
20};
21
22/// Default maximum number of LLM round-trips (iterations) before the agent stops.
23pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
24
25/// Default tool execution timeout (5 minutes).
26pub const DEFAULT_TOOL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
27
28pub struct LlmAgent {
29    name: String,
30    description: String,
31    model: Arc<dyn Llm>,
32    instruction: Option<String>,
33    instruction_provider: Option<Arc<InstructionProvider>>,
34    global_instruction: Option<String>,
35    global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
36    skills_index: Option<Arc<SkillIndex>>,
37    skill_policy: SelectionPolicy,
38    max_skill_chars: usize,
39    #[allow(dead_code)] // Part of public API via builder
40    input_schema: Option<serde_json::Value>,
41    output_schema: Option<serde_json::Value>,
42    disallow_transfer_to_parent: bool,
43    disallow_transfer_to_peers: bool,
44    include_contents: adk_core::IncludeContents,
45    tools: Vec<Arc<dyn Tool>>,
46    #[allow(dead_code)] // Used in runtime toolset resolution (task 2.2)
47    toolsets: Vec<Arc<dyn Toolset>>,
48    sub_agents: Vec<Arc<dyn Agent>>,
49    output_key: Option<String>,
50    /// Default generation config (temperature, top_p, etc.) applied to every LLM request.
51    generate_content_config: Option<adk_core::GenerateContentConfig>,
52    /// Maximum number of LLM round-trips before stopping
53    max_iterations: u32,
54    /// Timeout for individual tool executions
55    tool_timeout: std::time::Duration,
56    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
57    after_callbacks: Arc<Vec<AfterAgentCallback>>,
58    before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
59    after_model_callbacks: Arc<Vec<AfterModelCallback>>,
60    before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
61    after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
62    on_tool_error_callbacks: Arc<Vec<OnToolErrorCallback>>,
63    /// Rich after-tool callbacks that receive tool, args, and response.
64    after_tool_callbacks_full: Arc<Vec<AfterToolCallbackFull>>,
65    /// Default retry budget applied to all tools without a per-tool override.
66    default_retry_budget: Option<RetryBudget>,
67    /// Per-tool retry budget overrides, keyed by tool name.
68    tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
69    /// Circuit breaker failure threshold. When set, tools are temporarily disabled
70    /// after this many consecutive failures within a single invocation.
71    circuit_breaker_threshold: Option<u32>,
72    tool_confirmation_policy: ToolConfirmationPolicy,
73    input_guardrails: Arc<GuardrailSet>,
74    output_guardrails: Arc<GuardrailSet>,
75}
76
77impl std::fmt::Debug for LlmAgent {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("LlmAgent")
80            .field("name", &self.name)
81            .field("description", &self.description)
82            .field("model", &self.model.name())
83            .field("instruction", &self.instruction)
84            .field("tools_count", &self.tools.len())
85            .field("sub_agents_count", &self.sub_agents.len())
86            .finish()
87    }
88}
89
90impl LlmAgent {
91    async fn apply_input_guardrails(
92        ctx: Arc<dyn InvocationContext>,
93        input_guardrails: Arc<GuardrailSet>,
94    ) -> Result<Arc<dyn InvocationContext>> {
95        let content =
96            enforce_guardrails(input_guardrails.as_ref(), ctx.user_content(), "input").await?;
97        if content.role != ctx.user_content().role || content.parts != ctx.user_content().parts {
98            Ok(with_user_content_override(ctx, content))
99        } else {
100            Ok(ctx)
101        }
102    }
103
104    async fn apply_output_guardrails(
105        output_guardrails: &GuardrailSet,
106        content: Content,
107    ) -> Result<Content> {
108        enforce_guardrails(output_guardrails, &content, "output").await
109    }
110}
111
112pub struct LlmAgentBuilder {
113    name: String,
114    description: Option<String>,
115    model: Option<Arc<dyn Llm>>,
116    instruction: Option<String>,
117    instruction_provider: Option<Arc<InstructionProvider>>,
118    global_instruction: Option<String>,
119    global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
120    skills_index: Option<Arc<SkillIndex>>,
121    skill_policy: SelectionPolicy,
122    max_skill_chars: usize,
123    input_schema: Option<serde_json::Value>,
124    output_schema: Option<serde_json::Value>,
125    disallow_transfer_to_parent: bool,
126    disallow_transfer_to_peers: bool,
127    include_contents: adk_core::IncludeContents,
128    tools: Vec<Arc<dyn Tool>>,
129    toolsets: Vec<Arc<dyn Toolset>>,
130    sub_agents: Vec<Arc<dyn Agent>>,
131    output_key: Option<String>,
132    generate_content_config: Option<adk_core::GenerateContentConfig>,
133    max_iterations: u32,
134    tool_timeout: std::time::Duration,
135    before_callbacks: Vec<BeforeAgentCallback>,
136    after_callbacks: Vec<AfterAgentCallback>,
137    before_model_callbacks: Vec<BeforeModelCallback>,
138    after_model_callbacks: Vec<AfterModelCallback>,
139    before_tool_callbacks: Vec<BeforeToolCallback>,
140    after_tool_callbacks: Vec<AfterToolCallback>,
141    on_tool_error_callbacks: Vec<OnToolErrorCallback>,
142    after_tool_callbacks_full: Vec<AfterToolCallbackFull>,
143    default_retry_budget: Option<RetryBudget>,
144    tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
145    circuit_breaker_threshold: Option<u32>,
146    tool_confirmation_policy: ToolConfirmationPolicy,
147    input_guardrails: GuardrailSet,
148    output_guardrails: GuardrailSet,
149}
150
151impl LlmAgentBuilder {
152    pub fn new(name: impl Into<String>) -> Self {
153        Self {
154            name: name.into(),
155            description: None,
156            model: None,
157            instruction: None,
158            instruction_provider: None,
159            global_instruction: None,
160            global_instruction_provider: None,
161            skills_index: None,
162            skill_policy: SelectionPolicy::default(),
163            max_skill_chars: 2000,
164            input_schema: None,
165            output_schema: None,
166            disallow_transfer_to_parent: false,
167            disallow_transfer_to_peers: false,
168            include_contents: adk_core::IncludeContents::Default,
169            tools: Vec::new(),
170            toolsets: Vec::new(),
171            sub_agents: Vec::new(),
172            output_key: None,
173            generate_content_config: None,
174            max_iterations: DEFAULT_MAX_ITERATIONS,
175            tool_timeout: DEFAULT_TOOL_TIMEOUT,
176            before_callbacks: Vec::new(),
177            after_callbacks: Vec::new(),
178            before_model_callbacks: Vec::new(),
179            after_model_callbacks: Vec::new(),
180            before_tool_callbacks: Vec::new(),
181            after_tool_callbacks: Vec::new(),
182            on_tool_error_callbacks: Vec::new(),
183            after_tool_callbacks_full: Vec::new(),
184            default_retry_budget: None,
185            tool_retry_budgets: std::collections::HashMap::new(),
186            circuit_breaker_threshold: None,
187            tool_confirmation_policy: ToolConfirmationPolicy::Never,
188            input_guardrails: GuardrailSet::new(),
189            output_guardrails: GuardrailSet::new(),
190        }
191    }
192
193    pub fn description(mut self, desc: impl Into<String>) -> Self {
194        self.description = Some(desc.into());
195        self
196    }
197
198    pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
199        self.model = Some(model);
200        self
201    }
202
203    pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
204        self.instruction = Some(instruction.into());
205        self
206    }
207
208    pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
209        self.instruction_provider = Some(Arc::new(provider));
210        self
211    }
212
213    pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
214        self.global_instruction = Some(instruction.into());
215        self
216    }
217
218    pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
219        self.global_instruction_provider = Some(Arc::new(provider));
220        self
221    }
222
223    /// Set a preloaded skills index for this agent.
224    pub fn with_skills(mut self, index: SkillIndex) -> Self {
225        self.skills_index = Some(Arc::new(index));
226        self
227    }
228
229    /// Auto-load skills from `.skills/` in the current working directory.
230    pub fn with_auto_skills(self) -> Result<Self> {
231        self.with_skills_from_root(".")
232    }
233
234    /// Auto-load skills from `.skills/` under a custom root directory.
235    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
236        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
237        self.skills_index = Some(Arc::new(index));
238        Ok(self)
239    }
240
241    /// Customize skill selection behavior.
242    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
243        self.skill_policy = policy;
244        self
245    }
246
247    /// Limit injected skill content length.
248    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
249        self.max_skill_chars = max_chars;
250        self
251    }
252
253    pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
254        self.input_schema = Some(schema);
255        self
256    }
257
258    pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
259        self.output_schema = Some(schema);
260        self
261    }
262
263    pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
264        self.disallow_transfer_to_parent = disallow;
265        self
266    }
267
268    pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
269        self.disallow_transfer_to_peers = disallow;
270        self
271    }
272
273    pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
274        self.include_contents = include;
275        self
276    }
277
278    pub fn output_key(mut self, key: impl Into<String>) -> Self {
279        self.output_key = Some(key.into());
280        self
281    }
282
283    /// Set default generation parameters (temperature, top_p, top_k, max_output_tokens)
284    /// applied to every LLM request made by this agent.
285    ///
286    /// These defaults are merged with any per-request config. If `output_schema` is also
287    /// set, the schema is preserved alongside these generation parameters.
288    ///
289    /// # Example
290    ///
291    /// ```rust,ignore
292    /// use adk_core::GenerateContentConfig;
293    ///
294    /// let agent = LlmAgentBuilder::new("my-agent")
295    ///     .model(model)
296    ///     .generate_content_config(GenerateContentConfig {
297    ///         temperature: Some(0.7),
298    ///         max_output_tokens: Some(2048),
299    ///         ..Default::default()
300    ///     })
301    ///     .build()?;
302    /// ```
303    pub fn generate_content_config(mut self, config: adk_core::GenerateContentConfig) -> Self {
304        self.generate_content_config = Some(config);
305        self
306    }
307
308    /// Set the default temperature for LLM requests.
309    /// Shorthand for setting just temperature without a full `GenerateContentConfig`.
310    pub fn temperature(mut self, temperature: f32) -> Self {
311        self.generate_content_config
312            .get_or_insert(adk_core::GenerateContentConfig::default())
313            .temperature = Some(temperature);
314        self
315    }
316
317    /// Set the default top_p for LLM requests.
318    pub fn top_p(mut self, top_p: f32) -> Self {
319        self.generate_content_config
320            .get_or_insert(adk_core::GenerateContentConfig::default())
321            .top_p = Some(top_p);
322        self
323    }
324
325    /// Set the default top_k for LLM requests.
326    pub fn top_k(mut self, top_k: i32) -> Self {
327        self.generate_content_config
328            .get_or_insert(adk_core::GenerateContentConfig::default())
329            .top_k = Some(top_k);
330        self
331    }
332
333    /// Set the default max output tokens for LLM requests.
334    pub fn max_output_tokens(mut self, max_tokens: i32) -> Self {
335        self.generate_content_config
336            .get_or_insert(adk_core::GenerateContentConfig::default())
337            .max_output_tokens = Some(max_tokens);
338        self
339    }
340
341    /// Set the maximum number of LLM round-trips (iterations) before the agent stops.
342    /// Default is 100.
343    pub fn max_iterations(mut self, max: u32) -> Self {
344        self.max_iterations = max;
345        self
346    }
347
348    /// Set the timeout for individual tool executions.
349    /// Default is 5 minutes. Tools that exceed this timeout will return an error.
350    pub fn tool_timeout(mut self, timeout: std::time::Duration) -> Self {
351        self.tool_timeout = timeout;
352        self
353    }
354
355    pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
356        self.tools.push(tool);
357        self
358    }
359
360    /// Register a dynamic toolset for per-invocation tool resolution.
361    ///
362    /// Toolsets are resolved at the start of each `run()` call using the
363    /// invocation's `ReadonlyContext`. This enables context-dependent tools
364    /// like per-user browser sessions from a pool.
365    pub fn toolset(mut self, toolset: Arc<dyn Toolset>) -> Self {
366        self.toolsets.push(toolset);
367        self
368    }
369
370    pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
371        self.sub_agents.push(agent);
372        self
373    }
374
375    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
376        self.before_callbacks.push(callback);
377        self
378    }
379
380    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
381        self.after_callbacks.push(callback);
382        self
383    }
384
385    pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
386        self.before_model_callbacks.push(callback);
387        self
388    }
389
390    pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
391        self.after_model_callbacks.push(callback);
392        self
393    }
394
395    pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
396        self.before_tool_callbacks.push(callback);
397        self
398    }
399
400    pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
401        self.after_tool_callbacks.push(callback);
402        self
403    }
404
405    /// Register a rich after-tool callback that receives the tool, arguments,
406    /// and response value.
407    ///
408    /// This is the V2 callback surface aligned with the Python/Go ADK model
409    /// where `after_tool_callback` receives the full tool execution context.
410    /// Unlike [`after_tool_callback`](Self::after_tool_callback) (which only
411    /// receives `CallbackContext`), this callback can inspect and modify tool
412    /// results directly.
413    ///
414    /// Return `Ok(None)` to keep the original response, or `Ok(Some(value))`
415    /// to replace the function response sent to the LLM.
416    ///
417    /// These callbacks run after the legacy `after_tool_callback` chain.
418    /// `ToolOutcome` is available via `ctx.tool_outcome()`.
419    pub fn after_tool_callback_full(mut self, callback: AfterToolCallbackFull) -> Self {
420        self.after_tool_callbacks_full.push(callback);
421        self
422    }
423
424    /// Register a callback invoked when a tool execution fails
425    /// (after retries are exhausted).
426    ///
427    /// If the callback returns `Ok(Some(value))`, the value is used as a
428    /// fallback function response to the LLM. If it returns `Ok(None)`,
429    /// the next callback in the chain is tried. If no callback provides a
430    /// fallback, the original error is reported to the LLM.
431    pub fn on_tool_error(mut self, callback: OnToolErrorCallback) -> Self {
432        self.on_tool_error_callbacks.push(callback);
433        self
434    }
435
436    /// Set a default retry budget applied to all tools that do not have
437    /// a per-tool override.
438    ///
439    /// When a tool execution fails and a retry budget applies, the agent
440    /// retries up to `budget.max_retries` times with the configured delay
441    /// between attempts.
442    pub fn default_retry_budget(mut self, budget: RetryBudget) -> Self {
443        self.default_retry_budget = Some(budget);
444        self
445    }
446
447    /// Set a per-tool retry budget that overrides the default for the
448    /// named tool.
449    ///
450    /// Per-tool budgets take precedence over the default retry budget.
451    pub fn tool_retry_budget(mut self, tool_name: impl Into<String>, budget: RetryBudget) -> Self {
452        self.tool_retry_budgets.insert(tool_name.into(), budget);
453        self
454    }
455
456    /// Configure a circuit breaker that temporarily disables tools after
457    /// `threshold` consecutive failures within a single invocation.
458    ///
459    /// When a tool's consecutive failure count reaches the threshold, subsequent
460    /// calls to that tool are short-circuited with an immediate error response
461    /// until the next invocation (which resets the state).
462    pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
463        self.circuit_breaker_threshold = Some(threshold);
464        self
465    }
466
467    /// Configure tool confirmation requirements for this agent.
468    pub fn tool_confirmation_policy(mut self, policy: ToolConfirmationPolicy) -> Self {
469        self.tool_confirmation_policy = policy;
470        self
471    }
472
473    /// Require confirmation for a specific tool name.
474    pub fn require_tool_confirmation(mut self, tool_name: impl Into<String>) -> Self {
475        self.tool_confirmation_policy = self.tool_confirmation_policy.with_tool(tool_name);
476        self
477    }
478
479    /// Require confirmation for all tool calls.
480    pub fn require_tool_confirmation_for_all(mut self) -> Self {
481        self.tool_confirmation_policy = ToolConfirmationPolicy::Always;
482        self
483    }
484
485    /// Set input guardrails to validate user input before processing.
486    ///
487    /// Input guardrails run before the agent processes the request and can:
488    /// - Block harmful or off-topic content
489    /// - Redact PII from user input
490    /// - Enforce input length limits
491    ///
492    /// Requires the `guardrails` feature.
493    pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
494        self.input_guardrails = guardrails;
495        self
496    }
497
498    /// Set output guardrails to validate agent responses.
499    ///
500    /// Output guardrails run after the agent generates a response and can:
501    /// - Enforce JSON schema compliance
502    /// - Redact PII from responses
503    /// - Block harmful content in responses
504    ///
505    /// Requires the `guardrails` feature.
506    pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
507        self.output_guardrails = guardrails;
508        self
509    }
510
511    pub fn build(self) -> Result<LlmAgent> {
512        let model =
513            self.model.ok_or_else(|| adk_core::AdkError::Agent("Model is required".to_string()))?;
514
515        let mut seen_names = std::collections::HashSet::new();
516        for agent in &self.sub_agents {
517            if !seen_names.insert(agent.name()) {
518                return Err(adk_core::AdkError::Agent(format!(
519                    "Duplicate sub-agent name: {}",
520                    agent.name()
521                )));
522            }
523        }
524
525        Ok(LlmAgent {
526            name: self.name,
527            description: self.description.unwrap_or_default(),
528            model,
529            instruction: self.instruction,
530            instruction_provider: self.instruction_provider,
531            global_instruction: self.global_instruction,
532            global_instruction_provider: self.global_instruction_provider,
533            skills_index: self.skills_index,
534            skill_policy: self.skill_policy,
535            max_skill_chars: self.max_skill_chars,
536            input_schema: self.input_schema,
537            output_schema: self.output_schema,
538            disallow_transfer_to_parent: self.disallow_transfer_to_parent,
539            disallow_transfer_to_peers: self.disallow_transfer_to_peers,
540            include_contents: self.include_contents,
541            tools: self.tools,
542            toolsets: self.toolsets,
543            sub_agents: self.sub_agents,
544            output_key: self.output_key,
545            generate_content_config: self.generate_content_config,
546            max_iterations: self.max_iterations,
547            tool_timeout: self.tool_timeout,
548            before_callbacks: Arc::new(self.before_callbacks),
549            after_callbacks: Arc::new(self.after_callbacks),
550            before_model_callbacks: Arc::new(self.before_model_callbacks),
551            after_model_callbacks: Arc::new(self.after_model_callbacks),
552            before_tool_callbacks: Arc::new(self.before_tool_callbacks),
553            after_tool_callbacks: Arc::new(self.after_tool_callbacks),
554            on_tool_error_callbacks: Arc::new(self.on_tool_error_callbacks),
555            after_tool_callbacks_full: Arc::new(self.after_tool_callbacks_full),
556            default_retry_budget: self.default_retry_budget,
557            tool_retry_budgets: self.tool_retry_budgets,
558            circuit_breaker_threshold: self.circuit_breaker_threshold,
559            tool_confirmation_policy: self.tool_confirmation_policy,
560            input_guardrails: Arc::new(self.input_guardrails),
561            output_guardrails: Arc::new(self.output_guardrails),
562        })
563    }
564}
565
566// AgentToolContext wraps the parent InvocationContext and preserves all context
567// instead of throwing it away like SimpleToolContext did
568struct AgentToolContext {
569    parent_ctx: Arc<dyn InvocationContext>,
570    function_call_id: String,
571    actions: Mutex<EventActions>,
572}
573
574impl AgentToolContext {
575    fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
576        Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
577    }
578
579    fn actions_guard(&self) -> std::sync::MutexGuard<'_, EventActions> {
580        self.actions.lock().unwrap_or_else(|e| e.into_inner())
581    }
582}
583
584#[async_trait]
585impl ReadonlyContext for AgentToolContext {
586    fn invocation_id(&self) -> &str {
587        self.parent_ctx.invocation_id()
588    }
589
590    fn agent_name(&self) -> &str {
591        self.parent_ctx.agent_name()
592    }
593
594    fn user_id(&self) -> &str {
595        // ✅ Delegate to parent - now tools get the real user_id!
596        self.parent_ctx.user_id()
597    }
598
599    fn app_name(&self) -> &str {
600        // ✅ Delegate to parent - now tools get the real app_name!
601        self.parent_ctx.app_name()
602    }
603
604    fn session_id(&self) -> &str {
605        // ✅ Delegate to parent - now tools get the real session_id!
606        self.parent_ctx.session_id()
607    }
608
609    fn branch(&self) -> &str {
610        self.parent_ctx.branch()
611    }
612
613    fn user_content(&self) -> &Content {
614        self.parent_ctx.user_content()
615    }
616}
617
618#[async_trait]
619impl CallbackContext for AgentToolContext {
620    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
621        // ✅ Delegate to parent - tools can now access artifacts!
622        self.parent_ctx.artifacts()
623    }
624}
625
626#[async_trait]
627impl ToolContext for AgentToolContext {
628    fn function_call_id(&self) -> &str {
629        &self.function_call_id
630    }
631
632    fn actions(&self) -> EventActions {
633        self.actions_guard().clone()
634    }
635
636    fn set_actions(&self, actions: EventActions) {
637        *self.actions_guard() = actions;
638    }
639
640    async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
641        // ✅ Delegate to parent's memory if available
642        if let Some(memory) = self.parent_ctx.memory() {
643            memory.search(query).await
644        } else {
645            Ok(vec![])
646        }
647    }
648
649    fn user_scopes(&self) -> Vec<String> {
650        self.parent_ctx.user_scopes()
651    }
652}
653
654/// Wrapper that adds ToolOutcome to an existing CallbackContext.
655/// Used only during after-tool callback invocation so callbacks
656/// can inspect structured metadata about the completed tool execution.
657struct ToolCallbackContext {
658    inner: Arc<dyn CallbackContext>,
659    outcome: ToolOutcome,
660}
661
662#[async_trait]
663impl ReadonlyContext for ToolCallbackContext {
664    fn invocation_id(&self) -> &str {
665        self.inner.invocation_id()
666    }
667
668    fn agent_name(&self) -> &str {
669        self.inner.agent_name()
670    }
671
672    fn user_id(&self) -> &str {
673        self.inner.user_id()
674    }
675
676    fn app_name(&self) -> &str {
677        self.inner.app_name()
678    }
679
680    fn session_id(&self) -> &str {
681        self.inner.session_id()
682    }
683
684    fn branch(&self) -> &str {
685        self.inner.branch()
686    }
687
688    fn user_content(&self) -> &Content {
689        self.inner.user_content()
690    }
691}
692
693#[async_trait]
694impl CallbackContext for ToolCallbackContext {
695    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
696        self.inner.artifacts()
697    }
698
699    fn tool_outcome(&self) -> Option<ToolOutcome> {
700        Some(self.outcome.clone())
701    }
702}
703
704/// Per-invocation circuit breaker state.
705///
706/// Tracks consecutive failures per tool name within a single agent
707/// invocation. When a tool's consecutive failure count reaches the
708/// configured threshold the breaker "opens" and subsequent calls to
709/// that tool are short-circuited with an immediate error response.
710///
711/// The state is created fresh at the start of each `run()` call so
712/// it automatically resets between invocations.
713struct CircuitBreakerState {
714    threshold: u32,
715    /// tool_name → consecutive failure count
716    failures: std::collections::HashMap<String, u32>,
717}
718
719impl CircuitBreakerState {
720    fn new(threshold: u32) -> Self {
721        Self { threshold, failures: std::collections::HashMap::new() }
722    }
723
724    /// Returns `true` if the tool is currently tripped (open state).
725    fn is_open(&self, tool_name: &str) -> bool {
726        self.failures.get(tool_name).copied().unwrap_or(0) >= self.threshold
727    }
728
729    /// Record a tool outcome. Resets count on success, increments on failure.
730    fn record(&mut self, outcome: &ToolOutcome) {
731        if outcome.success {
732            self.failures.remove(&outcome.tool_name);
733        } else {
734            let count = self.failures.entry(outcome.tool_name.clone()).or_insert(0);
735            *count += 1;
736        }
737    }
738}
739
740#[async_trait]
741impl Agent for LlmAgent {
742    fn name(&self) -> &str {
743        &self.name
744    }
745
746    fn description(&self) -> &str {
747        &self.description
748    }
749
750    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
751        &self.sub_agents
752    }
753
754    #[adk_telemetry::instrument(
755        skip(self, ctx),
756        fields(
757            agent.name = %self.name,
758            agent.description = %self.description,
759            invocation.id = %ctx.invocation_id(),
760            user.id = %ctx.user_id(),
761            session.id = %ctx.session_id()
762        )
763    )]
764    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
765        adk_telemetry::info!("Starting agent execution");
766        let ctx = Self::apply_input_guardrails(ctx, self.input_guardrails.clone()).await?;
767
768        let agent_name = self.name.clone();
769        let invocation_id = ctx.invocation_id().to_string();
770        let model = self.model.clone();
771        let tools = self.tools.clone();
772        let toolsets = self.toolsets.clone();
773        let sub_agents = self.sub_agents.clone();
774
775        let instruction = self.instruction.clone();
776        let instruction_provider = self.instruction_provider.clone();
777        let global_instruction = self.global_instruction.clone();
778        let global_instruction_provider = self.global_instruction_provider.clone();
779        let skills_index = self.skills_index.clone();
780        let skill_policy = self.skill_policy.clone();
781        let max_skill_chars = self.max_skill_chars;
782        let output_key = self.output_key.clone();
783        let output_schema = self.output_schema.clone();
784        let generate_content_config = self.generate_content_config.clone();
785        let include_contents = self.include_contents;
786        let max_iterations = self.max_iterations;
787        let tool_timeout = self.tool_timeout;
788        // Clone Arc references (cheap)
789        let before_agent_callbacks = self.before_callbacks.clone();
790        let after_agent_callbacks = self.after_callbacks.clone();
791        let before_model_callbacks = self.before_model_callbacks.clone();
792        let after_model_callbacks = self.after_model_callbacks.clone();
793        let before_tool_callbacks = self.before_tool_callbacks.clone();
794        let after_tool_callbacks = self.after_tool_callbacks.clone();
795        let on_tool_error_callbacks = self.on_tool_error_callbacks.clone();
796        let after_tool_callbacks_full = self.after_tool_callbacks_full.clone();
797        let default_retry_budget = self.default_retry_budget.clone();
798        let tool_retry_budgets = self.tool_retry_budgets.clone();
799        let circuit_breaker_threshold = self.circuit_breaker_threshold;
800        let tool_confirmation_policy = self.tool_confirmation_policy.clone();
801        let disallow_transfer_to_parent = self.disallow_transfer_to_parent;
802        let disallow_transfer_to_peers = self.disallow_transfer_to_peers;
803        let output_guardrails = self.output_guardrails.clone();
804
805        let s = stream! {
806            // ===== BEFORE AGENT CALLBACKS =====
807            // Execute before the agent starts running
808            // If any returns content, skip agent execution
809            for callback in before_agent_callbacks.as_ref() {
810                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
811                    Ok(Some(content)) => {
812                        // Callback returned content - yield it and skip agent execution
813                        let mut early_event = Event::new(&invocation_id);
814                        early_event.author = agent_name.clone();
815                        early_event.llm_response.content = Some(content);
816                        yield Ok(early_event);
817
818                        // Skip rest of agent execution and go to after callbacks
819                        for after_callback in after_agent_callbacks.as_ref() {
820                            match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
821                                Ok(Some(after_content)) => {
822                                    let mut after_event = Event::new(&invocation_id);
823                                    after_event.author = agent_name.clone();
824                                    after_event.llm_response.content = Some(after_content);
825                                    yield Ok(after_event);
826                                    return;
827                                }
828                                Ok(None) => continue,
829                                Err(e) => {
830                                    yield Err(e);
831                                    return;
832                                }
833                            }
834                        }
835                        return;
836                    }
837                    Ok(None) => {
838                        // Continue to next callback
839                        continue;
840                    }
841                    Err(e) => {
842                        // Callback failed - propagate error
843                        yield Err(e);
844                        return;
845                    }
846                }
847            }
848
849            // ===== MAIN AGENT EXECUTION =====
850            let mut prompt_preamble = Vec::new();
851
852            // ===== PROCESS SKILL CONTEXT =====
853            // If skills are configured, select the most relevant skill from user input
854            // and inject it as a compact instruction block before other prompts.
855            if let Some(index) = &skills_index {
856                let user_query = ctx
857                    .user_content()
858                    .parts
859                    .iter()
860                    .filter_map(|part| match part {
861                        Part::Text { text } => Some(text.as_str()),
862                        _ => None,
863                    })
864                    .collect::<Vec<_>>()
865                    .join("\n");
866
867                if let Some((_matched, skill_block)) = select_skill_prompt_block(
868                    index.as_ref(),
869                    &user_query,
870                    &skill_policy,
871                    max_skill_chars,
872                ) {
873                    prompt_preamble.push(Content {
874                        role: "user".to_string(),
875                        parts: vec![Part::Text { text: skill_block }],
876                    });
877                }
878            }
879
880            // ===== PROCESS GLOBAL INSTRUCTION =====
881            // GlobalInstruction provides tree-wide personality/identity
882            if let Some(provider) = &global_instruction_provider {
883                // Dynamic global instruction via provider
884                let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
885                if !global_inst.is_empty() {
886                    prompt_preamble.push(Content {
887                        role: "user".to_string(),
888                        parts: vec![Part::Text { text: global_inst }],
889                    });
890                }
891            } else if let Some(ref template) = global_instruction {
892                // Static global instruction with template injection
893                let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
894                if !processed.is_empty() {
895                    prompt_preamble.push(Content {
896                        role: "user".to_string(),
897                        parts: vec![Part::Text { text: processed }],
898                    });
899                }
900            }
901
902            // ===== PROCESS AGENT INSTRUCTION =====
903            // Agent-specific instruction
904            if let Some(provider) = &instruction_provider {
905                // Dynamic instruction via provider
906                let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
907                if !inst.is_empty() {
908                    prompt_preamble.push(Content {
909                        role: "user".to_string(),
910                        parts: vec![Part::Text { text: inst }],
911                    });
912                }
913            } else if let Some(ref template) = instruction {
914                // Static instruction with template injection
915                let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
916                if !processed.is_empty() {
917                    prompt_preamble.push(Content {
918                        role: "user".to_string(),
919                        parts: vec![Part::Text { text: processed }],
920                    });
921                }
922            }
923
924            // ===== LOAD SESSION HISTORY =====
925            // Load previous conversation turns from the session
926            // NOTE: Session history already includes the current user message (added by Runner before agent runs)
927            // When transfer_targets is set, this agent was invoked via transfer — filter out
928            // other agents' events so the LLM doesn't see the parent's tool calls as its own.
929            let session_history = if !ctx.run_config().transfer_targets.is_empty() {
930                ctx.session().conversation_history_for_agent(&agent_name)
931            } else {
932                ctx.session().conversation_history()
933            };
934            let mut session_history = session_history;
935            let current_user_content = ctx.user_content().clone();
936            if let Some(index) = session_history.iter().rposition(|content| content.role == "user") {
937                session_history[index] = current_user_content.clone();
938            } else {
939                session_history.push(current_user_content.clone());
940            }
941
942            // ===== APPLY INCLUDE_CONTENTS FILTERING =====
943            // Control what conversation history the agent sees
944            let mut conversation_history = match include_contents {
945                adk_core::IncludeContents::None => {
946                    let mut filtered = prompt_preamble.clone();
947                    filtered.push(current_user_content);
948                    filtered
949                }
950                adk_core::IncludeContents::Default => {
951                    let mut full_history = prompt_preamble;
952                    full_history.extend(session_history);
953                    full_history
954                }
955            };
956
957            // ===== RESOLVE TOOLSETS =====
958            // Start with static tools, then merge in toolset-provided tools
959            let mut resolved_tools: Vec<Arc<dyn Tool>> = tools.clone();
960            let static_tool_names: std::collections::HashSet<String> =
961                tools.iter().map(|t| t.name().to_string()).collect();
962
963            // Track which toolset provided each tool for deterministic error messages
964            let mut toolset_source: std::collections::HashMap<String, String> =
965                std::collections::HashMap::new();
966
967            for toolset in &toolsets {
968                let toolset_tools = match toolset
969                    .tools(ctx.clone() as Arc<dyn ReadonlyContext>)
970                    .await
971                {
972                    Ok(t) => t,
973                    Err(e) => {
974                        yield Err(e);
975                        return;
976                    }
977                };
978                for tool in &toolset_tools {
979                    let name = tool.name().to_string();
980                    // Check static-vs-toolset conflict
981                    if static_tool_names.contains(&name) {
982                        yield Err(adk_core::AdkError::Agent(format!(
983                            "Duplicate tool name '{}': conflict between static tool and toolset '{}'",
984                            name,
985                            toolset.name()
986                        )));
987                        return;
988                    }
989                    // Check toolset-vs-toolset conflict
990                    if let Some(other_toolset_name) = toolset_source.get(&name) {
991                        yield Err(adk_core::AdkError::Agent(format!(
992                            "Duplicate tool name '{}': conflict between toolset '{}' and toolset '{}'",
993                            name,
994                            other_toolset_name,
995                            toolset.name()
996                        )));
997                        return;
998                    }
999                    toolset_source.insert(name, toolset.name().to_string());
1000                    resolved_tools.push(tool.clone());
1001                }
1002            }
1003
1004            // Build tool lookup map for O(1) access from merged resolved_tools
1005            let tool_map: std::collections::HashMap<String, Arc<dyn Tool>> = resolved_tools
1006                .iter()
1007                .map(|t| (t.name().to_string(), t.clone()))
1008                .collect();
1009
1010            // Helper: extract long-running tool IDs from content
1011            let collect_long_running_ids = |content: &Content| -> Vec<String> {
1012                content.parts.iter()
1013                    .filter_map(|p| {
1014                        if let Part::FunctionCall { name, .. } = p {
1015                            if let Some(tool) = tool_map.get(name) {
1016                                if tool.is_long_running() {
1017                                    return Some(name.clone());
1018                                }
1019                            }
1020                        }
1021                        None
1022                    })
1023                    .collect()
1024            };
1025
1026            // Build tool declarations for Gemini
1027            // Uses enhanced_description() which includes NOTE for long-running tools
1028            let mut tool_declarations = std::collections::HashMap::new();
1029            for tool in &resolved_tools {
1030                // Build FunctionDeclaration JSON with enhanced description
1031                // For long-running tools, this includes a warning not to call again if pending
1032                let mut decl = serde_json::json!({
1033                    "name": tool.name(),
1034                    "description": tool.enhanced_description(),
1035                });
1036
1037                if let Some(params) = tool.parameters_schema() {
1038                    decl["parameters"] = params;
1039                }
1040
1041                if let Some(response) = tool.response_schema() {
1042                    decl["response"] = response;
1043                }
1044
1045                tool_declarations.insert(tool.name().to_string(), decl);
1046            }
1047
1048            // Build the list of valid transfer targets.
1049            // Sources: sub_agents (always) + transfer_targets from RunConfig
1050            // (set by the runner to include parent/peers for transferred agents).
1051            // Apply disallow_transfer_to_parent / disallow_transfer_to_peers filtering.
1052            let mut valid_transfer_targets: Vec<String> = sub_agents
1053                .iter()
1054                .map(|a| a.name().to_string())
1055                .collect();
1056
1057            // Merge in runner-provided targets (parent, peers) from RunConfig
1058            let run_config_targets = &ctx.run_config().transfer_targets;
1059            let parent_agent_name = ctx.run_config().parent_agent.clone();
1060            let sub_agent_names: std::collections::HashSet<&str> = sub_agents
1061                .iter()
1062                .map(|a| a.name())
1063                .collect();
1064
1065            for target in run_config_targets {
1066                // Skip if already in the list (from sub_agents)
1067                if sub_agent_names.contains(target.as_str()) {
1068                    continue;
1069                }
1070
1071                // Apply disallow flags
1072                let is_parent = parent_agent_name.as_deref() == Some(target.as_str());
1073                if is_parent && disallow_transfer_to_parent {
1074                    continue;
1075                }
1076                if !is_parent && disallow_transfer_to_peers {
1077                    continue;
1078                }
1079
1080                valid_transfer_targets.push(target.clone());
1081            }
1082
1083            // Inject transfer_to_agent tool if there are valid targets
1084            if !valid_transfer_targets.is_empty() {
1085                let transfer_tool_name = "transfer_to_agent";
1086                let transfer_tool_decl = serde_json::json!({
1087                    "name": transfer_tool_name,
1088                    "description": format!(
1089                        "Transfer execution to another agent. Valid targets: {}",
1090                        valid_transfer_targets.join(", ")
1091                    ),
1092                    "parameters": {
1093                        "type": "object",
1094                        "properties": {
1095                            "agent_name": {
1096                                "type": "string",
1097                                "description": "The name of the agent to transfer to.",
1098                                "enum": valid_transfer_targets
1099                            }
1100                        },
1101                        "required": ["agent_name"]
1102                    }
1103                });
1104                tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
1105            }
1106
1107
1108            // ===== CIRCUIT BREAKER STATE =====
1109            // Created fresh per invocation so it resets between runs.
1110            let mut circuit_breaker_state = circuit_breaker_threshold.map(CircuitBreakerState::new);
1111
1112            // Multi-turn loop with max iterations
1113            let mut iteration = 0;
1114
1115            loop {
1116                iteration += 1;
1117                if iteration > max_iterations {
1118                    yield Err(adk_core::AdkError::Agent(
1119                        format!("Max iterations ({}) exceeded", max_iterations)
1120                    ));
1121                    return;
1122                }
1123
1124                // Build request with conversation history
1125                // Merge agent-level generate_content_config with output_schema.
1126                // Agent-level config provides defaults (temperature, top_p, etc.),
1127                // output_schema is layered on top as response_schema.
1128                // If the runner set a cached_content name (via automatic cache lifecycle),
1129                // merge it into the config so the provider can reuse cached content.
1130                let config = match (&generate_content_config, &output_schema) {
1131                    (Some(base), Some(schema)) => {
1132                        let mut merged = base.clone();
1133                        merged.response_schema = Some(schema.clone());
1134                        Some(merged)
1135                    }
1136                    (Some(base), None) => Some(base.clone()),
1137                    (None, Some(schema)) => Some(adk_core::GenerateContentConfig {
1138                        response_schema: Some(schema.clone()),
1139                        ..Default::default()
1140                    }),
1141                    (None, None) => None,
1142                };
1143
1144                // Layer cached_content from RunConfig onto the request config.
1145                let config = if let Some(ref cached) = ctx.run_config().cached_content {
1146                    let mut cfg = config.unwrap_or_default();
1147                    // Only set if the agent hasn't already specified one
1148                    if cfg.cached_content.is_none() {
1149                        cfg.cached_content = Some(cached.clone());
1150                    }
1151                    Some(cfg)
1152                } else {
1153                    config
1154                };
1155
1156                let request = LlmRequest {
1157                    model: model.name().to_string(),
1158                    contents: conversation_history.clone(),
1159                    tools: tool_declarations.clone(),
1160                    config,
1161                };
1162
1163                // ===== BEFORE MODEL CALLBACKS =====
1164                // These can modify the request or skip the model call by returning a response
1165                let mut current_request = request;
1166                let mut model_response_override = None;
1167                for callback in before_model_callbacks.as_ref() {
1168                    match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
1169                        Ok(BeforeModelResult::Continue(modified_request)) => {
1170                            // Callback may have modified the request, continue with it
1171                            current_request = modified_request;
1172                        }
1173                        Ok(BeforeModelResult::Skip(response)) => {
1174                            // Callback returned a response - skip model call
1175                            model_response_override = Some(response);
1176                            break;
1177                        }
1178                        Err(e) => {
1179                            // Callback failed - propagate error
1180                            yield Err(e);
1181                            return;
1182                        }
1183                    }
1184                }
1185                let request = current_request;
1186
1187                // Determine streaming source: cached response or real model
1188                let mut accumulated_content: Option<Content> = None;
1189
1190                if let Some(cached_response) = model_response_override {
1191                    // Use callback-provided response (e.g., from cache)
1192                    // Yield it as an event
1193                    accumulated_content = cached_response.content.clone();
1194                    normalize_option_content(&mut accumulated_content);
1195                    if let Some(content) = accumulated_content.take() {
1196                        let has_function_calls = content
1197                            .parts
1198                            .iter()
1199                            .any(|part| matches!(part, Part::FunctionCall { .. }));
1200                        let content = if has_function_calls {
1201                            content
1202                        } else {
1203                            Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1204                        };
1205                        accumulated_content = Some(content);
1206                    }
1207
1208                    let mut cached_event = Event::new(&invocation_id);
1209                    cached_event.author = agent_name.clone();
1210                    cached_event.llm_response.content = accumulated_content.clone();
1211                    cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
1212                    cached_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), serde_json::to_string(&request).unwrap_or_default());
1213                    cached_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&cached_response).unwrap_or_default());
1214
1215                    // Populate long_running_tool_ids for function calls from long-running tools
1216                    if let Some(ref content) = accumulated_content {
1217                        cached_event.long_running_tool_ids = collect_long_running_ids(content);
1218                    }
1219
1220                    yield Ok(cached_event);
1221                } else {
1222                    // Record LLM request for tracing
1223                    let request_json = serde_json::to_string(&request).unwrap_or_default();
1224
1225                    // Create call_llm span with GCP attributes (works for all model types)
1226                    let llm_ts = std::time::SystemTime::now()
1227                        .duration_since(std::time::UNIX_EPOCH)
1228                        .unwrap_or_default()
1229                        .as_nanos();
1230                    let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
1231                    let llm_span = tracing::info_span!(
1232                        "call_llm",
1233                        "gcp.vertex.agent.event_id" = %llm_event_id,
1234                        "gcp.vertex.agent.invocation_id" = %invocation_id,
1235                        "gcp.vertex.agent.session_id" = %ctx.session_id(),
1236                        "gen_ai.conversation.id" = %ctx.session_id(),
1237                        "gcp.vertex.agent.llm_request" = %request_json,
1238                        "gcp.vertex.agent.llm_response" = tracing::field::Empty  // Placeholder for later recording
1239                    );
1240                    let _llm_guard = llm_span.enter();
1241
1242                    // Check streaming mode from run config
1243                    use adk_core::StreamingMode;
1244                    let streaming_mode = ctx.run_config().streaming_mode;
1245                    let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi)
1246                        && output_guardrails.is_empty();
1247
1248                    // Always use streaming internally for LLM calls
1249                    let mut response_stream = model.generate_content(request, true).await?;
1250
1251                    use futures::StreamExt;
1252
1253                    // Track last chunk for final event metadata (used in None mode)
1254                    let mut last_chunk: Option<LlmResponse> = None;
1255
1256                    // Stream and process chunks with AfterModel callbacks
1257                    while let Some(chunk_result) = response_stream.next().await {
1258                        let mut chunk = match chunk_result {
1259                            Ok(c) => c,
1260                            Err(e) => {
1261                                yield Err(e);
1262                                return;
1263                            }
1264                        };
1265
1266                        // ===== AFTER MODEL CALLBACKS (per chunk) =====
1267                        // Callbacks can modify each streaming chunk
1268                        for callback in after_model_callbacks.as_ref() {
1269                            match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
1270                                Ok(Some(modified_chunk)) => {
1271                                    // Callback modified this chunk
1272                                    chunk = modified_chunk;
1273                                    break;
1274                                }
1275                                Ok(None) => {
1276                                    // Continue to next callback
1277                                    continue;
1278                                }
1279                                Err(e) => {
1280                                    // Callback failed - propagate error
1281                                    yield Err(e);
1282                                    return;
1283                                }
1284                            }
1285                        }
1286
1287                        normalize_option_content(&mut chunk.content);
1288
1289                        // Accumulate content for conversation history (always needed)
1290                        if let Some(chunk_content) = chunk.content.clone() {
1291                            if let Some(ref mut acc) = accumulated_content {
1292                                acc.parts.extend(chunk_content.parts);
1293                            } else {
1294                                accumulated_content = Some(chunk_content);
1295                            }
1296                        }
1297
1298                        // For SSE/Bidi mode: yield each chunk immediately with stable event ID
1299                        if should_stream_to_client {
1300                            let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
1301                            partial_event.author = agent_name.clone();
1302                            partial_event.llm_request = Some(request_json.clone());
1303                            partial_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1304                            partial_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&chunk).unwrap_or_default());
1305                            partial_event.llm_response.partial = chunk.partial;
1306                            partial_event.llm_response.turn_complete = chunk.turn_complete;
1307                            partial_event.llm_response.finish_reason = chunk.finish_reason;
1308                            partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
1309                            partial_event.llm_response.content = chunk.content.clone();
1310
1311                            // Populate long_running_tool_ids
1312                            if let Some(ref content) = chunk.content {
1313                                partial_event.long_running_tool_ids = collect_long_running_ids(content);
1314                            }
1315
1316                            yield Ok(partial_event);
1317                        }
1318
1319                        // Store last chunk for final event metadata
1320                        last_chunk = Some(chunk.clone());
1321
1322                        // Check if turn is complete
1323                        if chunk.turn_complete {
1324                            break;
1325                        }
1326                    }
1327
1328                    // For None mode: yield single final event with accumulated content
1329                    if !should_stream_to_client {
1330                        if let Some(content) = accumulated_content.take() {
1331                            let has_function_calls = content
1332                                .parts
1333                                .iter()
1334                                .any(|part| matches!(part, Part::FunctionCall { .. }));
1335                            let content = if has_function_calls {
1336                                content
1337                            } else {
1338                                Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1339                            };
1340                            accumulated_content = Some(content);
1341                        }
1342
1343                        let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
1344                        final_event.author = agent_name.clone();
1345                        final_event.llm_request = Some(request_json.clone());
1346                        final_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1347                        final_event.llm_response.content = accumulated_content.clone();
1348                        final_event.llm_response.partial = false;
1349                        final_event.llm_response.turn_complete = true;
1350
1351                        // Copy metadata from last chunk
1352                        if let Some(ref last) = last_chunk {
1353                            final_event.llm_response.finish_reason = last.finish_reason;
1354                            final_event.llm_response.usage_metadata = last.usage_metadata.clone();
1355                            final_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(last).unwrap_or_default());
1356                        }
1357
1358                        // Populate long_running_tool_ids
1359                        if let Some(ref content) = accumulated_content {
1360                            final_event.long_running_tool_ids = collect_long_running_ids(content);
1361                        }
1362
1363                        yield Ok(final_event);
1364                    }
1365
1366                    // Record LLM response to span before guard drops
1367                    if let Some(ref content) = accumulated_content {
1368                        let response_json = serde_json::to_string(content).unwrap_or_default();
1369                        llm_span.record("gcp.vertex.agent.llm_response", &response_json);
1370                    }
1371                }
1372
1373                // After streaming/caching completes, check for function calls in accumulated content
1374                let function_call_names: Vec<String> = accumulated_content.as_ref()
1375                    .map(|c| c.parts.iter()
1376                        .filter_map(|p| {
1377                            if let Part::FunctionCall { name, .. } = p {
1378                                Some(name.clone())
1379                            } else {
1380                                None
1381                            }
1382                        })
1383                        .collect())
1384                    .unwrap_or_default();
1385
1386                let has_function_calls = !function_call_names.is_empty();
1387
1388                // Check if ALL function calls are from long-running tools
1389                // If so, we should NOT continue the loop - the tool returned a pending status
1390                // and the agent/client will poll for completion later
1391                let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
1392                    tool_map.get(name)
1393                        .map(|t| t.is_long_running())
1394                        .unwrap_or(false)
1395                });
1396
1397                // Add final content to history
1398                if let Some(ref content) = accumulated_content {
1399                    conversation_history.push(content.clone());
1400
1401                    // Handle output_key: save final agent output to state_delta
1402                    if let Some(ref output_key) = output_key {
1403                        if !has_function_calls {  // Only save if not calling tools
1404                            let mut text_parts = String::new();
1405                            for part in &content.parts {
1406                                if let Part::Text { text } = part {
1407                                    text_parts.push_str(text);
1408                                }
1409                            }
1410                            if !text_parts.is_empty() {
1411                                // Yield a final state update event
1412                                let mut state_event = Event::new(&invocation_id);
1413                                state_event.author = agent_name.clone();
1414                                state_event.actions.state_delta.insert(
1415                                    output_key.clone(),
1416                                    serde_json::Value::String(text_parts),
1417                                );
1418                                yield Ok(state_event);
1419                            }
1420                        }
1421                    }
1422                }
1423
1424                if !has_function_calls {
1425                    // No function calls, we're done
1426                    // Record LLM response for tracing
1427                    if let Some(ref content) = accumulated_content {
1428                        let response_json = serde_json::to_string(content).unwrap_or_default();
1429                        tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
1430                    }
1431
1432                    tracing::info!(agent.name = %agent_name, "Agent execution complete");
1433                    break;
1434                }
1435
1436                // Execute function calls and add responses to history
1437                if let Some(content) = &accumulated_content {
1438                    let mut tool_call_index = 0usize;
1439                    for part in &content.parts {
1440                        if let Part::FunctionCall { name, args, id, .. } = part {
1441                            let fallback_call_id =
1442                                format!("{}_{}_{}", invocation_id, name, tool_call_index);
1443                            tool_call_index += 1;
1444                            let function_call_id = id.clone().unwrap_or(fallback_call_id);
1445
1446                            // Handle transfer_to_agent specially
1447                            if name == "transfer_to_agent" {
1448                                let target_agent = args.get("agent_name")
1449                                    .and_then(|v| v.as_str())
1450                                    .unwrap_or_default()
1451                                    .to_string();
1452
1453                                // Validate against the full set of valid transfer targets
1454                                // (sub-agents + parent/peers from RunConfig)
1455                                let valid_target = valid_transfer_targets.iter().any(|n| n == &target_agent);
1456                                if !valid_target {
1457                                    // Return error to LLM so it can try again
1458                                    let error_content = Content {
1459                                        role: "function".to_string(),
1460                                        parts: vec![Part::FunctionResponse {
1461                                            function_response: FunctionResponseData {
1462                                                name: name.clone(),
1463                                                response: serde_json::json!({
1464                                                    "error": format!(
1465                                                        "Agent '{}' not found. Available agents: {:?}",
1466                                                        target_agent,
1467                                                        valid_transfer_targets
1468                                                    )
1469                                                }),
1470                                            },
1471                                            id: id.clone(),
1472                                        }],
1473                                    };
1474                                    conversation_history.push(error_content.clone());
1475
1476                                    let mut error_event = Event::new(&invocation_id);
1477                                    error_event.author = agent_name.clone();
1478                                    error_event.llm_response.content = Some(error_content);
1479                                    yield Ok(error_event);
1480                                    continue;
1481                                }
1482
1483                                let mut transfer_event = Event::new(&invocation_id);
1484                                transfer_event.author = agent_name.clone();
1485                                transfer_event.actions.transfer_to_agent = Some(target_agent);
1486
1487                                yield Ok(transfer_event);
1488                                return;
1489                            }
1490
1491                            // ===== BEFORE TOOL CALLBACKS =====
1492                            // Allows policy checks or callback-provided short-circuit responses.
1493                            let mut tool_actions = EventActions::default();
1494                            let mut response_content: Option<Content> = None;
1495                            let mut run_after_tool_callbacks = true;
1496                            let mut tool_outcome_for_callback: Option<ToolOutcome> = None;
1497                            // Track tool ref, args, and response for AfterToolCallbackFull
1498                            let mut executed_tool: Option<Arc<dyn Tool>> = None;
1499                            let mut executed_tool_response: Option<serde_json::Value> = None;
1500
1501                            // ===== TOOL CONFIRMATION POLICY =====
1502                            // If configured and no decision is provided, emit an interrupt event
1503                            // before execution. Callers can resume by re-running with a decision
1504                            // in RunConfig.tool_confirmation_decisions.
1505                            if tool_confirmation_policy.requires_confirmation(name) {
1506                                match ctx.run_config().tool_confirmation_decisions.get(name).copied()
1507                                {
1508                                    Some(ToolConfirmationDecision::Approve) => {
1509                                        tool_actions.tool_confirmation_decision =
1510                                            Some(ToolConfirmationDecision::Approve);
1511                                    }
1512                                    Some(ToolConfirmationDecision::Deny) => {
1513                                        tool_actions.tool_confirmation_decision =
1514                                            Some(ToolConfirmationDecision::Deny);
1515                                        response_content = Some(Content {
1516                                            role: "function".to_string(),
1517                                            parts: vec![Part::FunctionResponse {
1518                                                function_response: FunctionResponseData {
1519                                                    name: name.clone(),
1520                                                    response: serde_json::json!({
1521                                                        "error": format!(
1522                                                            "Tool '{}' execution denied by confirmation policy",
1523                                                            name
1524                                                        ),
1525                                                    }),
1526                                                },
1527                                                id: id.clone(),
1528                                            }],
1529                                        });
1530                                        run_after_tool_callbacks = false;
1531                                    }
1532                                    None => {
1533                                        let mut confirmation_event = Event::new(&invocation_id);
1534                                        confirmation_event.author = agent_name.clone();
1535                                        confirmation_event.llm_response.interrupted = true;
1536                                        confirmation_event.llm_response.turn_complete = true;
1537                                        confirmation_event.llm_response.content = Some(Content {
1538                                            role: "model".to_string(),
1539                                            parts: vec![Part::Text {
1540                                                text: format!(
1541                                                    "Tool confirmation required for '{}'. Provide approve/deny decision to continue.",
1542                                                    name
1543                                                ),
1544                                            }],
1545                                        });
1546                                        confirmation_event.actions.tool_confirmation =
1547                                            Some(ToolConfirmationRequest {
1548                                                tool_name: name.clone(),
1549                                                function_call_id: Some(function_call_id),
1550                                                args: args.clone(),
1551                                            });
1552                                        yield Ok(confirmation_event);
1553                                        return;
1554                                    }
1555                                }
1556                            }
1557
1558                            if response_content.is_none() {
1559                                for callback in before_tool_callbacks.as_ref() {
1560                                    match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1561                                        Ok(Some(content)) => {
1562                                            response_content = Some(content);
1563                                            break;
1564                                        }
1565                                        Ok(None) => continue,
1566                                        Err(e) => {
1567                                            yield Err(e);
1568                                            return;
1569                                        }
1570                                    }
1571                                }
1572                            }
1573
1574                            // Find and execute tool unless callbacks already short-circuited.
1575                            if response_content.is_none() {
1576                                // ===== CIRCUIT BREAKER CHECK =====
1577                                // If the circuit breaker is open for this tool, skip execution
1578                                // and return an immediate error response to the LLM.
1579                                if let Some(ref cb_state) = circuit_breaker_state {
1580                                    if cb_state.is_open(name) {
1581                                        let error_msg = format!(
1582                                            "Tool '{}' is temporarily disabled after {} consecutive failures",
1583                                            name, cb_state.threshold
1584                                        );
1585                                        tracing::warn!(tool.name = %name, "circuit breaker open, skipping tool execution");
1586                                        response_content = Some(Content {
1587                                            role: "function".to_string(),
1588                                            parts: vec![Part::FunctionResponse {
1589                                                function_response: FunctionResponseData {
1590                                                    name: name.clone(),
1591                                                    response: serde_json::json!({ "error": error_msg }),
1592                                                },
1593                                                id: id.clone(),
1594                                            }],
1595                                        });
1596                                        run_after_tool_callbacks = false;
1597                                    }
1598                                }
1599                            }
1600
1601                            if response_content.is_none() {
1602                                if let Some(tool) = tool_map.get(name) {
1603                                    // ✅ Use AgentToolContext that preserves parent context
1604                                    let tool_ctx: Arc<dyn ToolContext> = Arc::new(AgentToolContext::new(
1605                                        ctx.clone(),
1606                                        function_call_id.clone(),
1607                                    ));
1608
1609                                    // Create span name following adk-go pattern: "execute_tool {name}"
1610                                    let span_name = format!("execute_tool {}", name);
1611                                    let tool_span = tracing::info_span!(
1612                                        "",
1613                                        otel.name = %span_name,
1614                                        tool.name = %name,
1615                                        "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
1616                                        "gcp.vertex.agent.invocation_id" = %invocation_id,
1617                                        "gcp.vertex.agent.session_id" = %ctx.session_id(),
1618                                        "gen_ai.conversation.id" = %ctx.session_id()
1619                                    );
1620
1621                                    // ===== RETRY BUDGET RESOLUTION =====
1622                                    // Look up per-tool budget first, fall back to default budget.
1623                                    // When no budget is configured, max_attempts is 1 (single attempt, no retry).
1624                                    let budget = tool_retry_budgets.get(name)
1625                                        .or(default_retry_budget.as_ref());
1626                                    let max_attempts = budget.map(|b| b.max_retries + 1).unwrap_or(1);
1627                                    let retry_delay = budget.map(|b| b.delay).unwrap_or_default();
1628
1629                                    // Time tool execution with retry loop.
1630                                    // Derive success from Ok/Err/timeout path — never from JSON inspection.
1631                                    let tool_clone = tool.clone();
1632                                    let tool_start = std::time::Instant::now();
1633
1634                                    let mut last_error = String::new();
1635                                    let mut final_attempt: u32 = 0;
1636                                    let mut retry_result: Option<serde_json::Value> = None;
1637
1638                                    for attempt in 0..max_attempts {
1639                                        final_attempt = attempt;
1640
1641                                        if attempt > 0 {
1642                                            tokio::time::sleep(retry_delay).await;
1643                                        }
1644
1645                                        match async {
1646                                            tracing::info!(tool.name = %name, tool.args = %args, attempt = attempt, "tool_call");
1647                                            let exec_future = tool_clone.execute(tool_ctx.clone(), args.clone());
1648                                            tokio::time::timeout(tool_timeout, exec_future).await
1649                                        }.instrument(tool_span.clone()).await {
1650                                            Ok(Ok(value)) => {
1651                                                tracing::info!(tool.name = %name, tool.result = %value, "tool_result");
1652                                                retry_result = Some(value);
1653                                                break;
1654                                            }
1655                                            Ok(Err(e)) => {
1656                                                last_error = e.to_string();
1657                                                if attempt + 1 < max_attempts {
1658                                                    tracing::warn!(tool.name = %name, attempt = attempt, error = %last_error, "tool execution failed, retrying");
1659                                                } else {
1660                                                    tracing::warn!(tool.name = %name, error = %last_error, "tool_error");
1661                                                }
1662                                            }
1663                                            Err(_) => {
1664                                                last_error = format!(
1665                                                    "Tool '{}' timed out after {} seconds",
1666                                                    name, tool_timeout.as_secs()
1667                                                );
1668                                                if attempt + 1 < max_attempts {
1669                                                    tracing::warn!(tool.name = %name, attempt = attempt, timeout_secs = tool_timeout.as_secs(), "tool timed out, retrying");
1670                                                } else {
1671                                                    tracing::warn!(tool.name = %name, timeout_secs = tool_timeout.as_secs(), "tool_timeout");
1672                                                }
1673                                            }
1674                                        }
1675                                    }
1676
1677                                    let tool_duration = tool_start.elapsed();
1678
1679                                    // Derive final success/error/response from retry loop outcome
1680                                    let (tool_success, tool_error_message, function_response) = match retry_result {
1681                                        Some(value) => (true, None, value),
1682                                        None => (false, Some(last_error.clone()), serde_json::json!({ "error": last_error })),
1683                                    };
1684
1685                                    // Build ToolOutcome from execution state — never from JSON inspection.
1686                                    // Emitted for the final attempt only, with correct attempt number (0-based).
1687                                    let outcome = ToolOutcome {
1688                                        tool_name: name.clone(),
1689                                        tool_args: args.clone(),
1690                                        success: tool_success,
1691                                        duration: tool_duration,
1692                                        error_message: tool_error_message.clone(),
1693                                        attempt: final_attempt,
1694                                    };
1695                                    tool_outcome_for_callback = Some(outcome);
1696
1697                                    // ===== CIRCUIT BREAKER RECORDING =====
1698                                    if let Some(ref mut cb_state) = circuit_breaker_state {
1699                                        cb_state.record(tool_outcome_for_callback.as_ref().unwrap());
1700                                    }
1701
1702                                    // ===== ON TOOL ERROR CALLBACKS =====
1703                                    // When tool failed (after all retries exhausted), try on_tool_error
1704                                    // callbacks for a fallback result. Only invoked after ALL retries
1705                                    // are exhausted, not on each individual retry attempt.
1706                                    let final_function_response = if !tool_success {
1707                                        let mut fallback_result = None;
1708                                        let error_msg = tool_error_message.clone().unwrap_or_default();
1709                                        for callback in on_tool_error_callbacks.as_ref() {
1710                                            match callback(
1711                                                ctx.clone() as Arc<dyn CallbackContext>,
1712                                                tool.clone(),
1713                                                args.clone(),
1714                                                error_msg.clone(),
1715                                            ).await {
1716                                                Ok(Some(result)) => {
1717                                                    fallback_result = Some(result);
1718                                                    break;
1719                                                }
1720                                                Ok(None) => continue,
1721                                                Err(e) => {
1722                                                    tracing::warn!(error = %e, "on_tool_error callback failed");
1723                                                    break;
1724                                                }
1725                                            }
1726                                        }
1727                                        if let Some(fallback) = fallback_result {
1728                                            fallback
1729                                        } else {
1730                                            function_response
1731                                        }
1732                                    } else {
1733                                        function_response
1734                                    };
1735
1736                                    let confirmation_decision =
1737                                        tool_actions.tool_confirmation_decision;
1738                                    tool_actions = tool_ctx.actions();
1739                                    if tool_actions.tool_confirmation_decision.is_none() {
1740                                        tool_actions.tool_confirmation_decision =
1741                                            confirmation_decision;
1742                                    }
1743                                    // Capture tool and response for AfterToolCallbackFull
1744                                    executed_tool = Some(tool.clone());
1745                                    executed_tool_response = Some(final_function_response.clone());
1746                                    response_content = Some(Content {
1747                                        role: "function".to_string(),
1748                                        parts: vec![Part::FunctionResponse {
1749                                            function_response: FunctionResponseData {
1750                                                name: name.clone(),
1751                                                response: final_function_response,
1752                                            },
1753                                            id: id.clone(),
1754                                        }],
1755                                    });
1756                                } else {
1757                                    response_content = Some(Content {
1758                                        role: "function".to_string(),
1759                                        parts: vec![Part::FunctionResponse {
1760                                            function_response: FunctionResponseData {
1761                                                name: name.clone(),
1762                                                response: serde_json::json!({
1763                                                    "error": format!("Tool {} not found", name)
1764                                                }),
1765                                            },
1766                                            id: id.clone(),
1767                                        }],
1768                                    });
1769                                }
1770                            }
1771
1772                            // ===== AFTER TOOL CALLBACKS =====
1773                            // Allows post-processing of tool outputs or audit side effects.
1774                            let mut response_content = response_content.expect("tool response content is set");
1775                            if run_after_tool_callbacks {
1776                                // Build callback context with ToolOutcome if available
1777                                let cb_ctx: Arc<dyn CallbackContext> = match tool_outcome_for_callback {
1778                                    Some(outcome) => Arc::new(ToolCallbackContext {
1779                                        inner: ctx.clone() as Arc<dyn CallbackContext>,
1780                                        outcome,
1781                                    }),
1782                                    None => ctx.clone() as Arc<dyn CallbackContext>,
1783                                };
1784                                for callback in after_tool_callbacks.as_ref() {
1785                                    match callback(cb_ctx.clone()).await {
1786                                        Ok(Some(modified_content)) => {
1787                                            response_content = modified_content;
1788                                            break;
1789                                        }
1790                                        Ok(None) => continue,
1791                                        Err(e) => {
1792                                            yield Err(e);
1793                                            return;
1794                                        }
1795                                    }
1796                                }
1797
1798                                // ===== AFTER TOOL CALLBACKS (FULL / V2) =====
1799                                // Rich callbacks that receive tool, args, and response value.
1800                                // Aligned with Python/Go ADK after_tool_callback signature.
1801                                // Run after legacy AfterToolCallback chain.
1802                                if let (Some(tool_ref), Some(tool_resp)) = (&executed_tool, executed_tool_response) {
1803                                    for callback in after_tool_callbacks_full.as_ref() {
1804                                        match callback(
1805                                            cb_ctx.clone(),
1806                                            tool_ref.clone(),
1807                                            args.clone(),
1808                                            tool_resp.clone(),
1809                                        ).await {
1810                                            Ok(Some(modified_value)) => {
1811                                                // Replace the function response in the content
1812                                                response_content = Content {
1813                                                    role: "function".to_string(),
1814                                                    parts: vec![Part::FunctionResponse {
1815                                                        function_response: FunctionResponseData {
1816                                                            name: name.clone(),
1817                                                            response: modified_value,
1818                                                        },
1819                                                        id: id.clone(),
1820                                                    }],
1821                                                };
1822                                                break;
1823                                            }
1824                                            Ok(None) => continue,
1825                                            Err(e) => {
1826                                                yield Err(e);
1827                                                return;
1828                                            }
1829                                        }
1830                                    }
1831                                }
1832                            }
1833
1834                            // Yield tool execution event
1835                            let mut tool_event = Event::new(&invocation_id);
1836                            tool_event.author = agent_name.clone();
1837                            tool_event.actions = tool_actions.clone();
1838                            tool_event.llm_response.content = Some(response_content.clone());
1839                            yield Ok(tool_event);
1840
1841                            // Check if tool requested escalation or skip_summarization
1842                            if tool_actions.escalate || tool_actions.skip_summarization {
1843                                // Tool wants to terminate agent loop
1844                                return;
1845                            }
1846
1847                            // Add function response to history
1848                            conversation_history.push(response_content);
1849                        }
1850                    }
1851                }
1852
1853                // If all function calls were from long-running tools, we need ONE more model call
1854                // to let the model generate a user-friendly response about the pending task
1855                // But we mark this as the final iteration to prevent infinite loops
1856                if all_calls_are_long_running {
1857                    // Continue to next iteration for model to respond, but this will be the last
1858                    // The model will see the tool response and generate text like "Started task X..."
1859                    // On next iteration, there won't be function calls, so we'll break naturally
1860                }
1861            }
1862
1863            // ===== AFTER AGENT CALLBACKS =====
1864            // Execute after the agent completes
1865            for callback in after_agent_callbacks.as_ref() {
1866                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1867                    Ok(Some(content)) => {
1868                        // Callback returned content - yield it
1869                        let mut after_event = Event::new(&invocation_id);
1870                        after_event.author = agent_name.clone();
1871                        after_event.llm_response.content = Some(content);
1872                        yield Ok(after_event);
1873                        break; // First callback that returns content wins
1874                    }
1875                    Ok(None) => {
1876                        // Continue to next callback
1877                        continue;
1878                    }
1879                    Err(e) => {
1880                        // Callback failed - propagate error
1881                        yield Err(e);
1882                        return;
1883                    }
1884                }
1885            }
1886        };
1887
1888        Ok(Box::pin(s))
1889    }
1890}