adk_agent/
llm_agent.rs

1use adk_core::{
2    AfterAgentCallback, AfterModelCallback, AfterToolCallback, Agent, BeforeAgentCallback,
3    BeforeModelCallback, BeforeModelResult, BeforeToolCallback, CallbackContext, Content, Event,
4    EventActions, FunctionResponseData, GlobalInstructionProvider, InstructionProvider,
5    InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry, Part, ReadonlyContext, Result,
6    Tool, ToolContext,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use tracing::Instrument;
12
13use crate::guardrails::GuardrailSet;
14
15/// Default maximum number of LLM round-trips (iterations) before the agent stops.
16pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
17
18pub struct LlmAgent {
19    name: String,
20    description: String,
21    model: Arc<dyn Llm>,
22    instruction: Option<String>,
23    instruction_provider: Option<Arc<InstructionProvider>>,
24    global_instruction: Option<String>,
25    global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
26    #[allow(dead_code)] // Part of public API via builder
27    input_schema: Option<serde_json::Value>,
28    output_schema: Option<serde_json::Value>,
29    #[allow(dead_code)] // Part of public API via builder
30    disallow_transfer_to_parent: bool,
31    #[allow(dead_code)] // Part of public API via builder
32    disallow_transfer_to_peers: bool,
33    include_contents: adk_core::IncludeContents,
34    tools: Vec<Arc<dyn Tool>>,
35    sub_agents: Vec<Arc<dyn Agent>>,
36    output_key: Option<String>,
37    /// Maximum number of LLM round-trips before stopping
38    max_iterations: u32,
39    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
40    after_callbacks: Arc<Vec<AfterAgentCallback>>,
41    before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
42    after_model_callbacks: Arc<Vec<AfterModelCallback>>,
43    before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
44    after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
45    #[allow(dead_code)] // Used when guardrails feature is enabled
46    input_guardrails: GuardrailSet,
47    #[allow(dead_code)] // Used when guardrails feature is enabled
48    output_guardrails: GuardrailSet,
49}
50
51impl std::fmt::Debug for LlmAgent {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("LlmAgent")
54            .field("name", &self.name)
55            .field("description", &self.description)
56            .field("model", &self.model.name())
57            .field("instruction", &self.instruction)
58            .field("tools_count", &self.tools.len())
59            .field("sub_agents_count", &self.sub_agents.len())
60            .finish()
61    }
62}
63
64pub struct LlmAgentBuilder {
65    name: String,
66    description: Option<String>,
67    model: Option<Arc<dyn Llm>>,
68    instruction: Option<String>,
69    instruction_provider: Option<Arc<InstructionProvider>>,
70    global_instruction: Option<String>,
71    global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
72    input_schema: Option<serde_json::Value>,
73    output_schema: Option<serde_json::Value>,
74    disallow_transfer_to_parent: bool,
75    disallow_transfer_to_peers: bool,
76    include_contents: adk_core::IncludeContents,
77    tools: Vec<Arc<dyn Tool>>,
78    sub_agents: Vec<Arc<dyn Agent>>,
79    output_key: Option<String>,
80    max_iterations: u32,
81    before_callbacks: Vec<BeforeAgentCallback>,
82    after_callbacks: Vec<AfterAgentCallback>,
83    before_model_callbacks: Vec<BeforeModelCallback>,
84    after_model_callbacks: Vec<AfterModelCallback>,
85    before_tool_callbacks: Vec<BeforeToolCallback>,
86    after_tool_callbacks: Vec<AfterToolCallback>,
87    input_guardrails: GuardrailSet,
88    output_guardrails: GuardrailSet,
89}
90
91impl LlmAgentBuilder {
92    pub fn new(name: impl Into<String>) -> Self {
93        Self {
94            name: name.into(),
95            description: None,
96            model: None,
97            instruction: None,
98            instruction_provider: None,
99            global_instruction: None,
100            global_instruction_provider: None,
101            input_schema: None,
102            output_schema: None,
103            disallow_transfer_to_parent: false,
104            disallow_transfer_to_peers: false,
105            include_contents: adk_core::IncludeContents::Default,
106            tools: Vec::new(),
107            sub_agents: Vec::new(),
108            output_key: None,
109            max_iterations: DEFAULT_MAX_ITERATIONS,
110            before_callbacks: Vec::new(),
111            after_callbacks: Vec::new(),
112            before_model_callbacks: Vec::new(),
113            after_model_callbacks: Vec::new(),
114            before_tool_callbacks: Vec::new(),
115            after_tool_callbacks: Vec::new(),
116            input_guardrails: GuardrailSet::new(),
117            output_guardrails: GuardrailSet::new(),
118        }
119    }
120
121    pub fn description(mut self, desc: impl Into<String>) -> Self {
122        self.description = Some(desc.into());
123        self
124    }
125
126    pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
127        self.model = Some(model);
128        self
129    }
130
131    pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
132        self.instruction = Some(instruction.into());
133        self
134    }
135
136    pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
137        self.instruction_provider = Some(Arc::new(provider));
138        self
139    }
140
141    pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
142        self.global_instruction = Some(instruction.into());
143        self
144    }
145
146    pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
147        self.global_instruction_provider = Some(Arc::new(provider));
148        self
149    }
150
151    pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
152        self.input_schema = Some(schema);
153        self
154    }
155
156    pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
157        self.output_schema = Some(schema);
158        self
159    }
160
161    pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
162        self.disallow_transfer_to_parent = disallow;
163        self
164    }
165
166    pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
167        self.disallow_transfer_to_peers = disallow;
168        self
169    }
170
171    pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
172        self.include_contents = include;
173        self
174    }
175
176    pub fn output_key(mut self, key: impl Into<String>) -> Self {
177        self.output_key = Some(key.into());
178        self
179    }
180
181    /// Set the maximum number of LLM round-trips (iterations) before the agent stops.
182    /// Default is 100.
183    pub fn max_iterations(mut self, max: u32) -> Self {
184        self.max_iterations = max;
185        self
186    }
187
188    pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
189        self.tools.push(tool);
190        self
191    }
192
193    pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
194        self.sub_agents.push(agent);
195        self
196    }
197
198    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
199        self.before_callbacks.push(callback);
200        self
201    }
202
203    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
204        self.after_callbacks.push(callback);
205        self
206    }
207
208    pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
209        self.before_model_callbacks.push(callback);
210        self
211    }
212
213    pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
214        self.after_model_callbacks.push(callback);
215        self
216    }
217
218    pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
219        self.before_tool_callbacks.push(callback);
220        self
221    }
222
223    pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
224        self.after_tool_callbacks.push(callback);
225        self
226    }
227
228    /// Set input guardrails to validate user input before processing.
229    ///
230    /// Input guardrails run before the agent processes the request and can:
231    /// - Block harmful or off-topic content
232    /// - Redact PII from user input
233    /// - Enforce input length limits
234    ///
235    /// Requires the `guardrails` feature.
236    pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
237        self.input_guardrails = guardrails;
238        self
239    }
240
241    /// Set output guardrails to validate agent responses.
242    ///
243    /// Output guardrails run after the agent generates a response and can:
244    /// - Enforce JSON schema compliance
245    /// - Redact PII from responses
246    /// - Block harmful content in responses
247    ///
248    /// Requires the `guardrails` feature.
249    pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
250        self.output_guardrails = guardrails;
251        self
252    }
253
254    pub fn build(self) -> Result<LlmAgent> {
255        let model =
256            self.model.ok_or_else(|| adk_core::AdkError::Agent("Model is required".to_string()))?;
257
258        Ok(LlmAgent {
259            name: self.name,
260            description: self.description.unwrap_or_default(),
261            model,
262            instruction: self.instruction,
263            instruction_provider: self.instruction_provider,
264            global_instruction: self.global_instruction,
265            global_instruction_provider: self.global_instruction_provider,
266            input_schema: self.input_schema,
267            output_schema: self.output_schema,
268            disallow_transfer_to_parent: self.disallow_transfer_to_parent,
269            disallow_transfer_to_peers: self.disallow_transfer_to_peers,
270            include_contents: self.include_contents,
271            tools: self.tools,
272            sub_agents: self.sub_agents,
273            output_key: self.output_key,
274            max_iterations: self.max_iterations,
275            before_callbacks: Arc::new(self.before_callbacks),
276            after_callbacks: Arc::new(self.after_callbacks),
277            before_model_callbacks: Arc::new(self.before_model_callbacks),
278            after_model_callbacks: Arc::new(self.after_model_callbacks),
279            before_tool_callbacks: Arc::new(self.before_tool_callbacks),
280            after_tool_callbacks: Arc::new(self.after_tool_callbacks),
281            input_guardrails: self.input_guardrails,
282            output_guardrails: self.output_guardrails,
283        })
284    }
285}
286
287// AgentToolContext wraps the parent InvocationContext and preserves all context
288// instead of throwing it away like SimpleToolContext did
289struct AgentToolContext {
290    parent_ctx: Arc<dyn InvocationContext>,
291    function_call_id: String,
292    actions: Mutex<EventActions>,
293}
294
295impl AgentToolContext {
296    fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
297        Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
298    }
299}
300
301#[async_trait]
302impl ReadonlyContext for AgentToolContext {
303    fn invocation_id(&self) -> &str {
304        self.parent_ctx.invocation_id()
305    }
306
307    fn agent_name(&self) -> &str {
308        self.parent_ctx.agent_name()
309    }
310
311    fn user_id(&self) -> &str {
312        // ✅ Delegate to parent - now tools get the real user_id!
313        self.parent_ctx.user_id()
314    }
315
316    fn app_name(&self) -> &str {
317        // ✅ Delegate to parent - now tools get the real app_name!
318        self.parent_ctx.app_name()
319    }
320
321    fn session_id(&self) -> &str {
322        // ✅ Delegate to parent - now tools get the real session_id!
323        self.parent_ctx.session_id()
324    }
325
326    fn branch(&self) -> &str {
327        self.parent_ctx.branch()
328    }
329
330    fn user_content(&self) -> &Content {
331        self.parent_ctx.user_content()
332    }
333}
334
335#[async_trait]
336impl CallbackContext for AgentToolContext {
337    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
338        // ✅ Delegate to parent - tools can now access artifacts!
339        self.parent_ctx.artifacts()
340    }
341}
342
343#[async_trait]
344impl ToolContext for AgentToolContext {
345    fn function_call_id(&self) -> &str {
346        &self.function_call_id
347    }
348
349    fn actions(&self) -> EventActions {
350        self.actions.lock().unwrap().clone()
351    }
352
353    fn set_actions(&self, actions: EventActions) {
354        *self.actions.lock().unwrap() = actions;
355    }
356
357    async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
358        // ✅ Delegate to parent's memory if available
359        if let Some(memory) = self.parent_ctx.memory() {
360            memory.search(query).await
361        } else {
362            Ok(vec![])
363        }
364    }
365}
366
367#[async_trait]
368impl Agent for LlmAgent {
369    fn name(&self) -> &str {
370        &self.name
371    }
372
373    fn description(&self) -> &str {
374        &self.description
375    }
376
377    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
378        &self.sub_agents
379    }
380
381    #[adk_telemetry::instrument(
382        skip(self, ctx),
383        fields(
384            agent.name = %self.name,
385            agent.description = %self.description,
386            invocation.id = %ctx.invocation_id(),
387            user.id = %ctx.user_id(),
388            session.id = %ctx.session_id()
389        )
390    )]
391    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
392        adk_telemetry::info!("Starting agent execution");
393
394        let agent_name = self.name.clone();
395        let invocation_id = ctx.invocation_id().to_string();
396        let model = self.model.clone();
397        let tools = self.tools.clone();
398        let sub_agents = self.sub_agents.clone();
399
400        let instruction = self.instruction.clone();
401        let instruction_provider = self.instruction_provider.clone();
402        let global_instruction = self.global_instruction.clone();
403        let global_instruction_provider = self.global_instruction_provider.clone();
404        let output_key = self.output_key.clone();
405        let output_schema = self.output_schema.clone();
406        let include_contents = self.include_contents;
407        let max_iterations = self.max_iterations;
408        // Clone Arc references (cheap)
409        let before_agent_callbacks = self.before_callbacks.clone();
410        let after_agent_callbacks = self.after_callbacks.clone();
411        let before_model_callbacks = self.before_model_callbacks.clone();
412        let after_model_callbacks = self.after_model_callbacks.clone();
413        let _before_tool_callbacks = self.before_tool_callbacks.clone();
414        let _after_tool_callbacks = self.after_tool_callbacks.clone();
415
416        let s = stream! {
417            // ===== BEFORE AGENT CALLBACKS =====
418            // Execute before the agent starts running
419            // If any returns content, skip agent execution
420            for callback in before_agent_callbacks.as_ref() {
421                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
422                    Ok(Some(content)) => {
423                        // Callback returned content - yield it and skip agent execution
424                        let mut early_event = Event::new(&invocation_id);
425                        early_event.author = agent_name.clone();
426                        early_event.llm_response.content = Some(content);
427                        yield Ok(early_event);
428
429                        // Skip rest of agent execution and go to after callbacks
430                        for after_callback in after_agent_callbacks.as_ref() {
431                            match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
432                                Ok(Some(after_content)) => {
433                                    let mut after_event = Event::new(&invocation_id);
434                                    after_event.author = agent_name.clone();
435                                    after_event.llm_response.content = Some(after_content);
436                                    yield Ok(after_event);
437                                    return;
438                                }
439                                Ok(None) => continue,
440                                Err(e) => {
441                                    yield Err(e);
442                                    return;
443                                }
444                            }
445                        }
446                        return;
447                    }
448                    Ok(None) => {
449                        // Continue to next callback
450                        continue;
451                    }
452                    Err(e) => {
453                        // Callback failed - propagate error
454                        yield Err(e);
455                        return;
456                    }
457                }
458            }
459
460            // ===== MAIN AGENT EXECUTION =====
461            let mut conversation_history = Vec::new();
462
463            // ===== PROCESS GLOBAL INSTRUCTION =====
464            // GlobalInstruction provides tree-wide personality/identity
465            if let Some(provider) = &global_instruction_provider {
466                // Dynamic global instruction via provider
467                let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
468                if !global_inst.is_empty() {
469                    conversation_history.push(Content {
470                        role: "user".to_string(),
471                        parts: vec![Part::Text { text: global_inst }],
472                    });
473                }
474            } else if let Some(ref template) = global_instruction {
475                // Static global instruction with template injection
476                let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
477                if !processed.is_empty() {
478                    conversation_history.push(Content {
479                        role: "user".to_string(),
480                        parts: vec![Part::Text { text: processed }],
481                    });
482                }
483            }
484
485            // ===== PROCESS AGENT INSTRUCTION =====
486            // Agent-specific instruction
487            if let Some(provider) = &instruction_provider {
488                // Dynamic instruction via provider
489                let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
490                if !inst.is_empty() {
491                    conversation_history.push(Content {
492                        role: "user".to_string(),
493                        parts: vec![Part::Text { text: inst }],
494                    });
495                }
496            } else if let Some(ref template) = instruction {
497                // Static instruction with template injection
498                let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
499                if !processed.is_empty() {
500                    conversation_history.push(Content {
501                        role: "user".to_string(),
502                        parts: vec![Part::Text { text: processed }],
503                    });
504                }
505            }
506
507            // ===== LOAD SESSION HISTORY =====
508            // Load previous conversation turns from the session
509            // NOTE: Session history already includes the current user message (added by Runner before agent runs)
510            let session_history = ctx.session().conversation_history();
511            conversation_history.extend(session_history);
512
513            // ===== APPLY INCLUDE_CONTENTS FILTERING =====
514            // Control what conversation history the agent sees
515            let mut conversation_history = match include_contents {
516                adk_core::IncludeContents::None => {
517                    // Agent operates solely on current turn - only keep the latest user input
518                    // Remove all previous history except instructions and current user message
519                    let mut filtered = Vec::new();
520
521                    // Keep global and agent instructions (already added above)
522                    let instruction_count = conversation_history.iter()
523                        .take_while(|c| c.role == "user" && c.parts.iter().any(|p| {
524                            if let Part::Text { text } = p {
525                                // These are likely instructions, not user queries
526                                !text.is_empty()
527                            } else {
528                                false
529                            }
530                        }))
531                        .count();
532
533                    // Take instructions
534                    filtered.extend(conversation_history.iter().take(instruction_count).cloned());
535
536                    // Take only the last user message (current turn)
537                    if let Some(last) = conversation_history.last() {
538                        if last.role == "user" {
539                            filtered.push(last.clone());
540                        }
541                    }
542
543                    filtered
544                }
545                adk_core::IncludeContents::Default => {
546                    // Default behavior - keep full conversation history
547                    conversation_history
548                }
549            };
550
551            // Build tool declarations for Gemini
552            // Uses enhanced_description() which includes NOTE for long-running tools
553            let mut tool_declarations = std::collections::HashMap::new();
554            for tool in &tools {
555                // Build FunctionDeclaration JSON with enhanced description
556                // For long-running tools, this includes a warning not to call again if pending
557                let mut decl = serde_json::json!({
558                    "name": tool.name(),
559                    "description": tool.enhanced_description(),
560                });
561
562                if let Some(params) = tool.parameters_schema() {
563                    decl["parameters"] = params;
564                }
565
566                if let Some(response) = tool.response_schema() {
567                    decl["response"] = response;
568                }
569
570                tool_declarations.insert(tool.name().to_string(), decl);
571            }
572
573            // Inject transfer_to_agent tool if sub-agents exist
574            if !sub_agents.is_empty() {
575                let transfer_tool_name = "transfer_to_agent";
576                let transfer_tool_decl = serde_json::json!({
577                    "name": transfer_tool_name,
578                    "description": "Transfer execution to another agent.",
579                    "parameters": {
580                        "type": "object",
581                        "properties": {
582                            "agent_name": {
583                                "type": "string",
584                                "description": "The name of the agent to transfer to."
585                            }
586                        },
587                        "required": ["agent_name"]
588                    }
589                });
590                tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
591            }
592
593
594            // Multi-turn loop with max iterations
595            let mut iteration = 0;
596
597            loop {
598                iteration += 1;
599                if iteration > max_iterations {
600                    yield Err(adk_core::AdkError::Agent(
601                        format!("Max iterations ({}) exceeded", max_iterations)
602                    ));
603                    return;
604                }
605
606                // Build request with conversation history
607                let config = output_schema.as_ref().map(|schema| {
608                    adk_core::GenerateContentConfig {
609                        temperature: None,
610                        top_p: None,
611                        top_k: None,
612                        max_output_tokens: None,
613                        response_schema: Some(schema.clone()),
614                    }
615                });
616
617                let request = LlmRequest {
618                    model: model.name().to_string(),
619                    contents: conversation_history.clone(),
620                    tools: tool_declarations.clone(),
621                    config,
622                };
623
624                // ===== BEFORE MODEL CALLBACKS =====
625                // These can modify the request or skip the model call by returning a response
626                let mut current_request = request;
627                let mut model_response_override = None;
628                for callback in before_model_callbacks.as_ref() {
629                    match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
630                        Ok(BeforeModelResult::Continue(modified_request)) => {
631                            // Callback may have modified the request, continue with it
632                            current_request = modified_request;
633                        }
634                        Ok(BeforeModelResult::Skip(response)) => {
635                            // Callback returned a response - skip model call
636                            model_response_override = Some(response);
637                            break;
638                        }
639                        Err(e) => {
640                            // Callback failed - propagate error
641                            yield Err(e);
642                            return;
643                        }
644                    }
645                }
646                let request = current_request;
647
648                // Determine streaming source: cached response or real model
649                let mut accumulated_content: Option<Content> = None;
650
651                if let Some(cached_response) = model_response_override {
652                    // Use callback-provided response (e.g., from cache)
653                    // Yield it as an event
654                    let mut cached_event = Event::new(&invocation_id);
655                    cached_event.author = agent_name.clone();
656                    cached_event.llm_response.content = cached_response.content.clone();
657                    cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
658                    cached_event.gcp_llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
659                    cached_event.gcp_llm_response = Some(serde_json::to_string(&cached_response).unwrap_or_default());
660
661                    // Populate long_running_tool_ids for function calls from long-running tools
662                    if let Some(ref content) = cached_response.content {
663                        let long_running_ids: Vec<String> = content.parts.iter()
664                            .filter_map(|p| {
665                                if let Part::FunctionCall { name, .. } = p {
666                                    if let Some(tool) = tools.iter().find(|t| t.name() == name) {
667                                        if tool.is_long_running() {
668                                            return Some(name.clone());
669                                        }
670                                    }
671                                }
672                                None
673                            })
674                            .collect();
675                        cached_event.long_running_tool_ids = long_running_ids;
676                    }
677
678                    yield Ok(cached_event);
679
680                    accumulated_content = cached_response.content;
681                } else {
682                    // Record LLM request for tracing
683                    let request_json = serde_json::to_string(&request).unwrap_or_default();
684
685                    // Create call_llm span with GCP attributes (works for all model types)
686                    let llm_ts = std::time::SystemTime::now()
687                        .duration_since(std::time::UNIX_EPOCH)
688                        .unwrap_or_default()
689                        .as_nanos();
690                    let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
691                    let llm_span = tracing::info_span!(
692                        "call_llm",
693                        "gcp.vertex.agent.event_id" = %llm_event_id,
694                        "gcp.vertex.agent.invocation_id" = %invocation_id,
695                        "gcp.vertex.agent.session_id" = %ctx.session_id(),
696                        "gcp.vertex.agent.llm_request" = %request_json,
697                        "gcp.vertex.agent.llm_response" = tracing::field::Empty  // Placeholder for later recording
698                    );
699                    let _llm_guard = llm_span.enter();
700
701                    // Check streaming mode from run config
702                    use adk_core::StreamingMode;
703                    let streaming_mode = ctx.run_config().streaming_mode;
704                    let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi);
705
706                    // Always use streaming internally for LLM calls
707                    let mut response_stream = model.generate_content(request, true).await?;
708
709                    use futures::StreamExt;
710
711                    // Track last chunk for final event metadata (used in None mode)
712                    let mut last_chunk: Option<LlmResponse> = None;
713
714                    // Stream and process chunks with AfterModel callbacks
715                    while let Some(chunk_result) = response_stream.next().await {
716                        let mut chunk = match chunk_result {
717                            Ok(c) => c,
718                            Err(e) => {
719                                yield Err(e);
720                                return;
721                            }
722                        };
723
724                        // ===== AFTER MODEL CALLBACKS (per chunk) =====
725                        // Callbacks can modify each streaming chunk
726                        for callback in after_model_callbacks.as_ref() {
727                            match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
728                                Ok(Some(modified_chunk)) => {
729                                    // Callback modified this chunk
730                                    chunk = modified_chunk;
731                                    break;
732                                }
733                                Ok(None) => {
734                                    // Continue to next callback
735                                    continue;
736                                }
737                                Err(e) => {
738                                    // Callback failed - propagate error
739                                    yield Err(e);
740                                    return;
741                                }
742                            }
743                        }
744
745                        // Accumulate content for conversation history (always needed)
746                        if let Some(chunk_content) = chunk.content.clone() {
747                            if let Some(ref mut acc) = accumulated_content {
748                                acc.parts.extend(chunk_content.parts);
749                            } else {
750                                accumulated_content = Some(chunk_content);
751                            }
752                        }
753
754                        // For SSE/Bidi mode: yield each chunk immediately with stable event ID
755                        if should_stream_to_client {
756                            let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
757                            partial_event.author = agent_name.clone();
758                            partial_event.llm_request = Some(request_json.clone());
759                            partial_event.gcp_llm_request = Some(request_json.clone());
760                            partial_event.gcp_llm_response = Some(serde_json::to_string(&chunk).unwrap_or_default());
761                            partial_event.llm_response.partial = chunk.partial;
762                            partial_event.llm_response.turn_complete = chunk.turn_complete;
763                            partial_event.llm_response.finish_reason = chunk.finish_reason;
764                            partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
765                            partial_event.llm_response.content = chunk.content.clone();
766
767                            // Populate long_running_tool_ids
768                            if let Some(ref content) = chunk.content {
769                                let long_running_ids: Vec<String> = content.parts.iter()
770                                    .filter_map(|p| {
771                                        if let Part::FunctionCall { name, .. } = p {
772                                            if let Some(tool) = tools.iter().find(|t| t.name() == name) {
773                                                if tool.is_long_running() {
774                                                    return Some(name.clone());
775                                                }
776                                            }
777                                        }
778                                        None
779                                    })
780                                    .collect();
781                                partial_event.long_running_tool_ids = long_running_ids;
782                            }
783
784                            yield Ok(partial_event);
785                        }
786
787                        // Store last chunk for final event metadata
788                        last_chunk = Some(chunk.clone());
789
790                        // Check if turn is complete
791                        if chunk.turn_complete {
792                            break;
793                        }
794                    }
795
796                    // For None mode: yield single final event with accumulated content
797                    if !should_stream_to_client {
798                        let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
799                        final_event.author = agent_name.clone();
800                        final_event.llm_request = Some(request_json.clone());
801                        final_event.gcp_llm_request = Some(request_json.clone());
802                        final_event.llm_response.content = accumulated_content.clone();
803                        final_event.llm_response.partial = false;
804                        final_event.llm_response.turn_complete = true;
805
806                        // Copy metadata from last chunk
807                        if let Some(ref last) = last_chunk {
808                            final_event.llm_response.finish_reason = last.finish_reason;
809                            final_event.llm_response.usage_metadata = last.usage_metadata.clone();
810                            final_event.gcp_llm_response = Some(serde_json::to_string(last).unwrap_or_default());
811                        }
812
813                        // Populate long_running_tool_ids
814                        if let Some(ref content) = accumulated_content {
815                            let long_running_ids: Vec<String> = content.parts.iter()
816                                .filter_map(|p| {
817                                    if let Part::FunctionCall { name, .. } = p {
818                                        if let Some(tool) = tools.iter().find(|t| t.name() == name) {
819                                            if tool.is_long_running() {
820                                                return Some(name.clone());
821                                            }
822                                        }
823                                    }
824                                    None
825                                })
826                                .collect();
827                            final_event.long_running_tool_ids = long_running_ids;
828                        }
829
830                        yield Ok(final_event);
831                    }
832
833                    // Record LLM response to span before guard drops
834                    if let Some(ref content) = accumulated_content {
835                        let response_json = serde_json::to_string(content).unwrap_or_default();
836                        llm_span.record("gcp.vertex.agent.llm_response", &response_json);
837                    }
838                }
839
840                // After streaming/caching completes, check for function calls in accumulated content
841                let function_call_names: Vec<String> = accumulated_content.as_ref()
842                    .map(|c| c.parts.iter()
843                        .filter_map(|p| {
844                            if let Part::FunctionCall { name, .. } = p {
845                                Some(name.clone())
846                            } else {
847                                None
848                            }
849                        })
850                        .collect())
851                    .unwrap_or_default();
852
853                let has_function_calls = !function_call_names.is_empty();
854
855                // Check if ALL function calls are from long-running tools
856                // If so, we should NOT continue the loop - the tool returned a pending status
857                // and the agent/client will poll for completion later
858                let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
859                    tools.iter()
860                        .find(|t| t.name() == name)
861                        .map(|t| t.is_long_running())
862                        .unwrap_or(false)
863                });
864
865                // Add final content to history
866                if let Some(ref content) = accumulated_content {
867                    conversation_history.push(content.clone());
868
869                    // Handle output_key: save final agent output to state_delta
870                    if let Some(ref output_key) = output_key {
871                        if !has_function_calls {  // Only save if not calling tools
872                            let mut text_parts = String::new();
873                            for part in &content.parts {
874                                if let Part::Text { text } = part {
875                                    text_parts.push_str(text);
876                                }
877                            }
878                            if !text_parts.is_empty() {
879                                // Yield a final state update event
880                                let mut state_event = Event::new(&invocation_id);
881                                state_event.author = agent_name.clone();
882                                state_event.actions.state_delta.insert(
883                                    output_key.clone(),
884                                    serde_json::Value::String(text_parts),
885                                );
886                                yield Ok(state_event);
887                            }
888                        }
889                    }
890                }
891
892                if !has_function_calls {
893                    // No function calls, we're done
894                    // Record LLM response for tracing
895                    if let Some(ref content) = accumulated_content {
896                        let response_json = serde_json::to_string(content).unwrap_or_default();
897                        tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
898                    }
899
900                    tracing::info!(agent.name = %agent_name, "Agent execution complete");
901                    break;
902                }
903
904                // Execute function calls and add responses to history
905                if let Some(content) = &accumulated_content {
906                    for part in &content.parts {
907                        if let Part::FunctionCall { name, args, id } = part {
908                            // Handle transfer_to_agent specially
909                            if name == "transfer_to_agent" {
910                                let target_agent = args.get("agent_name")
911                                    .and_then(|v| v.as_str())
912                                    .unwrap_or_default()
913                                    .to_string();
914
915                                let mut transfer_event = Event::new(&invocation_id);
916                                transfer_event.author = agent_name.clone();
917                                transfer_event.actions.transfer_to_agent = Some(target_agent);
918
919                                yield Ok(transfer_event);
920                                return;
921                            }
922
923
924                            // Find and execute tool
925                            let (tool_result, tool_actions) = if let Some(tool) = tools.iter().find(|t| t.name() == name) {
926                                // ✅ Use AgentToolContext that preserves parent context
927                                let tool_ctx: Arc<dyn ToolContext> = Arc::new(AgentToolContext::new(
928                                    ctx.clone(),
929                                    format!("{}_{}", invocation_id, name),
930                                ));
931
932                                // Create span name following adk-go pattern: "execute_tool {name}"
933                                let span_name = format!("execute_tool {}", name);
934                                let tool_span = tracing::info_span!(
935                                    "",
936                                    otel.name = %span_name,
937                                    tool.name = %name,
938                                    "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
939                                    "gcp.vertex.agent.invocation_id" = %invocation_id,
940                                    "gcp.vertex.agent.session_id" = %ctx.session_id()
941                                );
942
943                                // Use instrument() for proper async span handling
944                                let result = async {
945                                    tracing::info!(tool.name = %name, tool.args = %args, "tool_call");
946                                    match tool.execute(tool_ctx.clone(), args.clone()).await {
947                                        Ok(result) => {
948                                            tracing::info!(tool.name = %name, tool.result = %result, "tool_result");
949                                            result
950                                        }
951                                        Err(e) => {
952                                            tracing::warn!(tool.name = %name, error = %e, "tool_error");
953                                            serde_json::json!({ "error": e.to_string() })
954                                        }
955                                    }
956                                }.instrument(tool_span).await;
957
958                                (result, tool_ctx.actions())
959                            } else {
960                                (serde_json::json!({ "error": format!("Tool {} not found", name) }), EventActions::default())
961                            };
962
963                            // Yield tool execution event
964                            let mut tool_event = Event::new(&invocation_id);
965                            tool_event.author = agent_name.clone();
966                            tool_event.actions = tool_actions.clone();
967                            tool_event.llm_response.content = Some(Content {
968                                role: "function".to_string(),
969                                parts: vec![Part::FunctionResponse {
970                                    function_response: FunctionResponseData {
971                                        name: name.clone(),
972                                        response: tool_result.clone(),
973                                    },
974                                    id: id.clone(),
975                                }],
976                            });
977                            yield Ok(tool_event);
978
979                            // Check if tool requested escalation or skip_summarization
980                            if tool_actions.escalate || tool_actions.skip_summarization {
981                                // Tool wants to terminate agent loop
982                                return;
983                            }
984
985                            // Add function response to history
986                            conversation_history.push(Content {
987                                role: "function".to_string(),
988                                parts: vec![Part::FunctionResponse {
989                                    function_response: FunctionResponseData {
990                                        name: name.clone(),
991                                        response: tool_result,
992                                    },
993                                    id: id.clone(),
994                                }],
995                            });
996                        }
997                    }
998                }
999
1000                // If all function calls were from long-running tools, we need ONE more model call
1001                // to let the model generate a user-friendly response about the pending task
1002                // But we mark this as the final iteration to prevent infinite loops
1003                if all_calls_are_long_running {
1004                    // Continue to next iteration for model to respond, but this will be the last
1005                    // The model will see the tool response and generate text like "Started task X..."
1006                    // On next iteration, there won't be function calls, so we'll break naturally
1007                }
1008            }
1009
1010            // ===== AFTER AGENT CALLBACKS =====
1011            // Execute after the agent completes
1012            for callback in after_agent_callbacks.as_ref() {
1013                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1014                    Ok(Some(content)) => {
1015                        // Callback returned content - yield it
1016                        let mut after_event = Event::new(&invocation_id);
1017                        after_event.author = agent_name.clone();
1018                        after_event.llm_response.content = Some(content);
1019                        yield Ok(after_event);
1020                        break; // First callback that returns content wins
1021                    }
1022                    Ok(None) => {
1023                        // Continue to next callback
1024                        continue;
1025                    }
1026                    Err(e) => {
1027                        // Callback failed - propagate error
1028                        yield Err(e);
1029                        return;
1030                    }
1031                }
1032            }
1033        };
1034
1035        Ok(Box::pin(s))
1036    }
1037}