adk_agent/
llm_agent.rs

1use adk_core::{
2    AfterAgentCallback, AfterModelCallback, AfterToolCallback, Agent, BeforeAgentCallback,
3    BeforeModelCallback, BeforeModelResult, BeforeToolCallback, CallbackContext, Content, Event,
4    EventActions, FunctionResponseData, GlobalInstructionProvider, InstructionProvider,
5    InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry, Part, ReadonlyContext, Result,
6    Tool, ToolContext,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use tracing::Instrument;
12
13use crate::guardrails::GuardrailSet;
14
15pub struct LlmAgent {
16    name: String,
17    description: String,
18    model: Arc<dyn Llm>,
19    instruction: Option<String>,
20    instruction_provider: Option<Arc<InstructionProvider>>,
21    global_instruction: Option<String>,
22    global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
23    #[allow(dead_code)] // Part of public API via builder
24    input_schema: Option<serde_json::Value>,
25    output_schema: Option<serde_json::Value>,
26    #[allow(dead_code)] // Part of public API via builder
27    disallow_transfer_to_parent: bool,
28    #[allow(dead_code)] // Part of public API via builder
29    disallow_transfer_to_peers: bool,
30    include_contents: adk_core::IncludeContents,
31    tools: Vec<Arc<dyn Tool>>,
32    sub_agents: Vec<Arc<dyn Agent>>,
33    output_key: Option<String>,
34    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
35    after_callbacks: Arc<Vec<AfterAgentCallback>>,
36    before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
37    after_model_callbacks: Arc<Vec<AfterModelCallback>>,
38    before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
39    after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
40    #[allow(dead_code)] // Used when guardrails feature is enabled
41    input_guardrails: GuardrailSet,
42    #[allow(dead_code)] // Used when guardrails feature is enabled
43    output_guardrails: GuardrailSet,
44}
45
46impl std::fmt::Debug for LlmAgent {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("LlmAgent")
49            .field("name", &self.name)
50            .field("description", &self.description)
51            .field("model", &self.model.name())
52            .field("instruction", &self.instruction)
53            .field("tools_count", &self.tools.len())
54            .field("sub_agents_count", &self.sub_agents.len())
55            .finish()
56    }
57}
58
59pub struct LlmAgentBuilder {
60    name: String,
61    description: Option<String>,
62    model: Option<Arc<dyn Llm>>,
63    instruction: Option<String>,
64    instruction_provider: Option<Arc<InstructionProvider>>,
65    global_instruction: Option<String>,
66    global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
67    input_schema: Option<serde_json::Value>,
68    output_schema: Option<serde_json::Value>,
69    disallow_transfer_to_parent: bool,
70    disallow_transfer_to_peers: bool,
71    include_contents: adk_core::IncludeContents,
72    tools: Vec<Arc<dyn Tool>>,
73    sub_agents: Vec<Arc<dyn Agent>>,
74    output_key: Option<String>,
75    before_callbacks: Vec<BeforeAgentCallback>,
76    after_callbacks: Vec<AfterAgentCallback>,
77    before_model_callbacks: Vec<BeforeModelCallback>,
78    after_model_callbacks: Vec<AfterModelCallback>,
79    before_tool_callbacks: Vec<BeforeToolCallback>,
80    after_tool_callbacks: Vec<AfterToolCallback>,
81    input_guardrails: GuardrailSet,
82    output_guardrails: GuardrailSet,
83}
84
85impl LlmAgentBuilder {
86    pub fn new(name: impl Into<String>) -> Self {
87        Self {
88            name: name.into(),
89            description: None,
90            model: None,
91            instruction: None,
92            instruction_provider: None,
93            global_instruction: None,
94            global_instruction_provider: None,
95            input_schema: None,
96            output_schema: None,
97            disallow_transfer_to_parent: false,
98            disallow_transfer_to_peers: false,
99            include_contents: adk_core::IncludeContents::Default,
100            tools: Vec::new(),
101            sub_agents: Vec::new(),
102            output_key: None,
103            before_callbacks: Vec::new(),
104            after_callbacks: Vec::new(),
105            before_model_callbacks: Vec::new(),
106            after_model_callbacks: Vec::new(),
107            before_tool_callbacks: Vec::new(),
108            after_tool_callbacks: Vec::new(),
109            input_guardrails: GuardrailSet::new(),
110            output_guardrails: GuardrailSet::new(),
111        }
112    }
113
114    pub fn description(mut self, desc: impl Into<String>) -> Self {
115        self.description = Some(desc.into());
116        self
117    }
118
119    pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
120        self.model = Some(model);
121        self
122    }
123
124    pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
125        self.instruction = Some(instruction.into());
126        self
127    }
128
129    pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
130        self.instruction_provider = Some(Arc::new(provider));
131        self
132    }
133
134    pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
135        self.global_instruction = Some(instruction.into());
136        self
137    }
138
139    pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
140        self.global_instruction_provider = Some(Arc::new(provider));
141        self
142    }
143
144    pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
145        self.input_schema = Some(schema);
146        self
147    }
148
149    pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
150        self.output_schema = Some(schema);
151        self
152    }
153
154    pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
155        self.disallow_transfer_to_parent = disallow;
156        self
157    }
158
159    pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
160        self.disallow_transfer_to_peers = disallow;
161        self
162    }
163
164    pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
165        self.include_contents = include;
166        self
167    }
168
169    pub fn output_key(mut self, key: impl Into<String>) -> Self {
170        self.output_key = Some(key.into());
171        self
172    }
173
174    pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
175        self.tools.push(tool);
176        self
177    }
178
179    pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
180        self.sub_agents.push(agent);
181        self
182    }
183
184    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
185        self.before_callbacks.push(callback);
186        self
187    }
188
189    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
190        self.after_callbacks.push(callback);
191        self
192    }
193
194    pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
195        self.before_model_callbacks.push(callback);
196        self
197    }
198
199    pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
200        self.after_model_callbacks.push(callback);
201        self
202    }
203
204    pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
205        self.before_tool_callbacks.push(callback);
206        self
207    }
208
209    pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
210        self.after_tool_callbacks.push(callback);
211        self
212    }
213
214    /// Set input guardrails to validate user input before processing.
215    ///
216    /// Input guardrails run before the agent processes the request and can:
217    /// - Block harmful or off-topic content
218    /// - Redact PII from user input
219    /// - Enforce input length limits
220    ///
221    /// Requires the `guardrails` feature.
222    pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
223        self.input_guardrails = guardrails;
224        self
225    }
226
227    /// Set output guardrails to validate agent responses.
228    ///
229    /// Output guardrails run after the agent generates a response and can:
230    /// - Enforce JSON schema compliance
231    /// - Redact PII from responses
232    /// - Block harmful content in responses
233    ///
234    /// Requires the `guardrails` feature.
235    pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
236        self.output_guardrails = guardrails;
237        self
238    }
239
240    pub fn build(self) -> Result<LlmAgent> {
241        let model =
242            self.model.ok_or_else(|| adk_core::AdkError::Agent("Model is required".to_string()))?;
243
244        Ok(LlmAgent {
245            name: self.name,
246            description: self.description.unwrap_or_default(),
247            model,
248            instruction: self.instruction,
249            instruction_provider: self.instruction_provider,
250            global_instruction: self.global_instruction,
251            global_instruction_provider: self.global_instruction_provider,
252            input_schema: self.input_schema,
253            output_schema: self.output_schema,
254            disallow_transfer_to_parent: self.disallow_transfer_to_parent,
255            disallow_transfer_to_peers: self.disallow_transfer_to_peers,
256            include_contents: self.include_contents,
257            tools: self.tools,
258            sub_agents: self.sub_agents,
259            output_key: self.output_key,
260            before_callbacks: Arc::new(self.before_callbacks),
261            after_callbacks: Arc::new(self.after_callbacks),
262            before_model_callbacks: Arc::new(self.before_model_callbacks),
263            after_model_callbacks: Arc::new(self.after_model_callbacks),
264            before_tool_callbacks: Arc::new(self.before_tool_callbacks),
265            after_tool_callbacks: Arc::new(self.after_tool_callbacks),
266            input_guardrails: self.input_guardrails,
267            output_guardrails: self.output_guardrails,
268        })
269    }
270}
271
272// AgentToolContext wraps the parent InvocationContext and preserves all context
273// instead of throwing it away like SimpleToolContext did
274struct AgentToolContext {
275    parent_ctx: Arc<dyn InvocationContext>,
276    function_call_id: String,
277    actions: Mutex<EventActions>,
278}
279
280impl AgentToolContext {
281    fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
282        Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
283    }
284}
285
286#[async_trait]
287impl ReadonlyContext for AgentToolContext {
288    fn invocation_id(&self) -> &str {
289        self.parent_ctx.invocation_id()
290    }
291
292    fn agent_name(&self) -> &str {
293        self.parent_ctx.agent_name()
294    }
295
296    fn user_id(&self) -> &str {
297        // ✅ Delegate to parent - now tools get the real user_id!
298        self.parent_ctx.user_id()
299    }
300
301    fn app_name(&self) -> &str {
302        // ✅ Delegate to parent - now tools get the real app_name!
303        self.parent_ctx.app_name()
304    }
305
306    fn session_id(&self) -> &str {
307        // ✅ Delegate to parent - now tools get the real session_id!
308        self.parent_ctx.session_id()
309    }
310
311    fn branch(&self) -> &str {
312        self.parent_ctx.branch()
313    }
314
315    fn user_content(&self) -> &Content {
316        self.parent_ctx.user_content()
317    }
318}
319
320#[async_trait]
321impl CallbackContext for AgentToolContext {
322    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
323        // ✅ Delegate to parent - tools can now access artifacts!
324        self.parent_ctx.artifacts()
325    }
326}
327
328#[async_trait]
329impl ToolContext for AgentToolContext {
330    fn function_call_id(&self) -> &str {
331        &self.function_call_id
332    }
333
334    fn actions(&self) -> EventActions {
335        self.actions.lock().unwrap().clone()
336    }
337
338    fn set_actions(&self, actions: EventActions) {
339        *self.actions.lock().unwrap() = actions;
340    }
341
342    async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
343        // ✅ Delegate to parent's memory if available
344        if let Some(memory) = self.parent_ctx.memory() {
345            memory.search(query).await
346        } else {
347            Ok(vec![])
348        }
349    }
350}
351
352#[async_trait]
353impl Agent for LlmAgent {
354    fn name(&self) -> &str {
355        &self.name
356    }
357
358    fn description(&self) -> &str {
359        &self.description
360    }
361
362    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
363        &self.sub_agents
364    }
365
366    #[adk_telemetry::instrument(
367        skip(self, ctx),
368        fields(
369            agent.name = %self.name,
370            agent.description = %self.description,
371            invocation.id = %ctx.invocation_id(),
372            user.id = %ctx.user_id(),
373            session.id = %ctx.session_id()
374        )
375    )]
376    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
377        adk_telemetry::info!("Starting agent execution");
378
379        let agent_name = self.name.clone();
380        let invocation_id = ctx.invocation_id().to_string();
381        let model = self.model.clone();
382        let tools = self.tools.clone();
383        let sub_agents = self.sub_agents.clone();
384
385        let instruction = self.instruction.clone();
386        let instruction_provider = self.instruction_provider.clone();
387        let global_instruction = self.global_instruction.clone();
388        let global_instruction_provider = self.global_instruction_provider.clone();
389        let output_key = self.output_key.clone();
390        let output_schema = self.output_schema.clone();
391        let include_contents = self.include_contents;
392        // Clone Arc references (cheap)
393        let before_agent_callbacks = self.before_callbacks.clone();
394        let after_agent_callbacks = self.after_callbacks.clone();
395        let before_model_callbacks = self.before_model_callbacks.clone();
396        let after_model_callbacks = self.after_model_callbacks.clone();
397        let _before_tool_callbacks = self.before_tool_callbacks.clone();
398        let _after_tool_callbacks = self.after_tool_callbacks.clone();
399
400        let s = stream! {
401            // ===== BEFORE AGENT CALLBACKS =====
402            // Execute before the agent starts running
403            // If any returns content, skip agent execution
404            for callback in before_agent_callbacks.as_ref() {
405                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
406                    Ok(Some(content)) => {
407                        // Callback returned content - yield it and skip agent execution
408                        let mut early_event = Event::new(&invocation_id);
409                        early_event.author = agent_name.clone();
410                        early_event.llm_response.content = Some(content);
411                        yield Ok(early_event);
412
413                        // Skip rest of agent execution and go to after callbacks
414                        for after_callback in after_agent_callbacks.as_ref() {
415                            match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
416                                Ok(Some(after_content)) => {
417                                    let mut after_event = Event::new(&invocation_id);
418                                    after_event.author = agent_name.clone();
419                                    after_event.llm_response.content = Some(after_content);
420                                    yield Ok(after_event);
421                                    return;
422                                }
423                                Ok(None) => continue,
424                                Err(e) => {
425                                    yield Err(e);
426                                    return;
427                                }
428                            }
429                        }
430                        return;
431                    }
432                    Ok(None) => {
433                        // Continue to next callback
434                        continue;
435                    }
436                    Err(e) => {
437                        // Callback failed - propagate error
438                        yield Err(e);
439                        return;
440                    }
441                }
442            }
443
444            // ===== MAIN AGENT EXECUTION =====
445            let mut conversation_history = Vec::new();
446
447            // ===== PROCESS GLOBAL INSTRUCTION =====
448            // GlobalInstruction provides tree-wide personality/identity
449            if let Some(provider) = &global_instruction_provider {
450                // Dynamic global instruction via provider
451                let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
452                if !global_inst.is_empty() {
453                    conversation_history.push(Content {
454                        role: "user".to_string(),
455                        parts: vec![Part::Text { text: global_inst }],
456                    });
457                }
458            } else if let Some(ref template) = global_instruction {
459                // Static global instruction with template injection
460                let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
461                if !processed.is_empty() {
462                    conversation_history.push(Content {
463                        role: "user".to_string(),
464                        parts: vec![Part::Text { text: processed }],
465                    });
466                }
467            }
468
469            // ===== PROCESS AGENT INSTRUCTION =====
470            // Agent-specific instruction
471            if let Some(provider) = &instruction_provider {
472                // Dynamic instruction via provider
473                let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
474                if !inst.is_empty() {
475                    conversation_history.push(Content {
476                        role: "user".to_string(),
477                        parts: vec![Part::Text { text: inst }],
478                    });
479                }
480            } else if let Some(ref template) = instruction {
481                // Static instruction with template injection
482                let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
483                if !processed.is_empty() {
484                    conversation_history.push(Content {
485                        role: "user".to_string(),
486                        parts: vec![Part::Text { text: processed }],
487                    });
488                }
489            }
490
491            // ===== LOAD SESSION HISTORY =====
492            // Load previous conversation turns from the session
493            // NOTE: Session history already includes the current user message (added by Runner before agent runs)
494            let session_history = ctx.session().conversation_history();
495            conversation_history.extend(session_history);
496
497            // ===== APPLY INCLUDE_CONTENTS FILTERING =====
498            // Control what conversation history the agent sees
499            let mut conversation_history = match include_contents {
500                adk_core::IncludeContents::None => {
501                    // Agent operates solely on current turn - only keep the latest user input
502                    // Remove all previous history except instructions and current user message
503                    let mut filtered = Vec::new();
504
505                    // Keep global and agent instructions (already added above)
506                    let instruction_count = conversation_history.iter()
507                        .take_while(|c| c.role == "user" && c.parts.iter().any(|p| {
508                            if let Part::Text { text } = p {
509                                // These are likely instructions, not user queries
510                                !text.is_empty()
511                            } else {
512                                false
513                            }
514                        }))
515                        .count();
516
517                    // Take instructions
518                    filtered.extend(conversation_history.iter().take(instruction_count).cloned());
519
520                    // Take only the last user message (current turn)
521                    if let Some(last) = conversation_history.last() {
522                        if last.role == "user" {
523                            filtered.push(last.clone());
524                        }
525                    }
526
527                    filtered
528                }
529                adk_core::IncludeContents::Default => {
530                    // Default behavior - keep full conversation history
531                    conversation_history
532                }
533            };
534
535            // Build tool declarations for Gemini
536            // Uses enhanced_description() which includes NOTE for long-running tools
537            let mut tool_declarations = std::collections::HashMap::new();
538            for tool in &tools {
539                // Build FunctionDeclaration JSON with enhanced description
540                // For long-running tools, this includes a warning not to call again if pending
541                let mut decl = serde_json::json!({
542                    "name": tool.name(),
543                    "description": tool.enhanced_description(),
544                });
545
546                if let Some(params) = tool.parameters_schema() {
547                    decl["parameters"] = params;
548                }
549
550                if let Some(response) = tool.response_schema() {
551                    decl["response"] = response;
552                }
553
554                tool_declarations.insert(tool.name().to_string(), decl);
555            }
556
557            // Inject transfer_to_agent tool if sub-agents exist
558            if !sub_agents.is_empty() {
559                let transfer_tool_name = "transfer_to_agent";
560                let transfer_tool_decl = serde_json::json!({
561                    "name": transfer_tool_name,
562                    "description": "Transfer execution to another agent.",
563                    "parameters": {
564                        "type": "object",
565                        "properties": {
566                            "agent_name": {
567                                "type": "string",
568                                "description": "The name of the agent to transfer to."
569                            }
570                        },
571                        "required": ["agent_name"]
572                    }
573                });
574                tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
575            }
576
577
578            // Multi-turn loop with max iterations
579            let max_iterations = 10;
580            let mut iteration = 0;
581
582            loop {
583                iteration += 1;
584                if iteration > max_iterations {
585                    yield Err(adk_core::AdkError::Agent(
586                        format!("Max iterations ({}) exceeded", max_iterations)
587                    ));
588                    return;
589                }
590
591                // Build request with conversation history
592                let config = output_schema.as_ref().map(|schema| {
593                    adk_core::GenerateContentConfig {
594                        temperature: None,
595                        top_p: None,
596                        top_k: None,
597                        max_output_tokens: None,
598                        response_schema: Some(schema.clone()),
599                    }
600                });
601
602                let request = LlmRequest {
603                    model: model.name().to_string(),
604                    contents: conversation_history.clone(),
605                    tools: tool_declarations.clone(),
606                    config,
607                };
608
609                // ===== BEFORE MODEL CALLBACKS =====
610                // These can modify the request or skip the model call by returning a response
611                let mut current_request = request;
612                let mut model_response_override = None;
613                for callback in before_model_callbacks.as_ref() {
614                    match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
615                        Ok(BeforeModelResult::Continue(modified_request)) => {
616                            // Callback may have modified the request, continue with it
617                            current_request = modified_request;
618                        }
619                        Ok(BeforeModelResult::Skip(response)) => {
620                            // Callback returned a response - skip model call
621                            model_response_override = Some(response);
622                            break;
623                        }
624                        Err(e) => {
625                            // Callback failed - propagate error
626                            yield Err(e);
627                            return;
628                        }
629                    }
630                }
631                let request = current_request;
632
633                // Determine streaming source: cached response or real model
634                let mut accumulated_content: Option<Content> = None;
635
636                if let Some(cached_response) = model_response_override {
637                    // Use callback-provided response (e.g., from cache)
638                    // Yield it as an event
639                    let mut cached_event = Event::new(&invocation_id);
640                    cached_event.author = agent_name.clone();
641                    cached_event.llm_response.content = cached_response.content.clone();
642                    cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
643                    cached_event.gcp_llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
644                    cached_event.gcp_llm_response = Some(serde_json::to_string(&cached_response).unwrap_or_default());
645
646                    // Populate long_running_tool_ids for function calls from long-running tools
647                    if let Some(ref content) = cached_response.content {
648                        let long_running_ids: Vec<String> = content.parts.iter()
649                            .filter_map(|p| {
650                                if let Part::FunctionCall { name, .. } = p {
651                                    if let Some(tool) = tools.iter().find(|t| t.name() == name) {
652                                        if tool.is_long_running() {
653                                            return Some(name.clone());
654                                        }
655                                    }
656                                }
657                                None
658                            })
659                            .collect();
660                        cached_event.long_running_tool_ids = long_running_ids;
661                    }
662
663                    yield Ok(cached_event);
664
665                    accumulated_content = cached_response.content;
666                } else {
667                    // Record LLM request for tracing
668                    let request_json = serde_json::to_string(&request).unwrap_or_default();
669
670                    // Create call_llm span with GCP attributes (works for all model types)
671                    let llm_ts = std::time::SystemTime::now()
672                        .duration_since(std::time::UNIX_EPOCH)
673                        .unwrap_or_default()
674                        .as_nanos();
675                    let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
676                    let llm_span = tracing::info_span!(
677                        "call_llm",
678                        "gcp.vertex.agent.event_id" = %llm_event_id,
679                        "gcp.vertex.agent.invocation_id" = %invocation_id,
680                        "gcp.vertex.agent.session_id" = %ctx.session_id(),
681                        "gcp.vertex.agent.llm_request" = %request_json,
682                        "gcp.vertex.agent.llm_response" = tracing::field::Empty  // Placeholder for later recording
683                    );
684                    let _llm_guard = llm_span.enter();
685
686                    // Check streaming mode from run config
687                    use adk_core::StreamingMode;
688                    let streaming_mode = ctx.run_config().streaming_mode;
689                    let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi);
690
691                    // Always use streaming internally for LLM calls
692                    let mut response_stream = model.generate_content(request, true).await?;
693
694                    use futures::StreamExt;
695
696                    // Track last chunk for final event metadata (used in None mode)
697                    let mut last_chunk: Option<LlmResponse> = None;
698
699                    // Stream and process chunks with AfterModel callbacks
700                    while let Some(chunk_result) = response_stream.next().await {
701                        let mut chunk = match chunk_result {
702                            Ok(c) => c,
703                            Err(e) => {
704                                yield Err(e);
705                                return;
706                            }
707                        };
708
709                        // ===== AFTER MODEL CALLBACKS (per chunk) =====
710                        // Callbacks can modify each streaming chunk
711                        for callback in after_model_callbacks.as_ref() {
712                            match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
713                                Ok(Some(modified_chunk)) => {
714                                    // Callback modified this chunk
715                                    chunk = modified_chunk;
716                                    break;
717                                }
718                                Ok(None) => {
719                                    // Continue to next callback
720                                    continue;
721                                }
722                                Err(e) => {
723                                    // Callback failed - propagate error
724                                    yield Err(e);
725                                    return;
726                                }
727                            }
728                        }
729
730                        // Accumulate content for conversation history (always needed)
731                        if let Some(chunk_content) = chunk.content.clone() {
732                            if let Some(ref mut acc) = accumulated_content {
733                                acc.parts.extend(chunk_content.parts);
734                            } else {
735                                accumulated_content = Some(chunk_content);
736                            }
737                        }
738
739                        // For SSE/Bidi mode: yield each chunk immediately with stable event ID
740                        if should_stream_to_client {
741                            let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
742                            partial_event.author = agent_name.clone();
743                            partial_event.llm_request = Some(request_json.clone());
744                            partial_event.gcp_llm_request = Some(request_json.clone());
745                            partial_event.gcp_llm_response = Some(serde_json::to_string(&chunk).unwrap_or_default());
746                            partial_event.llm_response.partial = chunk.partial;
747                            partial_event.llm_response.turn_complete = chunk.turn_complete;
748                            partial_event.llm_response.finish_reason = chunk.finish_reason;
749                            partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
750                            partial_event.llm_response.content = chunk.content.clone();
751
752                            // Populate long_running_tool_ids
753                            if let Some(ref content) = chunk.content {
754                                let long_running_ids: Vec<String> = content.parts.iter()
755                                    .filter_map(|p| {
756                                        if let Part::FunctionCall { name, .. } = p {
757                                            if let Some(tool) = tools.iter().find(|t| t.name() == name) {
758                                                if tool.is_long_running() {
759                                                    return Some(name.clone());
760                                                }
761                                            }
762                                        }
763                                        None
764                                    })
765                                    .collect();
766                                partial_event.long_running_tool_ids = long_running_ids;
767                            }
768
769                            yield Ok(partial_event);
770                        }
771
772                        // Store last chunk for final event metadata
773                        last_chunk = Some(chunk.clone());
774
775                        // Check if turn is complete
776                        if chunk.turn_complete {
777                            break;
778                        }
779                    }
780
781                    // For None mode: yield single final event with accumulated content
782                    if !should_stream_to_client {
783                        let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
784                        final_event.author = agent_name.clone();
785                        final_event.llm_request = Some(request_json.clone());
786                        final_event.gcp_llm_request = Some(request_json.clone());
787                        final_event.llm_response.content = accumulated_content.clone();
788                        final_event.llm_response.partial = false;
789                        final_event.llm_response.turn_complete = true;
790
791                        // Copy metadata from last chunk
792                        if let Some(ref last) = last_chunk {
793                            final_event.llm_response.finish_reason = last.finish_reason;
794                            final_event.llm_response.usage_metadata = last.usage_metadata.clone();
795                            final_event.gcp_llm_response = Some(serde_json::to_string(last).unwrap_or_default());
796                        }
797
798                        // Populate long_running_tool_ids
799                        if let Some(ref content) = accumulated_content {
800                            let long_running_ids: Vec<String> = content.parts.iter()
801                                .filter_map(|p| {
802                                    if let Part::FunctionCall { name, .. } = p {
803                                        if let Some(tool) = tools.iter().find(|t| t.name() == name) {
804                                            if tool.is_long_running() {
805                                                return Some(name.clone());
806                                            }
807                                        }
808                                    }
809                                    None
810                                })
811                                .collect();
812                            final_event.long_running_tool_ids = long_running_ids;
813                        }
814
815                        yield Ok(final_event);
816                    }
817
818                    // Record LLM response to span before guard drops
819                    if let Some(ref content) = accumulated_content {
820                        let response_json = serde_json::to_string(content).unwrap_or_default();
821                        llm_span.record("gcp.vertex.agent.llm_response", &response_json);
822                    }
823                }
824
825                // After streaming/caching completes, check for function calls in accumulated content
826                let function_call_names: Vec<String> = accumulated_content.as_ref()
827                    .map(|c| c.parts.iter()
828                        .filter_map(|p| {
829                            if let Part::FunctionCall { name, .. } = p {
830                                Some(name.clone())
831                            } else {
832                                None
833                            }
834                        })
835                        .collect())
836                    .unwrap_or_default();
837
838                let has_function_calls = !function_call_names.is_empty();
839
840                // Check if ALL function calls are from long-running tools
841                // If so, we should NOT continue the loop - the tool returned a pending status
842                // and the agent/client will poll for completion later
843                let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
844                    tools.iter()
845                        .find(|t| t.name() == name)
846                        .map(|t| t.is_long_running())
847                        .unwrap_or(false)
848                });
849
850                // Add final content to history
851                if let Some(ref content) = accumulated_content {
852                    conversation_history.push(content.clone());
853
854                    // Handle output_key: save final agent output to state_delta
855                    if let Some(ref output_key) = output_key {
856                        if !has_function_calls {  // Only save if not calling tools
857                            let mut text_parts = String::new();
858                            for part in &content.parts {
859                                if let Part::Text { text } = part {
860                                    text_parts.push_str(text);
861                                }
862                            }
863                            if !text_parts.is_empty() {
864                                // Yield a final state update event
865                                let mut state_event = Event::new(&invocation_id);
866                                state_event.author = agent_name.clone();
867                                state_event.actions.state_delta.insert(
868                                    output_key.clone(),
869                                    serde_json::Value::String(text_parts),
870                                );
871                                yield Ok(state_event);
872                            }
873                        }
874                    }
875                }
876
877                if !has_function_calls {
878                    // No function calls, we're done
879                    // Record LLM response for tracing
880                    if let Some(ref content) = accumulated_content {
881                        let response_json = serde_json::to_string(content).unwrap_or_default();
882                        tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
883                    }
884
885                    tracing::info!(agent.name = %agent_name, "Agent execution complete");
886                    break;
887                }
888
889                // Execute function calls and add responses to history
890                if let Some(content) = &accumulated_content {
891                    for part in &content.parts {
892                        if let Part::FunctionCall { name, args, id } = part {
893                            // Handle transfer_to_agent specially
894                            if name == "transfer_to_agent" {
895                                let target_agent = args.get("agent_name")
896                                    .and_then(|v| v.as_str())
897                                    .unwrap_or_default()
898                                    .to_string();
899
900                                let mut transfer_event = Event::new(&invocation_id);
901                                transfer_event.author = agent_name.clone();
902                                transfer_event.actions.transfer_to_agent = Some(target_agent);
903
904                                yield Ok(transfer_event);
905                                return;
906                            }
907
908
909                            // Find and execute tool
910                            let (tool_result, tool_actions) = if let Some(tool) = tools.iter().find(|t| t.name() == name) {
911                                // ✅ Use AgentToolContext that preserves parent context
912                                let tool_ctx: Arc<dyn ToolContext> = Arc::new(AgentToolContext::new(
913                                    ctx.clone(),
914                                    format!("{}_{}", invocation_id, name),
915                                ));
916
917                                // Create span name following adk-go pattern: "execute_tool {name}"
918                                let span_name = format!("execute_tool {}", name);
919                                let tool_span = tracing::info_span!(
920                                    "",
921                                    otel.name = %span_name,
922                                    tool.name = %name,
923                                    "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
924                                    "gcp.vertex.agent.invocation_id" = %invocation_id,
925                                    "gcp.vertex.agent.session_id" = %ctx.session_id()
926                                );
927
928                                // Use instrument() for proper async span handling
929                                let result = async {
930                                    tracing::info!(tool.name = %name, tool.args = %args, "tool_call");
931                                    match tool.execute(tool_ctx.clone(), args.clone()).await {
932                                        Ok(result) => {
933                                            tracing::info!(tool.name = %name, tool.result = %result, "tool_result");
934                                            result
935                                        }
936                                        Err(e) => {
937                                            tracing::warn!(tool.name = %name, error = %e, "tool_error");
938                                            serde_json::json!({ "error": e.to_string() })
939                                        }
940                                    }
941                                }.instrument(tool_span).await;
942
943                                (result, tool_ctx.actions())
944                            } else {
945                                (serde_json::json!({ "error": format!("Tool {} not found", name) }), EventActions::default())
946                            };
947
948                            // Yield tool execution event
949                            let mut tool_event = Event::new(&invocation_id);
950                            tool_event.author = agent_name.clone();
951                            tool_event.actions = tool_actions.clone();
952                            tool_event.llm_response.content = Some(Content {
953                                role: "function".to_string(),
954                                parts: vec![Part::FunctionResponse {
955                                    function_response: FunctionResponseData {
956                                        name: name.clone(),
957                                        response: tool_result.clone(),
958                                    },
959                                    id: id.clone(),
960                                }],
961                            });
962                            yield Ok(tool_event);
963
964                            // Check if tool requested escalation or skip_summarization
965                            if tool_actions.escalate || tool_actions.skip_summarization {
966                                // Tool wants to terminate agent loop
967                                return;
968                            }
969
970                            // Add function response to history
971                            conversation_history.push(Content {
972                                role: "function".to_string(),
973                                parts: vec![Part::FunctionResponse {
974                                    function_response: FunctionResponseData {
975                                        name: name.clone(),
976                                        response: tool_result,
977                                    },
978                                    id: id.clone(),
979                                }],
980                            });
981                        }
982                    }
983                }
984
985                // If all function calls were from long-running tools, we need ONE more model call
986                // to let the model generate a user-friendly response about the pending task
987                // But we mark this as the final iteration to prevent infinite loops
988                if all_calls_are_long_running {
989                    // Continue to next iteration for model to respond, but this will be the last
990                    // The model will see the tool response and generate text like "Started task X..."
991                    // On next iteration, there won't be function calls, so we'll break naturally
992                }
993            }
994
995            // ===== AFTER AGENT CALLBACKS =====
996            // Execute after the agent completes
997            for callback in after_agent_callbacks.as_ref() {
998                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
999                    Ok(Some(content)) => {
1000                        // Callback returned content - yield it
1001                        let mut after_event = Event::new(&invocation_id);
1002                        after_event.author = agent_name.clone();
1003                        after_event.llm_response.content = Some(content);
1004                        yield Ok(after_event);
1005                        break; // First callback that returns content wins
1006                    }
1007                    Ok(None) => {
1008                        // Continue to next callback
1009                        continue;
1010                    }
1011                    Err(e) => {
1012                        // Callback failed - propagate error
1013                        yield Err(e);
1014                        return;
1015                    }
1016                }
1017            }
1018        };
1019
1020        Ok(Box::pin(s))
1021    }
1022}