1use adk_core::{
2 AfterAgentCallback, AfterModelCallback, AfterToolCallback, AfterToolCallbackFull, Agent,
3 BeforeAgentCallback, BeforeModelCallback, BeforeModelResult, BeforeToolCallback,
4 CallbackContext, Content, Event, EventActions, FunctionResponseData, GlobalInstructionProvider,
5 InstructionProvider, InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry,
6 OnToolErrorCallback, Part, ReadonlyContext, Result, RetryBudget, Tool,
7 ToolConfirmationDecision, ToolConfirmationPolicy, ToolConfirmationRequest, ToolContext,
8 ToolOutcome, Toolset,
9};
10use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index, select_skill_prompt_block};
11use async_stream::stream;
12use async_trait::async_trait;
13use std::sync::{Arc, Mutex};
14use tracing::Instrument;
15
16use crate::{
17 guardrails::{GuardrailSet, enforce_guardrails},
18 tool_call_markup::normalize_option_content,
19 workflow::with_user_content_override,
20};
21
22pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
24
25pub const DEFAULT_TOOL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
27
28pub struct LlmAgent {
29 name: String,
30 description: String,
31 model: Arc<dyn Llm>,
32 instruction: Option<String>,
33 instruction_provider: Option<Arc<InstructionProvider>>,
34 global_instruction: Option<String>,
35 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
36 skills_index: Option<Arc<SkillIndex>>,
37 skill_policy: SelectionPolicy,
38 max_skill_chars: usize,
39 #[allow(dead_code)] input_schema: Option<serde_json::Value>,
41 output_schema: Option<serde_json::Value>,
42 disallow_transfer_to_parent: bool,
43 disallow_transfer_to_peers: bool,
44 include_contents: adk_core::IncludeContents,
45 tools: Vec<Arc<dyn Tool>>,
46 #[allow(dead_code)] toolsets: Vec<Arc<dyn Toolset>>,
48 sub_agents: Vec<Arc<dyn Agent>>,
49 output_key: Option<String>,
50 generate_content_config: Option<adk_core::GenerateContentConfig>,
52 max_iterations: u32,
54 tool_timeout: std::time::Duration,
56 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
57 after_callbacks: Arc<Vec<AfterAgentCallback>>,
58 before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
59 after_model_callbacks: Arc<Vec<AfterModelCallback>>,
60 before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
61 after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
62 on_tool_error_callbacks: Arc<Vec<OnToolErrorCallback>>,
63 after_tool_callbacks_full: Arc<Vec<AfterToolCallbackFull>>,
65 default_retry_budget: Option<RetryBudget>,
67 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
69 circuit_breaker_threshold: Option<u32>,
72 tool_confirmation_policy: ToolConfirmationPolicy,
73 input_guardrails: Arc<GuardrailSet>,
74 output_guardrails: Arc<GuardrailSet>,
75}
76
77impl std::fmt::Debug for LlmAgent {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("LlmAgent")
80 .field("name", &self.name)
81 .field("description", &self.description)
82 .field("model", &self.model.name())
83 .field("instruction", &self.instruction)
84 .field("tools_count", &self.tools.len())
85 .field("sub_agents_count", &self.sub_agents.len())
86 .finish()
87 }
88}
89
90impl LlmAgent {
91 async fn apply_input_guardrails(
92 ctx: Arc<dyn InvocationContext>,
93 input_guardrails: Arc<GuardrailSet>,
94 ) -> Result<Arc<dyn InvocationContext>> {
95 let content =
96 enforce_guardrails(input_guardrails.as_ref(), ctx.user_content(), "input").await?;
97 if content.role != ctx.user_content().role || content.parts != ctx.user_content().parts {
98 Ok(with_user_content_override(ctx, content))
99 } else {
100 Ok(ctx)
101 }
102 }
103
104 async fn apply_output_guardrails(
105 output_guardrails: &GuardrailSet,
106 content: Content,
107 ) -> Result<Content> {
108 enforce_guardrails(output_guardrails, &content, "output").await
109 }
110}
111
112pub struct LlmAgentBuilder {
113 name: String,
114 description: Option<String>,
115 model: Option<Arc<dyn Llm>>,
116 instruction: Option<String>,
117 instruction_provider: Option<Arc<InstructionProvider>>,
118 global_instruction: Option<String>,
119 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
120 skills_index: Option<Arc<SkillIndex>>,
121 skill_policy: SelectionPolicy,
122 max_skill_chars: usize,
123 input_schema: Option<serde_json::Value>,
124 output_schema: Option<serde_json::Value>,
125 disallow_transfer_to_parent: bool,
126 disallow_transfer_to_peers: bool,
127 include_contents: adk_core::IncludeContents,
128 tools: Vec<Arc<dyn Tool>>,
129 toolsets: Vec<Arc<dyn Toolset>>,
130 sub_agents: Vec<Arc<dyn Agent>>,
131 output_key: Option<String>,
132 generate_content_config: Option<adk_core::GenerateContentConfig>,
133 max_iterations: u32,
134 tool_timeout: std::time::Duration,
135 before_callbacks: Vec<BeforeAgentCallback>,
136 after_callbacks: Vec<AfterAgentCallback>,
137 before_model_callbacks: Vec<BeforeModelCallback>,
138 after_model_callbacks: Vec<AfterModelCallback>,
139 before_tool_callbacks: Vec<BeforeToolCallback>,
140 after_tool_callbacks: Vec<AfterToolCallback>,
141 on_tool_error_callbacks: Vec<OnToolErrorCallback>,
142 after_tool_callbacks_full: Vec<AfterToolCallbackFull>,
143 default_retry_budget: Option<RetryBudget>,
144 tool_retry_budgets: std::collections::HashMap<String, RetryBudget>,
145 circuit_breaker_threshold: Option<u32>,
146 tool_confirmation_policy: ToolConfirmationPolicy,
147 input_guardrails: GuardrailSet,
148 output_guardrails: GuardrailSet,
149}
150
151impl LlmAgentBuilder {
152 pub fn new(name: impl Into<String>) -> Self {
153 Self {
154 name: name.into(),
155 description: None,
156 model: None,
157 instruction: None,
158 instruction_provider: None,
159 global_instruction: None,
160 global_instruction_provider: None,
161 skills_index: None,
162 skill_policy: SelectionPolicy::default(),
163 max_skill_chars: 2000,
164 input_schema: None,
165 output_schema: None,
166 disallow_transfer_to_parent: false,
167 disallow_transfer_to_peers: false,
168 include_contents: adk_core::IncludeContents::Default,
169 tools: Vec::new(),
170 toolsets: Vec::new(),
171 sub_agents: Vec::new(),
172 output_key: None,
173 generate_content_config: None,
174 max_iterations: DEFAULT_MAX_ITERATIONS,
175 tool_timeout: DEFAULT_TOOL_TIMEOUT,
176 before_callbacks: Vec::new(),
177 after_callbacks: Vec::new(),
178 before_model_callbacks: Vec::new(),
179 after_model_callbacks: Vec::new(),
180 before_tool_callbacks: Vec::new(),
181 after_tool_callbacks: Vec::new(),
182 on_tool_error_callbacks: Vec::new(),
183 after_tool_callbacks_full: Vec::new(),
184 default_retry_budget: None,
185 tool_retry_budgets: std::collections::HashMap::new(),
186 circuit_breaker_threshold: None,
187 tool_confirmation_policy: ToolConfirmationPolicy::Never,
188 input_guardrails: GuardrailSet::new(),
189 output_guardrails: GuardrailSet::new(),
190 }
191 }
192
193 pub fn description(mut self, desc: impl Into<String>) -> Self {
194 self.description = Some(desc.into());
195 self
196 }
197
198 pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
199 self.model = Some(model);
200 self
201 }
202
203 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
204 self.instruction = Some(instruction.into());
205 self
206 }
207
208 pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
209 self.instruction_provider = Some(Arc::new(provider));
210 self
211 }
212
213 pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
214 self.global_instruction = Some(instruction.into());
215 self
216 }
217
218 pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
219 self.global_instruction_provider = Some(Arc::new(provider));
220 self
221 }
222
223 pub fn with_skills(mut self, index: SkillIndex) -> Self {
225 self.skills_index = Some(Arc::new(index));
226 self
227 }
228
229 pub fn with_auto_skills(self) -> Result<Self> {
231 self.with_skills_from_root(".")
232 }
233
234 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
236 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
237 self.skills_index = Some(Arc::new(index));
238 Ok(self)
239 }
240
241 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
243 self.skill_policy = policy;
244 self
245 }
246
247 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
249 self.max_skill_chars = max_chars;
250 self
251 }
252
253 pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
254 self.input_schema = Some(schema);
255 self
256 }
257
258 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
259 self.output_schema = Some(schema);
260 self
261 }
262
263 pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
264 self.disallow_transfer_to_parent = disallow;
265 self
266 }
267
268 pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
269 self.disallow_transfer_to_peers = disallow;
270 self
271 }
272
273 pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
274 self.include_contents = include;
275 self
276 }
277
278 pub fn output_key(mut self, key: impl Into<String>) -> Self {
279 self.output_key = Some(key.into());
280 self
281 }
282
283 pub fn generate_content_config(mut self, config: adk_core::GenerateContentConfig) -> Self {
304 self.generate_content_config = Some(config);
305 self
306 }
307
308 pub fn temperature(mut self, temperature: f32) -> Self {
311 self.generate_content_config
312 .get_or_insert(adk_core::GenerateContentConfig::default())
313 .temperature = Some(temperature);
314 self
315 }
316
317 pub fn top_p(mut self, top_p: f32) -> Self {
319 self.generate_content_config
320 .get_or_insert(adk_core::GenerateContentConfig::default())
321 .top_p = Some(top_p);
322 self
323 }
324
325 pub fn top_k(mut self, top_k: i32) -> Self {
327 self.generate_content_config
328 .get_or_insert(adk_core::GenerateContentConfig::default())
329 .top_k = Some(top_k);
330 self
331 }
332
333 pub fn max_output_tokens(mut self, max_tokens: i32) -> Self {
335 self.generate_content_config
336 .get_or_insert(adk_core::GenerateContentConfig::default())
337 .max_output_tokens = Some(max_tokens);
338 self
339 }
340
341 pub fn max_iterations(mut self, max: u32) -> Self {
344 self.max_iterations = max;
345 self
346 }
347
348 pub fn tool_timeout(mut self, timeout: std::time::Duration) -> Self {
351 self.tool_timeout = timeout;
352 self
353 }
354
355 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
356 self.tools.push(tool);
357 self
358 }
359
360 pub fn toolset(mut self, toolset: Arc<dyn Toolset>) -> Self {
366 self.toolsets.push(toolset);
367 self
368 }
369
370 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
371 self.sub_agents.push(agent);
372 self
373 }
374
375 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
376 self.before_callbacks.push(callback);
377 self
378 }
379
380 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
381 self.after_callbacks.push(callback);
382 self
383 }
384
385 pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
386 self.before_model_callbacks.push(callback);
387 self
388 }
389
390 pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
391 self.after_model_callbacks.push(callback);
392 self
393 }
394
395 pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
396 self.before_tool_callbacks.push(callback);
397 self
398 }
399
400 pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
401 self.after_tool_callbacks.push(callback);
402 self
403 }
404
405 pub fn after_tool_callback_full(mut self, callback: AfterToolCallbackFull) -> Self {
420 self.after_tool_callbacks_full.push(callback);
421 self
422 }
423
424 pub fn on_tool_error(mut self, callback: OnToolErrorCallback) -> Self {
432 self.on_tool_error_callbacks.push(callback);
433 self
434 }
435
436 pub fn default_retry_budget(mut self, budget: RetryBudget) -> Self {
443 self.default_retry_budget = Some(budget);
444 self
445 }
446
447 pub fn tool_retry_budget(mut self, tool_name: impl Into<String>, budget: RetryBudget) -> Self {
452 self.tool_retry_budgets.insert(tool_name.into(), budget);
453 self
454 }
455
456 pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
463 self.circuit_breaker_threshold = Some(threshold);
464 self
465 }
466
467 pub fn tool_confirmation_policy(mut self, policy: ToolConfirmationPolicy) -> Self {
469 self.tool_confirmation_policy = policy;
470 self
471 }
472
473 pub fn require_tool_confirmation(mut self, tool_name: impl Into<String>) -> Self {
475 self.tool_confirmation_policy = self.tool_confirmation_policy.with_tool(tool_name);
476 self
477 }
478
479 pub fn require_tool_confirmation_for_all(mut self) -> Self {
481 self.tool_confirmation_policy = ToolConfirmationPolicy::Always;
482 self
483 }
484
485 pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
494 self.input_guardrails = guardrails;
495 self
496 }
497
498 pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
507 self.output_guardrails = guardrails;
508 self
509 }
510
511 pub fn build(self) -> Result<LlmAgent> {
512 let model =
513 self.model.ok_or_else(|| adk_core::AdkError::Agent("Model is required".to_string()))?;
514
515 let mut seen_names = std::collections::HashSet::new();
516 for agent in &self.sub_agents {
517 if !seen_names.insert(agent.name()) {
518 return Err(adk_core::AdkError::Agent(format!(
519 "Duplicate sub-agent name: {}",
520 agent.name()
521 )));
522 }
523 }
524
525 Ok(LlmAgent {
526 name: self.name,
527 description: self.description.unwrap_or_default(),
528 model,
529 instruction: self.instruction,
530 instruction_provider: self.instruction_provider,
531 global_instruction: self.global_instruction,
532 global_instruction_provider: self.global_instruction_provider,
533 skills_index: self.skills_index,
534 skill_policy: self.skill_policy,
535 max_skill_chars: self.max_skill_chars,
536 input_schema: self.input_schema,
537 output_schema: self.output_schema,
538 disallow_transfer_to_parent: self.disallow_transfer_to_parent,
539 disallow_transfer_to_peers: self.disallow_transfer_to_peers,
540 include_contents: self.include_contents,
541 tools: self.tools,
542 toolsets: self.toolsets,
543 sub_agents: self.sub_agents,
544 output_key: self.output_key,
545 generate_content_config: self.generate_content_config,
546 max_iterations: self.max_iterations,
547 tool_timeout: self.tool_timeout,
548 before_callbacks: Arc::new(self.before_callbacks),
549 after_callbacks: Arc::new(self.after_callbacks),
550 before_model_callbacks: Arc::new(self.before_model_callbacks),
551 after_model_callbacks: Arc::new(self.after_model_callbacks),
552 before_tool_callbacks: Arc::new(self.before_tool_callbacks),
553 after_tool_callbacks: Arc::new(self.after_tool_callbacks),
554 on_tool_error_callbacks: Arc::new(self.on_tool_error_callbacks),
555 after_tool_callbacks_full: Arc::new(self.after_tool_callbacks_full),
556 default_retry_budget: self.default_retry_budget,
557 tool_retry_budgets: self.tool_retry_budgets,
558 circuit_breaker_threshold: self.circuit_breaker_threshold,
559 tool_confirmation_policy: self.tool_confirmation_policy,
560 input_guardrails: Arc::new(self.input_guardrails),
561 output_guardrails: Arc::new(self.output_guardrails),
562 })
563 }
564}
565
566struct AgentToolContext {
569 parent_ctx: Arc<dyn InvocationContext>,
570 function_call_id: String,
571 actions: Mutex<EventActions>,
572}
573
574impl AgentToolContext {
575 fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
576 Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
577 }
578
579 fn actions_guard(&self) -> std::sync::MutexGuard<'_, EventActions> {
580 self.actions.lock().unwrap_or_else(|e| e.into_inner())
581 }
582}
583
584#[async_trait]
585impl ReadonlyContext for AgentToolContext {
586 fn invocation_id(&self) -> &str {
587 self.parent_ctx.invocation_id()
588 }
589
590 fn agent_name(&self) -> &str {
591 self.parent_ctx.agent_name()
592 }
593
594 fn user_id(&self) -> &str {
595 self.parent_ctx.user_id()
597 }
598
599 fn app_name(&self) -> &str {
600 self.parent_ctx.app_name()
602 }
603
604 fn session_id(&self) -> &str {
605 self.parent_ctx.session_id()
607 }
608
609 fn branch(&self) -> &str {
610 self.parent_ctx.branch()
611 }
612
613 fn user_content(&self) -> &Content {
614 self.parent_ctx.user_content()
615 }
616}
617
618#[async_trait]
619impl CallbackContext for AgentToolContext {
620 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
621 self.parent_ctx.artifacts()
623 }
624}
625
626#[async_trait]
627impl ToolContext for AgentToolContext {
628 fn function_call_id(&self) -> &str {
629 &self.function_call_id
630 }
631
632 fn actions(&self) -> EventActions {
633 self.actions_guard().clone()
634 }
635
636 fn set_actions(&self, actions: EventActions) {
637 *self.actions_guard() = actions;
638 }
639
640 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
641 if let Some(memory) = self.parent_ctx.memory() {
643 memory.search(query).await
644 } else {
645 Ok(vec![])
646 }
647 }
648
649 fn user_scopes(&self) -> Vec<String> {
650 self.parent_ctx.user_scopes()
651 }
652}
653
654struct ToolCallbackContext {
658 inner: Arc<dyn CallbackContext>,
659 outcome: ToolOutcome,
660}
661
662#[async_trait]
663impl ReadonlyContext for ToolCallbackContext {
664 fn invocation_id(&self) -> &str {
665 self.inner.invocation_id()
666 }
667
668 fn agent_name(&self) -> &str {
669 self.inner.agent_name()
670 }
671
672 fn user_id(&self) -> &str {
673 self.inner.user_id()
674 }
675
676 fn app_name(&self) -> &str {
677 self.inner.app_name()
678 }
679
680 fn session_id(&self) -> &str {
681 self.inner.session_id()
682 }
683
684 fn branch(&self) -> &str {
685 self.inner.branch()
686 }
687
688 fn user_content(&self) -> &Content {
689 self.inner.user_content()
690 }
691}
692
693#[async_trait]
694impl CallbackContext for ToolCallbackContext {
695 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
696 self.inner.artifacts()
697 }
698
699 fn tool_outcome(&self) -> Option<ToolOutcome> {
700 Some(self.outcome.clone())
701 }
702}
703
704struct CircuitBreakerState {
714 threshold: u32,
715 failures: std::collections::HashMap<String, u32>,
717}
718
719impl CircuitBreakerState {
720 fn new(threshold: u32) -> Self {
721 Self { threshold, failures: std::collections::HashMap::new() }
722 }
723
724 fn is_open(&self, tool_name: &str) -> bool {
726 self.failures.get(tool_name).copied().unwrap_or(0) >= self.threshold
727 }
728
729 fn record(&mut self, outcome: &ToolOutcome) {
731 if outcome.success {
732 self.failures.remove(&outcome.tool_name);
733 } else {
734 let count = self.failures.entry(outcome.tool_name.clone()).or_insert(0);
735 *count += 1;
736 }
737 }
738}
739
740#[async_trait]
741impl Agent for LlmAgent {
742 fn name(&self) -> &str {
743 &self.name
744 }
745
746 fn description(&self) -> &str {
747 &self.description
748 }
749
750 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
751 &self.sub_agents
752 }
753
754 #[adk_telemetry::instrument(
755 skip(self, ctx),
756 fields(
757 agent.name = %self.name,
758 agent.description = %self.description,
759 invocation.id = %ctx.invocation_id(),
760 user.id = %ctx.user_id(),
761 session.id = %ctx.session_id()
762 )
763 )]
764 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
765 adk_telemetry::info!("Starting agent execution");
766 let ctx = Self::apply_input_guardrails(ctx, self.input_guardrails.clone()).await?;
767
768 let agent_name = self.name.clone();
769 let invocation_id = ctx.invocation_id().to_string();
770 let model = self.model.clone();
771 let tools = self.tools.clone();
772 let toolsets = self.toolsets.clone();
773 let sub_agents = self.sub_agents.clone();
774
775 let instruction = self.instruction.clone();
776 let instruction_provider = self.instruction_provider.clone();
777 let global_instruction = self.global_instruction.clone();
778 let global_instruction_provider = self.global_instruction_provider.clone();
779 let skills_index = self.skills_index.clone();
780 let skill_policy = self.skill_policy.clone();
781 let max_skill_chars = self.max_skill_chars;
782 let output_key = self.output_key.clone();
783 let output_schema = self.output_schema.clone();
784 let generate_content_config = self.generate_content_config.clone();
785 let include_contents = self.include_contents;
786 let max_iterations = self.max_iterations;
787 let tool_timeout = self.tool_timeout;
788 let before_agent_callbacks = self.before_callbacks.clone();
790 let after_agent_callbacks = self.after_callbacks.clone();
791 let before_model_callbacks = self.before_model_callbacks.clone();
792 let after_model_callbacks = self.after_model_callbacks.clone();
793 let before_tool_callbacks = self.before_tool_callbacks.clone();
794 let after_tool_callbacks = self.after_tool_callbacks.clone();
795 let on_tool_error_callbacks = self.on_tool_error_callbacks.clone();
796 let after_tool_callbacks_full = self.after_tool_callbacks_full.clone();
797 let default_retry_budget = self.default_retry_budget.clone();
798 let tool_retry_budgets = self.tool_retry_budgets.clone();
799 let circuit_breaker_threshold = self.circuit_breaker_threshold;
800 let tool_confirmation_policy = self.tool_confirmation_policy.clone();
801 let disallow_transfer_to_parent = self.disallow_transfer_to_parent;
802 let disallow_transfer_to_peers = self.disallow_transfer_to_peers;
803 let output_guardrails = self.output_guardrails.clone();
804
805 let s = stream! {
806 for callback in before_agent_callbacks.as_ref() {
810 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
811 Ok(Some(content)) => {
812 let mut early_event = Event::new(&invocation_id);
814 early_event.author = agent_name.clone();
815 early_event.llm_response.content = Some(content);
816 yield Ok(early_event);
817
818 for after_callback in after_agent_callbacks.as_ref() {
820 match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
821 Ok(Some(after_content)) => {
822 let mut after_event = Event::new(&invocation_id);
823 after_event.author = agent_name.clone();
824 after_event.llm_response.content = Some(after_content);
825 yield Ok(after_event);
826 return;
827 }
828 Ok(None) => continue,
829 Err(e) => {
830 yield Err(e);
831 return;
832 }
833 }
834 }
835 return;
836 }
837 Ok(None) => {
838 continue;
840 }
841 Err(e) => {
842 yield Err(e);
844 return;
845 }
846 }
847 }
848
849 let mut prompt_preamble = Vec::new();
851
852 if let Some(index) = &skills_index {
856 let user_query = ctx
857 .user_content()
858 .parts
859 .iter()
860 .filter_map(|part| match part {
861 Part::Text { text } => Some(text.as_str()),
862 _ => None,
863 })
864 .collect::<Vec<_>>()
865 .join("\n");
866
867 if let Some((_matched, skill_block)) = select_skill_prompt_block(
868 index.as_ref(),
869 &user_query,
870 &skill_policy,
871 max_skill_chars,
872 ) {
873 prompt_preamble.push(Content {
874 role: "user".to_string(),
875 parts: vec![Part::Text { text: skill_block }],
876 });
877 }
878 }
879
880 if let Some(provider) = &global_instruction_provider {
883 let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
885 if !global_inst.is_empty() {
886 prompt_preamble.push(Content {
887 role: "user".to_string(),
888 parts: vec![Part::Text { text: global_inst }],
889 });
890 }
891 } else if let Some(ref template) = global_instruction {
892 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
894 if !processed.is_empty() {
895 prompt_preamble.push(Content {
896 role: "user".to_string(),
897 parts: vec![Part::Text { text: processed }],
898 });
899 }
900 }
901
902 if let Some(provider) = &instruction_provider {
905 let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
907 if !inst.is_empty() {
908 prompt_preamble.push(Content {
909 role: "user".to_string(),
910 parts: vec![Part::Text { text: inst }],
911 });
912 }
913 } else if let Some(ref template) = instruction {
914 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
916 if !processed.is_empty() {
917 prompt_preamble.push(Content {
918 role: "user".to_string(),
919 parts: vec![Part::Text { text: processed }],
920 });
921 }
922 }
923
924 let session_history = if !ctx.run_config().transfer_targets.is_empty() {
930 ctx.session().conversation_history_for_agent(&agent_name)
931 } else {
932 ctx.session().conversation_history()
933 };
934 let mut session_history = session_history;
935 let current_user_content = ctx.user_content().clone();
936 if let Some(index) = session_history.iter().rposition(|content| content.role == "user") {
937 session_history[index] = current_user_content.clone();
938 } else {
939 session_history.push(current_user_content.clone());
940 }
941
942 let mut conversation_history = match include_contents {
945 adk_core::IncludeContents::None => {
946 let mut filtered = prompt_preamble.clone();
947 filtered.push(current_user_content);
948 filtered
949 }
950 adk_core::IncludeContents::Default => {
951 let mut full_history = prompt_preamble;
952 full_history.extend(session_history);
953 full_history
954 }
955 };
956
957 let mut resolved_tools: Vec<Arc<dyn Tool>> = tools.clone();
960 let static_tool_names: std::collections::HashSet<String> =
961 tools.iter().map(|t| t.name().to_string()).collect();
962
963 let mut toolset_source: std::collections::HashMap<String, String> =
965 std::collections::HashMap::new();
966
967 for toolset in &toolsets {
968 let toolset_tools = match toolset
969 .tools(ctx.clone() as Arc<dyn ReadonlyContext>)
970 .await
971 {
972 Ok(t) => t,
973 Err(e) => {
974 yield Err(e);
975 return;
976 }
977 };
978 for tool in &toolset_tools {
979 let name = tool.name().to_string();
980 if static_tool_names.contains(&name) {
982 yield Err(adk_core::AdkError::Agent(format!(
983 "Duplicate tool name '{}': conflict between static tool and toolset '{}'",
984 name,
985 toolset.name()
986 )));
987 return;
988 }
989 if let Some(other_toolset_name) = toolset_source.get(&name) {
991 yield Err(adk_core::AdkError::Agent(format!(
992 "Duplicate tool name '{}': conflict between toolset '{}' and toolset '{}'",
993 name,
994 other_toolset_name,
995 toolset.name()
996 )));
997 return;
998 }
999 toolset_source.insert(name, toolset.name().to_string());
1000 resolved_tools.push(tool.clone());
1001 }
1002 }
1003
1004 let tool_map: std::collections::HashMap<String, Arc<dyn Tool>> = resolved_tools
1006 .iter()
1007 .map(|t| (t.name().to_string(), t.clone()))
1008 .collect();
1009
1010 let collect_long_running_ids = |content: &Content| -> Vec<String> {
1012 content.parts.iter()
1013 .filter_map(|p| {
1014 if let Part::FunctionCall { name, .. } = p {
1015 if let Some(tool) = tool_map.get(name) {
1016 if tool.is_long_running() {
1017 return Some(name.clone());
1018 }
1019 }
1020 }
1021 None
1022 })
1023 .collect()
1024 };
1025
1026 let mut tool_declarations = std::collections::HashMap::new();
1029 for tool in &resolved_tools {
1030 let mut decl = serde_json::json!({
1033 "name": tool.name(),
1034 "description": tool.enhanced_description(),
1035 });
1036
1037 if let Some(params) = tool.parameters_schema() {
1038 decl["parameters"] = params;
1039 }
1040
1041 if let Some(response) = tool.response_schema() {
1042 decl["response"] = response;
1043 }
1044
1045 tool_declarations.insert(tool.name().to_string(), decl);
1046 }
1047
1048 let mut valid_transfer_targets: Vec<String> = sub_agents
1053 .iter()
1054 .map(|a| a.name().to_string())
1055 .collect();
1056
1057 let run_config_targets = &ctx.run_config().transfer_targets;
1059 let parent_agent_name = ctx.run_config().parent_agent.clone();
1060 let sub_agent_names: std::collections::HashSet<&str> = sub_agents
1061 .iter()
1062 .map(|a| a.name())
1063 .collect();
1064
1065 for target in run_config_targets {
1066 if sub_agent_names.contains(target.as_str()) {
1068 continue;
1069 }
1070
1071 let is_parent = parent_agent_name.as_deref() == Some(target.as_str());
1073 if is_parent && disallow_transfer_to_parent {
1074 continue;
1075 }
1076 if !is_parent && disallow_transfer_to_peers {
1077 continue;
1078 }
1079
1080 valid_transfer_targets.push(target.clone());
1081 }
1082
1083 if !valid_transfer_targets.is_empty() {
1085 let transfer_tool_name = "transfer_to_agent";
1086 let transfer_tool_decl = serde_json::json!({
1087 "name": transfer_tool_name,
1088 "description": format!(
1089 "Transfer execution to another agent. Valid targets: {}",
1090 valid_transfer_targets.join(", ")
1091 ),
1092 "parameters": {
1093 "type": "object",
1094 "properties": {
1095 "agent_name": {
1096 "type": "string",
1097 "description": "The name of the agent to transfer to.",
1098 "enum": valid_transfer_targets
1099 }
1100 },
1101 "required": ["agent_name"]
1102 }
1103 });
1104 tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
1105 }
1106
1107
1108 let mut circuit_breaker_state = circuit_breaker_threshold.map(CircuitBreakerState::new);
1111
1112 let mut iteration = 0;
1114
1115 loop {
1116 iteration += 1;
1117 if iteration > max_iterations {
1118 yield Err(adk_core::AdkError::Agent(
1119 format!("Max iterations ({}) exceeded", max_iterations)
1120 ));
1121 return;
1122 }
1123
1124 let config = match (&generate_content_config, &output_schema) {
1131 (Some(base), Some(schema)) => {
1132 let mut merged = base.clone();
1133 merged.response_schema = Some(schema.clone());
1134 Some(merged)
1135 }
1136 (Some(base), None) => Some(base.clone()),
1137 (None, Some(schema)) => Some(adk_core::GenerateContentConfig {
1138 response_schema: Some(schema.clone()),
1139 ..Default::default()
1140 }),
1141 (None, None) => None,
1142 };
1143
1144 let config = if let Some(ref cached) = ctx.run_config().cached_content {
1146 let mut cfg = config.unwrap_or_default();
1147 if cfg.cached_content.is_none() {
1149 cfg.cached_content = Some(cached.clone());
1150 }
1151 Some(cfg)
1152 } else {
1153 config
1154 };
1155
1156 let request = LlmRequest {
1157 model: model.name().to_string(),
1158 contents: conversation_history.clone(),
1159 tools: tool_declarations.clone(),
1160 config,
1161 };
1162
1163 let mut current_request = request;
1166 let mut model_response_override = None;
1167 for callback in before_model_callbacks.as_ref() {
1168 match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
1169 Ok(BeforeModelResult::Continue(modified_request)) => {
1170 current_request = modified_request;
1172 }
1173 Ok(BeforeModelResult::Skip(response)) => {
1174 model_response_override = Some(response);
1176 break;
1177 }
1178 Err(e) => {
1179 yield Err(e);
1181 return;
1182 }
1183 }
1184 }
1185 let request = current_request;
1186
1187 let mut accumulated_content: Option<Content> = None;
1189
1190 if let Some(cached_response) = model_response_override {
1191 accumulated_content = cached_response.content.clone();
1194 normalize_option_content(&mut accumulated_content);
1195 if let Some(content) = accumulated_content.take() {
1196 let has_function_calls = content
1197 .parts
1198 .iter()
1199 .any(|part| matches!(part, Part::FunctionCall { .. }));
1200 let content = if has_function_calls {
1201 content
1202 } else {
1203 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1204 };
1205 accumulated_content = Some(content);
1206 }
1207
1208 let mut cached_event = Event::new(&invocation_id);
1209 cached_event.author = agent_name.clone();
1210 cached_event.llm_response.content = accumulated_content.clone();
1211 cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
1212 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), serde_json::to_string(&request).unwrap_or_default());
1213 cached_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&cached_response).unwrap_or_default());
1214
1215 if let Some(ref content) = accumulated_content {
1217 cached_event.long_running_tool_ids = collect_long_running_ids(content);
1218 }
1219
1220 yield Ok(cached_event);
1221 } else {
1222 let request_json = serde_json::to_string(&request).unwrap_or_default();
1224
1225 let llm_ts = std::time::SystemTime::now()
1227 .duration_since(std::time::UNIX_EPOCH)
1228 .unwrap_or_default()
1229 .as_nanos();
1230 let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
1231 let llm_span = tracing::info_span!(
1232 "call_llm",
1233 "gcp.vertex.agent.event_id" = %llm_event_id,
1234 "gcp.vertex.agent.invocation_id" = %invocation_id,
1235 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1236 "gen_ai.conversation.id" = %ctx.session_id(),
1237 "gcp.vertex.agent.llm_request" = %request_json,
1238 "gcp.vertex.agent.llm_response" = tracing::field::Empty );
1240 let _llm_guard = llm_span.enter();
1241
1242 use adk_core::StreamingMode;
1244 let streaming_mode = ctx.run_config().streaming_mode;
1245 let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi)
1246 && output_guardrails.is_empty();
1247
1248 let mut response_stream = model.generate_content(request, true).await?;
1250
1251 use futures::StreamExt;
1252
1253 let mut last_chunk: Option<LlmResponse> = None;
1255
1256 while let Some(chunk_result) = response_stream.next().await {
1258 let mut chunk = match chunk_result {
1259 Ok(c) => c,
1260 Err(e) => {
1261 yield Err(e);
1262 return;
1263 }
1264 };
1265
1266 for callback in after_model_callbacks.as_ref() {
1269 match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
1270 Ok(Some(modified_chunk)) => {
1271 chunk = modified_chunk;
1273 break;
1274 }
1275 Ok(None) => {
1276 continue;
1278 }
1279 Err(e) => {
1280 yield Err(e);
1282 return;
1283 }
1284 }
1285 }
1286
1287 normalize_option_content(&mut chunk.content);
1288
1289 if let Some(chunk_content) = chunk.content.clone() {
1291 if let Some(ref mut acc) = accumulated_content {
1292 acc.parts.extend(chunk_content.parts);
1293 } else {
1294 accumulated_content = Some(chunk_content);
1295 }
1296 }
1297
1298 if should_stream_to_client {
1300 let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
1301 partial_event.author = agent_name.clone();
1302 partial_event.llm_request = Some(request_json.clone());
1303 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1304 partial_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(&chunk).unwrap_or_default());
1305 partial_event.llm_response.partial = chunk.partial;
1306 partial_event.llm_response.turn_complete = chunk.turn_complete;
1307 partial_event.llm_response.finish_reason = chunk.finish_reason;
1308 partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
1309 partial_event.llm_response.content = chunk.content.clone();
1310
1311 if let Some(ref content) = chunk.content {
1313 partial_event.long_running_tool_ids = collect_long_running_ids(content);
1314 }
1315
1316 yield Ok(partial_event);
1317 }
1318
1319 last_chunk = Some(chunk.clone());
1321
1322 if chunk.turn_complete {
1324 break;
1325 }
1326 }
1327
1328 if !should_stream_to_client {
1330 if let Some(content) = accumulated_content.take() {
1331 let has_function_calls = content
1332 .parts
1333 .iter()
1334 .any(|part| matches!(part, Part::FunctionCall { .. }));
1335 let content = if has_function_calls {
1336 content
1337 } else {
1338 Self::apply_output_guardrails(output_guardrails.as_ref(), content).await?
1339 };
1340 accumulated_content = Some(content);
1341 }
1342
1343 let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
1344 final_event.author = agent_name.clone();
1345 final_event.llm_request = Some(request_json.clone());
1346 final_event.provider_metadata.insert("gcp.vertex.agent.llm_request".to_string(), request_json.clone());
1347 final_event.llm_response.content = accumulated_content.clone();
1348 final_event.llm_response.partial = false;
1349 final_event.llm_response.turn_complete = true;
1350
1351 if let Some(ref last) = last_chunk {
1353 final_event.llm_response.finish_reason = last.finish_reason;
1354 final_event.llm_response.usage_metadata = last.usage_metadata.clone();
1355 final_event.provider_metadata.insert("gcp.vertex.agent.llm_response".to_string(), serde_json::to_string(last).unwrap_or_default());
1356 }
1357
1358 if let Some(ref content) = accumulated_content {
1360 final_event.long_running_tool_ids = collect_long_running_ids(content);
1361 }
1362
1363 yield Ok(final_event);
1364 }
1365
1366 if let Some(ref content) = accumulated_content {
1368 let response_json = serde_json::to_string(content).unwrap_or_default();
1369 llm_span.record("gcp.vertex.agent.llm_response", &response_json);
1370 }
1371 }
1372
1373 let function_call_names: Vec<String> = accumulated_content.as_ref()
1375 .map(|c| c.parts.iter()
1376 .filter_map(|p| {
1377 if let Part::FunctionCall { name, .. } = p {
1378 Some(name.clone())
1379 } else {
1380 None
1381 }
1382 })
1383 .collect())
1384 .unwrap_or_default();
1385
1386 let has_function_calls = !function_call_names.is_empty();
1387
1388 let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
1392 tool_map.get(name)
1393 .map(|t| t.is_long_running())
1394 .unwrap_or(false)
1395 });
1396
1397 if let Some(ref content) = accumulated_content {
1399 conversation_history.push(content.clone());
1400
1401 if let Some(ref output_key) = output_key {
1403 if !has_function_calls { let mut text_parts = String::new();
1405 for part in &content.parts {
1406 if let Part::Text { text } = part {
1407 text_parts.push_str(text);
1408 }
1409 }
1410 if !text_parts.is_empty() {
1411 let mut state_event = Event::new(&invocation_id);
1413 state_event.author = agent_name.clone();
1414 state_event.actions.state_delta.insert(
1415 output_key.clone(),
1416 serde_json::Value::String(text_parts),
1417 );
1418 yield Ok(state_event);
1419 }
1420 }
1421 }
1422 }
1423
1424 if !has_function_calls {
1425 if let Some(ref content) = accumulated_content {
1428 let response_json = serde_json::to_string(content).unwrap_or_default();
1429 tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
1430 }
1431
1432 tracing::info!(agent.name = %agent_name, "Agent execution complete");
1433 break;
1434 }
1435
1436 if let Some(content) = &accumulated_content {
1438 let mut tool_call_index = 0usize;
1439 for part in &content.parts {
1440 if let Part::FunctionCall { name, args, id, .. } = part {
1441 let fallback_call_id =
1442 format!("{}_{}_{}", invocation_id, name, tool_call_index);
1443 tool_call_index += 1;
1444 let function_call_id = id.clone().unwrap_or(fallback_call_id);
1445
1446 if name == "transfer_to_agent" {
1448 let target_agent = args.get("agent_name")
1449 .and_then(|v| v.as_str())
1450 .unwrap_or_default()
1451 .to_string();
1452
1453 let valid_target = valid_transfer_targets.iter().any(|n| n == &target_agent);
1456 if !valid_target {
1457 let error_content = Content {
1459 role: "function".to_string(),
1460 parts: vec![Part::FunctionResponse {
1461 function_response: FunctionResponseData {
1462 name: name.clone(),
1463 response: serde_json::json!({
1464 "error": format!(
1465 "Agent '{}' not found. Available agents: {:?}",
1466 target_agent,
1467 valid_transfer_targets
1468 )
1469 }),
1470 },
1471 id: id.clone(),
1472 }],
1473 };
1474 conversation_history.push(error_content.clone());
1475
1476 let mut error_event = Event::new(&invocation_id);
1477 error_event.author = agent_name.clone();
1478 error_event.llm_response.content = Some(error_content);
1479 yield Ok(error_event);
1480 continue;
1481 }
1482
1483 let mut transfer_event = Event::new(&invocation_id);
1484 transfer_event.author = agent_name.clone();
1485 transfer_event.actions.transfer_to_agent = Some(target_agent);
1486
1487 yield Ok(transfer_event);
1488 return;
1489 }
1490
1491 let mut tool_actions = EventActions::default();
1494 let mut response_content: Option<Content> = None;
1495 let mut run_after_tool_callbacks = true;
1496 let mut tool_outcome_for_callback: Option<ToolOutcome> = None;
1497 let mut executed_tool: Option<Arc<dyn Tool>> = None;
1499 let mut executed_tool_response: Option<serde_json::Value> = None;
1500
1501 if tool_confirmation_policy.requires_confirmation(name) {
1506 match ctx.run_config().tool_confirmation_decisions.get(name).copied()
1507 {
1508 Some(ToolConfirmationDecision::Approve) => {
1509 tool_actions.tool_confirmation_decision =
1510 Some(ToolConfirmationDecision::Approve);
1511 }
1512 Some(ToolConfirmationDecision::Deny) => {
1513 tool_actions.tool_confirmation_decision =
1514 Some(ToolConfirmationDecision::Deny);
1515 response_content = Some(Content {
1516 role: "function".to_string(),
1517 parts: vec![Part::FunctionResponse {
1518 function_response: FunctionResponseData {
1519 name: name.clone(),
1520 response: serde_json::json!({
1521 "error": format!(
1522 "Tool '{}' execution denied by confirmation policy",
1523 name
1524 ),
1525 }),
1526 },
1527 id: id.clone(),
1528 }],
1529 });
1530 run_after_tool_callbacks = false;
1531 }
1532 None => {
1533 let mut confirmation_event = Event::new(&invocation_id);
1534 confirmation_event.author = agent_name.clone();
1535 confirmation_event.llm_response.interrupted = true;
1536 confirmation_event.llm_response.turn_complete = true;
1537 confirmation_event.llm_response.content = Some(Content {
1538 role: "model".to_string(),
1539 parts: vec![Part::Text {
1540 text: format!(
1541 "Tool confirmation required for '{}'. Provide approve/deny decision to continue.",
1542 name
1543 ),
1544 }],
1545 });
1546 confirmation_event.actions.tool_confirmation =
1547 Some(ToolConfirmationRequest {
1548 tool_name: name.clone(),
1549 function_call_id: Some(function_call_id),
1550 args: args.clone(),
1551 });
1552 yield Ok(confirmation_event);
1553 return;
1554 }
1555 }
1556 }
1557
1558 if response_content.is_none() {
1559 for callback in before_tool_callbacks.as_ref() {
1560 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1561 Ok(Some(content)) => {
1562 response_content = Some(content);
1563 break;
1564 }
1565 Ok(None) => continue,
1566 Err(e) => {
1567 yield Err(e);
1568 return;
1569 }
1570 }
1571 }
1572 }
1573
1574 if response_content.is_none() {
1576 if let Some(ref cb_state) = circuit_breaker_state {
1580 if cb_state.is_open(name) {
1581 let error_msg = format!(
1582 "Tool '{}' is temporarily disabled after {} consecutive failures",
1583 name, cb_state.threshold
1584 );
1585 tracing::warn!(tool.name = %name, "circuit breaker open, skipping tool execution");
1586 response_content = Some(Content {
1587 role: "function".to_string(),
1588 parts: vec![Part::FunctionResponse {
1589 function_response: FunctionResponseData {
1590 name: name.clone(),
1591 response: serde_json::json!({ "error": error_msg }),
1592 },
1593 id: id.clone(),
1594 }],
1595 });
1596 run_after_tool_callbacks = false;
1597 }
1598 }
1599 }
1600
1601 if response_content.is_none() {
1602 if let Some(tool) = tool_map.get(name) {
1603 let tool_ctx: Arc<dyn ToolContext> = Arc::new(AgentToolContext::new(
1605 ctx.clone(),
1606 function_call_id.clone(),
1607 ));
1608
1609 let span_name = format!("execute_tool {}", name);
1611 let tool_span = tracing::info_span!(
1612 "",
1613 otel.name = %span_name,
1614 tool.name = %name,
1615 "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
1616 "gcp.vertex.agent.invocation_id" = %invocation_id,
1617 "gcp.vertex.agent.session_id" = %ctx.session_id(),
1618 "gen_ai.conversation.id" = %ctx.session_id()
1619 );
1620
1621 let budget = tool_retry_budgets.get(name)
1625 .or(default_retry_budget.as_ref());
1626 let max_attempts = budget.map(|b| b.max_retries + 1).unwrap_or(1);
1627 let retry_delay = budget.map(|b| b.delay).unwrap_or_default();
1628
1629 let tool_clone = tool.clone();
1632 let tool_start = std::time::Instant::now();
1633
1634 let mut last_error = String::new();
1635 let mut final_attempt: u32 = 0;
1636 let mut retry_result: Option<serde_json::Value> = None;
1637
1638 for attempt in 0..max_attempts {
1639 final_attempt = attempt;
1640
1641 if attempt > 0 {
1642 tokio::time::sleep(retry_delay).await;
1643 }
1644
1645 match async {
1646 tracing::info!(tool.name = %name, tool.args = %args, attempt = attempt, "tool_call");
1647 let exec_future = tool_clone.execute(tool_ctx.clone(), args.clone());
1648 tokio::time::timeout(tool_timeout, exec_future).await
1649 }.instrument(tool_span.clone()).await {
1650 Ok(Ok(value)) => {
1651 tracing::info!(tool.name = %name, tool.result = %value, "tool_result");
1652 retry_result = Some(value);
1653 break;
1654 }
1655 Ok(Err(e)) => {
1656 last_error = e.to_string();
1657 if attempt + 1 < max_attempts {
1658 tracing::warn!(tool.name = %name, attempt = attempt, error = %last_error, "tool execution failed, retrying");
1659 } else {
1660 tracing::warn!(tool.name = %name, error = %last_error, "tool_error");
1661 }
1662 }
1663 Err(_) => {
1664 last_error = format!(
1665 "Tool '{}' timed out after {} seconds",
1666 name, tool_timeout.as_secs()
1667 );
1668 if attempt + 1 < max_attempts {
1669 tracing::warn!(tool.name = %name, attempt = attempt, timeout_secs = tool_timeout.as_secs(), "tool timed out, retrying");
1670 } else {
1671 tracing::warn!(tool.name = %name, timeout_secs = tool_timeout.as_secs(), "tool_timeout");
1672 }
1673 }
1674 }
1675 }
1676
1677 let tool_duration = tool_start.elapsed();
1678
1679 let (tool_success, tool_error_message, function_response) = match retry_result {
1681 Some(value) => (true, None, value),
1682 None => (false, Some(last_error.clone()), serde_json::json!({ "error": last_error })),
1683 };
1684
1685 let outcome = ToolOutcome {
1688 tool_name: name.clone(),
1689 tool_args: args.clone(),
1690 success: tool_success,
1691 duration: tool_duration,
1692 error_message: tool_error_message.clone(),
1693 attempt: final_attempt,
1694 };
1695 tool_outcome_for_callback = Some(outcome);
1696
1697 if let Some(ref mut cb_state) = circuit_breaker_state {
1699 cb_state.record(tool_outcome_for_callback.as_ref().unwrap());
1700 }
1701
1702 let final_function_response = if !tool_success {
1707 let mut fallback_result = None;
1708 let error_msg = tool_error_message.clone().unwrap_or_default();
1709 for callback in on_tool_error_callbacks.as_ref() {
1710 match callback(
1711 ctx.clone() as Arc<dyn CallbackContext>,
1712 tool.clone(),
1713 args.clone(),
1714 error_msg.clone(),
1715 ).await {
1716 Ok(Some(result)) => {
1717 fallback_result = Some(result);
1718 break;
1719 }
1720 Ok(None) => continue,
1721 Err(e) => {
1722 tracing::warn!(error = %e, "on_tool_error callback failed");
1723 break;
1724 }
1725 }
1726 }
1727 if let Some(fallback) = fallback_result {
1728 fallback
1729 } else {
1730 function_response
1731 }
1732 } else {
1733 function_response
1734 };
1735
1736 let confirmation_decision =
1737 tool_actions.tool_confirmation_decision;
1738 tool_actions = tool_ctx.actions();
1739 if tool_actions.tool_confirmation_decision.is_none() {
1740 tool_actions.tool_confirmation_decision =
1741 confirmation_decision;
1742 }
1743 executed_tool = Some(tool.clone());
1745 executed_tool_response = Some(final_function_response.clone());
1746 response_content = Some(Content {
1747 role: "function".to_string(),
1748 parts: vec![Part::FunctionResponse {
1749 function_response: FunctionResponseData {
1750 name: name.clone(),
1751 response: final_function_response,
1752 },
1753 id: id.clone(),
1754 }],
1755 });
1756 } else {
1757 response_content = Some(Content {
1758 role: "function".to_string(),
1759 parts: vec![Part::FunctionResponse {
1760 function_response: FunctionResponseData {
1761 name: name.clone(),
1762 response: serde_json::json!({
1763 "error": format!("Tool {} not found", name)
1764 }),
1765 },
1766 id: id.clone(),
1767 }],
1768 });
1769 }
1770 }
1771
1772 let mut response_content = response_content.expect("tool response content is set");
1775 if run_after_tool_callbacks {
1776 let cb_ctx: Arc<dyn CallbackContext> = match tool_outcome_for_callback {
1778 Some(outcome) => Arc::new(ToolCallbackContext {
1779 inner: ctx.clone() as Arc<dyn CallbackContext>,
1780 outcome,
1781 }),
1782 None => ctx.clone() as Arc<dyn CallbackContext>,
1783 };
1784 for callback in after_tool_callbacks.as_ref() {
1785 match callback(cb_ctx.clone()).await {
1786 Ok(Some(modified_content)) => {
1787 response_content = modified_content;
1788 break;
1789 }
1790 Ok(None) => continue,
1791 Err(e) => {
1792 yield Err(e);
1793 return;
1794 }
1795 }
1796 }
1797
1798 if let (Some(tool_ref), Some(tool_resp)) = (&executed_tool, executed_tool_response) {
1803 for callback in after_tool_callbacks_full.as_ref() {
1804 match callback(
1805 cb_ctx.clone(),
1806 tool_ref.clone(),
1807 args.clone(),
1808 tool_resp.clone(),
1809 ).await {
1810 Ok(Some(modified_value)) => {
1811 response_content = Content {
1813 role: "function".to_string(),
1814 parts: vec![Part::FunctionResponse {
1815 function_response: FunctionResponseData {
1816 name: name.clone(),
1817 response: modified_value,
1818 },
1819 id: id.clone(),
1820 }],
1821 };
1822 break;
1823 }
1824 Ok(None) => continue,
1825 Err(e) => {
1826 yield Err(e);
1827 return;
1828 }
1829 }
1830 }
1831 }
1832 }
1833
1834 let mut tool_event = Event::new(&invocation_id);
1836 tool_event.author = agent_name.clone();
1837 tool_event.actions = tool_actions.clone();
1838 tool_event.llm_response.content = Some(response_content.clone());
1839 yield Ok(tool_event);
1840
1841 if tool_actions.escalate || tool_actions.skip_summarization {
1843 return;
1845 }
1846
1847 conversation_history.push(response_content);
1849 }
1850 }
1851 }
1852
1853 if all_calls_are_long_running {
1857 }
1861 }
1862
1863 for callback in after_agent_callbacks.as_ref() {
1866 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1867 Ok(Some(content)) => {
1868 let mut after_event = Event::new(&invocation_id);
1870 after_event.author = agent_name.clone();
1871 after_event.llm_response.content = Some(content);
1872 yield Ok(after_event);
1873 break; }
1875 Ok(None) => {
1876 continue;
1878 }
1879 Err(e) => {
1880 yield Err(e);
1882 return;
1883 }
1884 }
1885 }
1886 };
1887
1888 Ok(Box::pin(s))
1889 }
1890}