adk_agent/
llm_agent.rs

1use adk_core::{
2    AfterAgentCallback, AfterModelCallback, AfterToolCallback, Agent, BeforeAgentCallback,
3    BeforeModelCallback, BeforeModelResult, BeforeToolCallback, CallbackContext, Content, Event,
4    EventActions, GlobalInstructionProvider, InstructionProvider, InvocationContext, Llm,
5    LlmRequest, MemoryEntry, Part, ReadonlyContext, Result, Tool, ToolContext,
6};
7use async_stream::stream;
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex};
10
11use crate::guardrails::GuardrailSet;
12
13pub struct LlmAgent {
14    name: String,
15    description: String,
16    model: Arc<dyn Llm>,
17    instruction: Option<String>,
18    instruction_provider: Option<Arc<InstructionProvider>>,
19    global_instruction: Option<String>,
20    global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
21    #[allow(dead_code)] // Part of public API via builder
22    input_schema: Option<serde_json::Value>,
23    output_schema: Option<serde_json::Value>,
24    #[allow(dead_code)] // Part of public API via builder
25    disallow_transfer_to_parent: bool,
26    #[allow(dead_code)] // Part of public API via builder
27    disallow_transfer_to_peers: bool,
28    include_contents: adk_core::IncludeContents,
29    tools: Vec<Arc<dyn Tool>>,
30    sub_agents: Vec<Arc<dyn Agent>>,
31    output_key: Option<String>,
32    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
33    after_callbacks: Arc<Vec<AfterAgentCallback>>,
34    before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
35    after_model_callbacks: Arc<Vec<AfterModelCallback>>,
36    before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
37    after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
38    #[allow(dead_code)] // Used when guardrails feature is enabled
39    input_guardrails: GuardrailSet,
40    #[allow(dead_code)] // Used when guardrails feature is enabled
41    output_guardrails: GuardrailSet,
42}
43
44impl std::fmt::Debug for LlmAgent {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("LlmAgent")
47            .field("name", &self.name)
48            .field("description", &self.description)
49            .field("model", &self.model.name())
50            .field("instruction", &self.instruction)
51            .field("tools_count", &self.tools.len())
52            .field("sub_agents_count", &self.sub_agents.len())
53            .finish()
54    }
55}
56
57pub struct LlmAgentBuilder {
58    name: String,
59    description: Option<String>,
60    model: Option<Arc<dyn Llm>>,
61    instruction: Option<String>,
62    instruction_provider: Option<Arc<InstructionProvider>>,
63    global_instruction: Option<String>,
64    global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
65    input_schema: Option<serde_json::Value>,
66    output_schema: Option<serde_json::Value>,
67    disallow_transfer_to_parent: bool,
68    disallow_transfer_to_peers: bool,
69    include_contents: adk_core::IncludeContents,
70    tools: Vec<Arc<dyn Tool>>,
71    sub_agents: Vec<Arc<dyn Agent>>,
72    output_key: Option<String>,
73    before_callbacks: Vec<BeforeAgentCallback>,
74    after_callbacks: Vec<AfterAgentCallback>,
75    before_model_callbacks: Vec<BeforeModelCallback>,
76    after_model_callbacks: Vec<AfterModelCallback>,
77    before_tool_callbacks: Vec<BeforeToolCallback>,
78    after_tool_callbacks: Vec<AfterToolCallback>,
79    input_guardrails: GuardrailSet,
80    output_guardrails: GuardrailSet,
81}
82
83impl LlmAgentBuilder {
84    pub fn new(name: impl Into<String>) -> Self {
85        Self {
86            name: name.into(),
87            description: None,
88            model: None,
89            instruction: None,
90            instruction_provider: None,
91            global_instruction: None,
92            global_instruction_provider: None,
93            input_schema: None,
94            output_schema: None,
95            disallow_transfer_to_parent: false,
96            disallow_transfer_to_peers: false,
97            include_contents: adk_core::IncludeContents::Default,
98            tools: Vec::new(),
99            sub_agents: Vec::new(),
100            output_key: None,
101            before_callbacks: Vec::new(),
102            after_callbacks: Vec::new(),
103            before_model_callbacks: Vec::new(),
104            after_model_callbacks: Vec::new(),
105            before_tool_callbacks: Vec::new(),
106            after_tool_callbacks: Vec::new(),
107            input_guardrails: GuardrailSet::default(),
108            output_guardrails: GuardrailSet::default(),
109        }
110    }
111
112    pub fn description(mut self, desc: impl Into<String>) -> Self {
113        self.description = Some(desc.into());
114        self
115    }
116
117    pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
118        self.model = Some(model);
119        self
120    }
121
122    pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
123        self.instruction = Some(instruction.into());
124        self
125    }
126
127    pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
128        self.instruction_provider = Some(Arc::new(provider));
129        self
130    }
131
132    pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
133        self.global_instruction = Some(instruction.into());
134        self
135    }
136
137    pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
138        self.global_instruction_provider = Some(Arc::new(provider));
139        self
140    }
141
142    pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
143        self.input_schema = Some(schema);
144        self
145    }
146
147    pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
148        self.output_schema = Some(schema);
149        self
150    }
151
152    pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
153        self.disallow_transfer_to_parent = disallow;
154        self
155    }
156
157    pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
158        self.disallow_transfer_to_peers = disallow;
159        self
160    }
161
162    pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
163        self.include_contents = include;
164        self
165    }
166
167    pub fn output_key(mut self, key: impl Into<String>) -> Self {
168        self.output_key = Some(key.into());
169        self
170    }
171
172    pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
173        self.tools.push(tool);
174        self
175    }
176
177    pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
178        self.sub_agents.push(agent);
179        self
180    }
181
182    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
183        self.before_callbacks.push(callback);
184        self
185    }
186
187    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
188        self.after_callbacks.push(callback);
189        self
190    }
191
192    pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
193        self.before_model_callbacks.push(callback);
194        self
195    }
196
197    pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
198        self.after_model_callbacks.push(callback);
199        self
200    }
201
202    pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
203        self.before_tool_callbacks.push(callback);
204        self
205    }
206
207    pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
208        self.after_tool_callbacks.push(callback);
209        self
210    }
211
212    /// Set input guardrails to validate user input before processing.
213    ///
214    /// Input guardrails run before the agent processes the request and can:
215    /// - Block harmful or off-topic content
216    /// - Redact PII from user input
217    /// - Enforce input length limits
218    ///
219    /// Requires the `guardrails` feature.
220    pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
221        self.input_guardrails = guardrails;
222        self
223    }
224
225    /// Set output guardrails to validate agent responses.
226    ///
227    /// Output guardrails run after the agent generates a response and can:
228    /// - Enforce JSON schema compliance
229    /// - Redact PII from responses
230    /// - Block harmful content in responses
231    ///
232    /// Requires the `guardrails` feature.
233    pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
234        self.output_guardrails = guardrails;
235        self
236    }
237
238    pub fn build(self) -> Result<LlmAgent> {
239        let model =
240            self.model.ok_or_else(|| adk_core::AdkError::Agent("Model is required".to_string()))?;
241
242        Ok(LlmAgent {
243            name: self.name,
244            description: self.description.unwrap_or_default(),
245            model,
246            instruction: self.instruction,
247            instruction_provider: self.instruction_provider,
248            global_instruction: self.global_instruction,
249            global_instruction_provider: self.global_instruction_provider,
250            input_schema: self.input_schema,
251            output_schema: self.output_schema,
252            disallow_transfer_to_parent: self.disallow_transfer_to_parent,
253            disallow_transfer_to_peers: self.disallow_transfer_to_peers,
254            include_contents: self.include_contents,
255            tools: self.tools,
256            sub_agents: self.sub_agents,
257            output_key: self.output_key,
258            before_callbacks: Arc::new(self.before_callbacks),
259            after_callbacks: Arc::new(self.after_callbacks),
260            before_model_callbacks: Arc::new(self.before_model_callbacks),
261            after_model_callbacks: Arc::new(self.after_model_callbacks),
262            before_tool_callbacks: Arc::new(self.before_tool_callbacks),
263            after_tool_callbacks: Arc::new(self.after_tool_callbacks),
264            input_guardrails: self.input_guardrails,
265            output_guardrails: self.output_guardrails,
266        })
267    }
268}
269
270// AgentToolContext wraps the parent InvocationContext and preserves all context
271// instead of throwing it away like SimpleToolContext did
272struct AgentToolContext {
273    parent_ctx: Arc<dyn InvocationContext>,
274    function_call_id: String,
275    actions: Mutex<EventActions>,
276}
277
278impl AgentToolContext {
279    fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
280        Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
281    }
282}
283
284#[async_trait]
285impl ReadonlyContext for AgentToolContext {
286    fn invocation_id(&self) -> &str {
287        self.parent_ctx.invocation_id()
288    }
289
290    fn agent_name(&self) -> &str {
291        self.parent_ctx.agent_name()
292    }
293
294    fn user_id(&self) -> &str {
295        // ✅ Delegate to parent - now tools get the real user_id!
296        self.parent_ctx.user_id()
297    }
298
299    fn app_name(&self) -> &str {
300        // ✅ Delegate to parent - now tools get the real app_name!
301        self.parent_ctx.app_name()
302    }
303
304    fn session_id(&self) -> &str {
305        // ✅ Delegate to parent - now tools get the real session_id!
306        self.parent_ctx.session_id()
307    }
308
309    fn branch(&self) -> &str {
310        self.parent_ctx.branch()
311    }
312
313    fn user_content(&self) -> &Content {
314        self.parent_ctx.user_content()
315    }
316}
317
318#[async_trait]
319impl CallbackContext for AgentToolContext {
320    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
321        // ✅ Delegate to parent - tools can now access artifacts!
322        self.parent_ctx.artifacts()
323    }
324}
325
326#[async_trait]
327impl ToolContext for AgentToolContext {
328    fn function_call_id(&self) -> &str {
329        &self.function_call_id
330    }
331
332    fn actions(&self) -> EventActions {
333        self.actions.lock().unwrap().clone()
334    }
335
336    fn set_actions(&self, actions: EventActions) {
337        *self.actions.lock().unwrap() = actions;
338    }
339
340    async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
341        // ✅ Delegate to parent's memory if available
342        if let Some(memory) = self.parent_ctx.memory() {
343            memory.search(query).await
344        } else {
345            Ok(vec![])
346        }
347    }
348}
349
350#[async_trait]
351impl Agent for LlmAgent {
352    fn name(&self) -> &str {
353        &self.name
354    }
355
356    fn description(&self) -> &str {
357        &self.description
358    }
359
360    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
361        &self.sub_agents
362    }
363
364    #[adk_telemetry::instrument(
365        skip(self, ctx),
366        fields(
367            agent.name = %self.name,
368            agent.description = %self.description,
369            invocation.id = %ctx.invocation_id(),
370            user.id = %ctx.user_id(),
371            session.id = %ctx.session_id()
372        )
373    )]
374    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
375        adk_telemetry::info!("Starting agent execution");
376
377        let agent_name = self.name.clone();
378        let invocation_id = ctx.invocation_id().to_string();
379        let model = self.model.clone();
380        let tools = self.tools.clone();
381        let sub_agents = self.sub_agents.clone();
382
383        let instruction = self.instruction.clone();
384        let instruction_provider = self.instruction_provider.clone();
385        let global_instruction = self.global_instruction.clone();
386        let global_instruction_provider = self.global_instruction_provider.clone();
387        let output_key = self.output_key.clone();
388        let output_schema = self.output_schema.clone();
389        let include_contents = self.include_contents;
390        // Clone Arc references (cheap)
391        let before_agent_callbacks = self.before_callbacks.clone();
392        let after_agent_callbacks = self.after_callbacks.clone();
393        let before_model_callbacks = self.before_model_callbacks.clone();
394        let after_model_callbacks = self.after_model_callbacks.clone();
395        let _before_tool_callbacks = self.before_tool_callbacks.clone();
396        let _after_tool_callbacks = self.after_tool_callbacks.clone();
397
398        let s = stream! {
399            // ===== BEFORE AGENT CALLBACKS =====
400            // Execute before the agent starts running
401            // If any returns content, skip agent execution
402            for callback in before_agent_callbacks.as_ref() {
403                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
404                    Ok(Some(content)) => {
405                        // Callback returned content - yield it and skip agent execution
406                        let mut early_event = Event::new(&invocation_id);
407                        early_event.author = agent_name.clone();
408                        early_event.llm_response.content = Some(content);
409                        yield Ok(early_event);
410
411                        // Skip rest of agent execution and go to after callbacks
412                        for after_callback in after_agent_callbacks.as_ref() {
413                            match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
414                                Ok(Some(after_content)) => {
415                                    let mut after_event = Event::new(&invocation_id);
416                                    after_event.author = agent_name.clone();
417                                    after_event.llm_response.content = Some(after_content);
418                                    yield Ok(after_event);
419                                    return;
420                                }
421                                Ok(None) => continue,
422                                Err(e) => {
423                                    yield Err(e);
424                                    return;
425                                }
426                            }
427                        }
428                        return;
429                    }
430                    Ok(None) => {
431                        // Continue to next callback
432                        continue;
433                    }
434                    Err(e) => {
435                        // Callback failed - propagate error
436                        yield Err(e);
437                        return;
438                    }
439                }
440            }
441
442            // ===== MAIN AGENT EXECUTION =====
443            let mut conversation_history = Vec::new();
444
445            // ===== PROCESS GLOBAL INSTRUCTION =====
446            // GlobalInstruction provides tree-wide personality/identity
447            if let Some(provider) = &global_instruction_provider {
448                // Dynamic global instruction via provider
449                let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
450                if !global_inst.is_empty() {
451                    conversation_history.push(Content {
452                        role: "user".to_string(),
453                        parts: vec![Part::Text { text: global_inst }],
454                    });
455                }
456            } else if let Some(ref template) = global_instruction {
457                // Static global instruction with template injection
458                let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
459                if !processed.is_empty() {
460                    conversation_history.push(Content {
461                        role: "user".to_string(),
462                        parts: vec![Part::Text { text: processed }],
463                    });
464                }
465            }
466
467            // ===== PROCESS AGENT INSTRUCTION =====
468            // Agent-specific instruction
469            if let Some(provider) = &instruction_provider {
470                // Dynamic instruction via provider
471                let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
472                if !inst.is_empty() {
473                    conversation_history.push(Content {
474                        role: "user".to_string(),
475                        parts: vec![Part::Text { text: inst }],
476                    });
477                }
478            } else if let Some(ref template) = instruction {
479                // Static instruction with template injection
480                let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
481                if !processed.is_empty() {
482                    conversation_history.push(Content {
483                        role: "user".to_string(),
484                        parts: vec![Part::Text { text: processed }],
485                    });
486                }
487            }
488
489            // ===== LOAD SESSION HISTORY =====
490            // Load previous conversation turns from the session
491            let session_history = ctx.session().conversation_history();
492            conversation_history.extend(session_history);
493
494            // Add user content (current turn)
495            conversation_history.push(ctx.user_content().clone());
496
497            // ===== APPLY INCLUDE_CONTENTS FILTERING =====
498            // Control what conversation history the agent sees
499            let mut conversation_history = match include_contents {
500                adk_core::IncludeContents::None => {
501                    // Agent operates solely on current turn - only keep the latest user input
502                    // Remove all previous history except instructions and current user message
503                    let mut filtered = Vec::new();
504
505                    // Keep global and agent instructions (already added above)
506                    let instruction_count = conversation_history.iter()
507                        .take_while(|c| c.role == "user" && c.parts.iter().any(|p| {
508                            if let Part::Text { text } = p {
509                                // These are likely instructions, not user queries
510                                !text.is_empty()
511                            } else {
512                                false
513                            }
514                        }))
515                        .count();
516
517                    // Take instructions
518                    filtered.extend(conversation_history.iter().take(instruction_count).cloned());
519
520                    // Take only the last user message (current turn)
521                    if let Some(last) = conversation_history.last() {
522                        if last.role == "user" {
523                            filtered.push(last.clone());
524                        }
525                    }
526
527                    filtered
528                }
529                adk_core::IncludeContents::Default => {
530                    // Default behavior - keep full conversation history
531                    conversation_history
532                }
533            };
534
535            // Build tool declarations for Gemini
536            // Uses enhanced_description() which includes NOTE for long-running tools
537            let mut tool_declarations = std::collections::HashMap::new();
538            for tool in &tools {
539                // Build FunctionDeclaration JSON with enhanced description
540                // For long-running tools, this includes a warning not to call again if pending
541                let mut decl = serde_json::json!({
542                    "name": tool.name(),
543                    "description": tool.enhanced_description(),
544                });
545
546                if let Some(params) = tool.parameters_schema() {
547                    decl["parameters"] = params;
548                }
549
550                if let Some(response) = tool.response_schema() {
551                    decl["response"] = response;
552                }
553
554                tool_declarations.insert(tool.name().to_string(), decl);
555            }
556
557            // Inject transfer_to_agent tool if sub-agents exist
558            if !sub_agents.is_empty() {
559                let transfer_tool_name = "transfer_to_agent";
560                let transfer_tool_decl = serde_json::json!({
561                    "name": transfer_tool_name,
562                    "description": "Transfer execution to another agent.",
563                    "parameters": {
564                        "type": "object",
565                        "properties": {
566                            "agent_name": {
567                                "type": "string",
568                                "description": "The name of the agent to transfer to."
569                            }
570                        },
571                        "required": ["agent_name"]
572                    }
573                });
574                tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
575            }
576
577
578            // Multi-turn loop with max iterations
579            let max_iterations = 10;
580            let mut iteration = 0;
581
582            loop {
583                iteration += 1;
584                if iteration > max_iterations {
585                    yield Err(adk_core::AdkError::Agent(
586                        format!("Max iterations ({}) exceeded", max_iterations)
587                    ));
588                    return;
589                }
590
591                // Build request with conversation history
592                let config = output_schema.as_ref().map(|schema| {
593                    adk_core::GenerateContentConfig {
594                        temperature: None,
595                        top_p: None,
596                        top_k: None,
597                        max_output_tokens: None,
598                        response_schema: Some(schema.clone()),
599                    }
600                });
601
602                let request = LlmRequest {
603                    model: model.name().to_string(),
604                    contents: conversation_history.clone(),
605                    tools: tool_declarations.clone(),
606                    config,
607                };
608
609                // ===== BEFORE MODEL CALLBACKS =====
610                // These can modify the request or skip the model call by returning a response
611                let mut current_request = request;
612                let mut model_response_override = None;
613                for callback in before_model_callbacks.as_ref() {
614                    match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
615                        Ok(BeforeModelResult::Continue(modified_request)) => {
616                            // Callback may have modified the request, continue with it
617                            current_request = modified_request;
618                        }
619                        Ok(BeforeModelResult::Skip(response)) => {
620                            // Callback returned a response - skip model call
621                            model_response_override = Some(response);
622                            break;
623                        }
624                        Err(e) => {
625                            // Callback failed - propagate error
626                            yield Err(e);
627                            return;
628                        }
629                    }
630                }
631                let request = current_request;
632
633                // Determine streaming source: cached response or real model
634                let mut accumulated_content: Option<Content> = None;
635
636                if let Some(cached_response) = model_response_override {
637                    // Use callback-provided response (e.g., from cache)
638                    // Yield it as an event
639                    let mut cached_event = Event::new(&invocation_id);
640                    cached_event.author = agent_name.clone();
641                    cached_event.llm_response.content = cached_response.content.clone();
642
643                    // Populate long_running_tool_ids for function calls from long-running tools
644                    if let Some(ref content) = cached_response.content {
645                        let long_running_ids: Vec<String> = content.parts.iter()
646                            .filter_map(|p| {
647                                if let Part::FunctionCall { name, .. } = p {
648                                    if let Some(tool) = tools.iter().find(|t| t.name() == name) {
649                                        if tool.is_long_running() {
650                                            return Some(name.clone());
651                                        }
652                                    }
653                                }
654                                None
655                            })
656                            .collect();
657                        cached_event.long_running_tool_ids = long_running_ids;
658                    }
659
660                    yield Ok(cached_event);
661
662                    accumulated_content = cached_response.content;
663                } else {
664                    // Call model with STREAMING ENABLED
665                    let mut response_stream = model.generate_content(request, true).await?;
666
667                    use futures::StreamExt;
668
669                    // Stream and process chunks with AfterModel callbacks
670                    while let Some(chunk_result) = response_stream.next().await {
671                        let mut chunk = match chunk_result {
672                            Ok(c) => c,
673                            Err(e) => {
674                                yield Err(e);
675                                return;
676                            }
677                        };
678
679                        // ===== AFTER MODEL CALLBACKS (per chunk) =====
680                        // Callbacks can modify each streaming chunk
681                        for callback in after_model_callbacks.as_ref() {
682                            match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
683                                Ok(Some(modified_chunk)) => {
684                                    // Callback modified this chunk
685                                    chunk = modified_chunk;
686                                    break;
687                                }
688                                Ok(None) => {
689                                    // Continue to next callback
690                                    continue;
691                                }
692                                Err(e) => {
693                                    // Callback failed - propagate error
694                                    yield Err(e);
695                                    return;
696                                }
697                            }
698                        }
699
700                        // Yield the (possibly modified) partial event
701                        let mut partial_event = Event::new(&invocation_id);
702                        partial_event.author = agent_name.clone();
703                        partial_event.llm_response.content = chunk.content.clone();
704
705                        // Populate long_running_tool_ids for function calls from long-running tools
706                        if let Some(ref content) = chunk.content {
707                            let long_running_ids: Vec<String> = content.parts.iter()
708                                .filter_map(|p| {
709                                    if let Part::FunctionCall { name, .. } = p {
710                                        // Check if this tool is long-running
711                                        if let Some(tool) = tools.iter().find(|t| t.name() == name) {
712                                            if tool.is_long_running() {
713                                                // Use tool name as ID (we don't have explicit call IDs)
714                                                return Some(name.clone());
715                                            }
716                                        }
717                                    }
718                                    None
719                                })
720                                .collect();
721                            partial_event.long_running_tool_ids = long_running_ids;
722                        }
723
724                        yield Ok(partial_event);
725
726                        // Accumulate content for history
727                        if let Some(chunk_content) = chunk.content {
728                            if let Some(ref mut acc) = accumulated_content {
729                                // Merge parts from this chunk into accumulated content
730                                acc.parts.extend(chunk_content.parts);
731                            } else {
732                                // First chunk - initialize accumulator
733                                accumulated_content = Some(chunk_content);
734                            }
735                        }
736
737                        // Check if turn is complete
738                        if chunk.turn_complete {
739                            break;
740                        }
741                    }
742                }
743
744                // After streaming/caching completes, check for function calls in accumulated content
745                let function_call_names: Vec<String> = accumulated_content.as_ref()
746                    .map(|c| c.parts.iter()
747                        .filter_map(|p| {
748                            if let Part::FunctionCall { name, .. } = p {
749                                Some(name.clone())
750                            } else {
751                                None
752                            }
753                        })
754                        .collect())
755                    .unwrap_or_default();
756
757                let has_function_calls = !function_call_names.is_empty();
758
759                // Check if ALL function calls are from long-running tools
760                // If so, we should NOT continue the loop - the tool returned a pending status
761                // and the agent/client will poll for completion later
762                let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
763                    tools.iter()
764                        .find(|t| t.name() == name)
765                        .map(|t| t.is_long_running())
766                        .unwrap_or(false)
767                });
768
769                // Add final content to history
770                if let Some(ref content) = accumulated_content {
771                    conversation_history.push(content.clone());
772
773                    // Handle output_key: save final agent output to state_delta
774                    if let Some(ref output_key) = output_key {
775                        if !has_function_calls {  // Only save if not calling tools
776                            let mut text_parts = String::new();
777                            for part in &content.parts {
778                                if let Part::Text { text } = part {
779                                    text_parts.push_str(text);
780                                }
781                            }
782                            if !text_parts.is_empty() {
783                                // Yield a final state update event
784                                let mut state_event = Event::new(&invocation_id);
785                                state_event.author = agent_name.clone();
786                                state_event.actions.state_delta.insert(
787                                    output_key.clone(),
788                                    serde_json::Value::String(text_parts),
789                                );
790                                yield Ok(state_event);
791                            }
792                        }
793                    }
794                }
795
796                if !has_function_calls {
797                    // No function calls, we're done
798                    break;
799                }
800
801                // Execute function calls and add responses to history
802                if let Some(content) = &accumulated_content {
803                    for part in &content.parts {
804                        if let Part::FunctionCall { name, args, id } = part {
805                            // Handle transfer_to_agent specially
806                            if name == "transfer_to_agent" {
807                                let target_agent = args.get("agent_name")
808                                    .and_then(|v| v.as_str())
809                                    .unwrap_or_default()
810                                    .to_string();
811
812                                let mut transfer_event = Event::new(&invocation_id);
813                                transfer_event.author = agent_name.clone();
814                                transfer_event.actions.transfer_to_agent = Some(target_agent);
815
816                                yield Ok(transfer_event);
817                                return;
818                            }
819
820
821                            // Find and execute tool
822                            let (tool_result, tool_actions) = if let Some(tool) = tools.iter().find(|t| t.name() == name) {
823                                // ✅ Use AgentToolContext that preserves parent context
824                                let tool_ctx: Arc<dyn ToolContext> = Arc::new(AgentToolContext::new(
825                                    ctx.clone(),
826                                    format!("{}_{}", invocation_id, name),
827                                ));
828
829                                let result = match tool.execute(tool_ctx.clone(), args.clone()).await {
830                                    Ok(result) => result,
831                                    Err(e) => serde_json::json!({ "error": e.to_string() }),
832                                };
833
834                                (result, tool_ctx.actions())
835                            } else {
836                                (serde_json::json!({ "error": format!("Tool {} not found", name) }), EventActions::default())
837                            };
838
839                            // Yield tool execution event
840                            let mut tool_event = Event::new(&invocation_id);
841                            tool_event.author = agent_name.clone();
842                            tool_event.actions = tool_actions.clone();
843                            tool_event.llm_response.content = Some(Content {
844                                role: "function".to_string(),
845                                parts: vec![Part::FunctionResponse {
846                                    name: name.clone(),
847                                    response: tool_result.clone(),
848                                    id: id.clone(),
849                                }],
850                            });
851                            yield Ok(tool_event);
852
853                            // Check if tool requested escalation or skip_summarization
854                            if tool_actions.escalate || tool_actions.skip_summarization {
855                                // Tool wants to terminate agent loop
856                                return;
857                            }
858
859                            // Add function response to history
860                            conversation_history.push(Content {
861                                role: "function".to_string(),
862                                parts: vec![Part::FunctionResponse {
863                                    name: name.clone(),
864                                    response: tool_result,
865                                    id: id.clone(),
866                                }],
867                            });
868                        }
869                    }
870                }
871
872                // If all function calls were from long-running tools, treat as final response
873                // The tools have been executed and returned pending status - don't continue the loop
874                if all_calls_are_long_running {
875                    break;
876                }
877            }
878
879            // ===== AFTER AGENT CALLBACKS =====
880            // Execute after the agent completes
881            for callback in after_agent_callbacks.as_ref() {
882                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
883                    Ok(Some(content)) => {
884                        // Callback returned content - yield it
885                        let mut after_event = Event::new(&invocation_id);
886                        after_event.author = agent_name.clone();
887                        after_event.llm_response.content = Some(content);
888                        yield Ok(after_event);
889                        break; // First callback that returns content wins
890                    }
891                    Ok(None) => {
892                        // Continue to next callback
893                        continue;
894                    }
895                    Err(e) => {
896                        // Callback failed - propagate error
897                        yield Err(e);
898                        return;
899                    }
900                }
901            }
902        };
903
904        Ok(Box::pin(s))
905    }
906}