1use adk_core::{
2 AfterAgentCallback, AfterModelCallback, AfterToolCallback, Agent, BeforeAgentCallback,
3 BeforeModelCallback, BeforeModelResult, BeforeToolCallback, CallbackContext, Content, Event,
4 EventActions, FunctionResponseData, GlobalInstructionProvider, InstructionProvider,
5 InvocationContext, Llm, LlmRequest, LlmResponse, MemoryEntry, Part, ReadonlyContext, Result,
6 Tool, ToolContext,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use tracing::Instrument;
12
13use crate::guardrails::GuardrailSet;
14
15pub const DEFAULT_MAX_ITERATIONS: u32 = 100;
17
18pub struct LlmAgent {
19 name: String,
20 description: String,
21 model: Arc<dyn Llm>,
22 instruction: Option<String>,
23 instruction_provider: Option<Arc<InstructionProvider>>,
24 global_instruction: Option<String>,
25 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
26 #[allow(dead_code)] input_schema: Option<serde_json::Value>,
28 output_schema: Option<serde_json::Value>,
29 #[allow(dead_code)] disallow_transfer_to_parent: bool,
31 #[allow(dead_code)] disallow_transfer_to_peers: bool,
33 include_contents: adk_core::IncludeContents,
34 tools: Vec<Arc<dyn Tool>>,
35 sub_agents: Vec<Arc<dyn Agent>>,
36 output_key: Option<String>,
37 max_iterations: u32,
39 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
40 after_callbacks: Arc<Vec<AfterAgentCallback>>,
41 before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
42 after_model_callbacks: Arc<Vec<AfterModelCallback>>,
43 before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
44 after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
45 #[allow(dead_code)] input_guardrails: GuardrailSet,
47 #[allow(dead_code)] output_guardrails: GuardrailSet,
49}
50
51impl std::fmt::Debug for LlmAgent {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("LlmAgent")
54 .field("name", &self.name)
55 .field("description", &self.description)
56 .field("model", &self.model.name())
57 .field("instruction", &self.instruction)
58 .field("tools_count", &self.tools.len())
59 .field("sub_agents_count", &self.sub_agents.len())
60 .finish()
61 }
62}
63
64pub struct LlmAgentBuilder {
65 name: String,
66 description: Option<String>,
67 model: Option<Arc<dyn Llm>>,
68 instruction: Option<String>,
69 instruction_provider: Option<Arc<InstructionProvider>>,
70 global_instruction: Option<String>,
71 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
72 input_schema: Option<serde_json::Value>,
73 output_schema: Option<serde_json::Value>,
74 disallow_transfer_to_parent: bool,
75 disallow_transfer_to_peers: bool,
76 include_contents: adk_core::IncludeContents,
77 tools: Vec<Arc<dyn Tool>>,
78 sub_agents: Vec<Arc<dyn Agent>>,
79 output_key: Option<String>,
80 max_iterations: u32,
81 before_callbacks: Vec<BeforeAgentCallback>,
82 after_callbacks: Vec<AfterAgentCallback>,
83 before_model_callbacks: Vec<BeforeModelCallback>,
84 after_model_callbacks: Vec<AfterModelCallback>,
85 before_tool_callbacks: Vec<BeforeToolCallback>,
86 after_tool_callbacks: Vec<AfterToolCallback>,
87 input_guardrails: GuardrailSet,
88 output_guardrails: GuardrailSet,
89}
90
91impl LlmAgentBuilder {
92 pub fn new(name: impl Into<String>) -> Self {
93 Self {
94 name: name.into(),
95 description: None,
96 model: None,
97 instruction: None,
98 instruction_provider: None,
99 global_instruction: None,
100 global_instruction_provider: None,
101 input_schema: None,
102 output_schema: None,
103 disallow_transfer_to_parent: false,
104 disallow_transfer_to_peers: false,
105 include_contents: adk_core::IncludeContents::Default,
106 tools: Vec::new(),
107 sub_agents: Vec::new(),
108 output_key: None,
109 max_iterations: DEFAULT_MAX_ITERATIONS,
110 before_callbacks: Vec::new(),
111 after_callbacks: Vec::new(),
112 before_model_callbacks: Vec::new(),
113 after_model_callbacks: Vec::new(),
114 before_tool_callbacks: Vec::new(),
115 after_tool_callbacks: Vec::new(),
116 input_guardrails: GuardrailSet::new(),
117 output_guardrails: GuardrailSet::new(),
118 }
119 }
120
121 pub fn description(mut self, desc: impl Into<String>) -> Self {
122 self.description = Some(desc.into());
123 self
124 }
125
126 pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
127 self.model = Some(model);
128 self
129 }
130
131 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
132 self.instruction = Some(instruction.into());
133 self
134 }
135
136 pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
137 self.instruction_provider = Some(Arc::new(provider));
138 self
139 }
140
141 pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
142 self.global_instruction = Some(instruction.into());
143 self
144 }
145
146 pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
147 self.global_instruction_provider = Some(Arc::new(provider));
148 self
149 }
150
151 pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
152 self.input_schema = Some(schema);
153 self
154 }
155
156 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
157 self.output_schema = Some(schema);
158 self
159 }
160
161 pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
162 self.disallow_transfer_to_parent = disallow;
163 self
164 }
165
166 pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
167 self.disallow_transfer_to_peers = disallow;
168 self
169 }
170
171 pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
172 self.include_contents = include;
173 self
174 }
175
176 pub fn output_key(mut self, key: impl Into<String>) -> Self {
177 self.output_key = Some(key.into());
178 self
179 }
180
181 pub fn max_iterations(mut self, max: u32) -> Self {
184 self.max_iterations = max;
185 self
186 }
187
188 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
189 self.tools.push(tool);
190 self
191 }
192
193 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
194 self.sub_agents.push(agent);
195 self
196 }
197
198 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
199 self.before_callbacks.push(callback);
200 self
201 }
202
203 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
204 self.after_callbacks.push(callback);
205 self
206 }
207
208 pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
209 self.before_model_callbacks.push(callback);
210 self
211 }
212
213 pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
214 self.after_model_callbacks.push(callback);
215 self
216 }
217
218 pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
219 self.before_tool_callbacks.push(callback);
220 self
221 }
222
223 pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
224 self.after_tool_callbacks.push(callback);
225 self
226 }
227
228 pub fn input_guardrails(mut self, guardrails: GuardrailSet) -> Self {
237 self.input_guardrails = guardrails;
238 self
239 }
240
241 pub fn output_guardrails(mut self, guardrails: GuardrailSet) -> Self {
250 self.output_guardrails = guardrails;
251 self
252 }
253
254 pub fn build(self) -> Result<LlmAgent> {
255 let model =
256 self.model.ok_or_else(|| adk_core::AdkError::Agent("Model is required".to_string()))?;
257
258 Ok(LlmAgent {
259 name: self.name,
260 description: self.description.unwrap_or_default(),
261 model,
262 instruction: self.instruction,
263 instruction_provider: self.instruction_provider,
264 global_instruction: self.global_instruction,
265 global_instruction_provider: self.global_instruction_provider,
266 input_schema: self.input_schema,
267 output_schema: self.output_schema,
268 disallow_transfer_to_parent: self.disallow_transfer_to_parent,
269 disallow_transfer_to_peers: self.disallow_transfer_to_peers,
270 include_contents: self.include_contents,
271 tools: self.tools,
272 sub_agents: self.sub_agents,
273 output_key: self.output_key,
274 max_iterations: self.max_iterations,
275 before_callbacks: Arc::new(self.before_callbacks),
276 after_callbacks: Arc::new(self.after_callbacks),
277 before_model_callbacks: Arc::new(self.before_model_callbacks),
278 after_model_callbacks: Arc::new(self.after_model_callbacks),
279 before_tool_callbacks: Arc::new(self.before_tool_callbacks),
280 after_tool_callbacks: Arc::new(self.after_tool_callbacks),
281 input_guardrails: self.input_guardrails,
282 output_guardrails: self.output_guardrails,
283 })
284 }
285}
286
287struct AgentToolContext {
290 parent_ctx: Arc<dyn InvocationContext>,
291 function_call_id: String,
292 actions: Mutex<EventActions>,
293}
294
295impl AgentToolContext {
296 fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
297 Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
298 }
299}
300
301#[async_trait]
302impl ReadonlyContext for AgentToolContext {
303 fn invocation_id(&self) -> &str {
304 self.parent_ctx.invocation_id()
305 }
306
307 fn agent_name(&self) -> &str {
308 self.parent_ctx.agent_name()
309 }
310
311 fn user_id(&self) -> &str {
312 self.parent_ctx.user_id()
314 }
315
316 fn app_name(&self) -> &str {
317 self.parent_ctx.app_name()
319 }
320
321 fn session_id(&self) -> &str {
322 self.parent_ctx.session_id()
324 }
325
326 fn branch(&self) -> &str {
327 self.parent_ctx.branch()
328 }
329
330 fn user_content(&self) -> &Content {
331 self.parent_ctx.user_content()
332 }
333}
334
335#[async_trait]
336impl CallbackContext for AgentToolContext {
337 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
338 self.parent_ctx.artifacts()
340 }
341}
342
343#[async_trait]
344impl ToolContext for AgentToolContext {
345 fn function_call_id(&self) -> &str {
346 &self.function_call_id
347 }
348
349 fn actions(&self) -> EventActions {
350 self.actions.lock().unwrap().clone()
351 }
352
353 fn set_actions(&self, actions: EventActions) {
354 *self.actions.lock().unwrap() = actions;
355 }
356
357 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
358 if let Some(memory) = self.parent_ctx.memory() {
360 memory.search(query).await
361 } else {
362 Ok(vec![])
363 }
364 }
365}
366
367#[async_trait]
368impl Agent for LlmAgent {
369 fn name(&self) -> &str {
370 &self.name
371 }
372
373 fn description(&self) -> &str {
374 &self.description
375 }
376
377 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
378 &self.sub_agents
379 }
380
381 #[adk_telemetry::instrument(
382 skip(self, ctx),
383 fields(
384 agent.name = %self.name,
385 agent.description = %self.description,
386 invocation.id = %ctx.invocation_id(),
387 user.id = %ctx.user_id(),
388 session.id = %ctx.session_id()
389 )
390 )]
391 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
392 adk_telemetry::info!("Starting agent execution");
393
394 let agent_name = self.name.clone();
395 let invocation_id = ctx.invocation_id().to_string();
396 let model = self.model.clone();
397 let tools = self.tools.clone();
398 let sub_agents = self.sub_agents.clone();
399
400 let instruction = self.instruction.clone();
401 let instruction_provider = self.instruction_provider.clone();
402 let global_instruction = self.global_instruction.clone();
403 let global_instruction_provider = self.global_instruction_provider.clone();
404 let output_key = self.output_key.clone();
405 let output_schema = self.output_schema.clone();
406 let include_contents = self.include_contents;
407 let max_iterations = self.max_iterations;
408 let before_agent_callbacks = self.before_callbacks.clone();
410 let after_agent_callbacks = self.after_callbacks.clone();
411 let before_model_callbacks = self.before_model_callbacks.clone();
412 let after_model_callbacks = self.after_model_callbacks.clone();
413 let _before_tool_callbacks = self.before_tool_callbacks.clone();
414 let _after_tool_callbacks = self.after_tool_callbacks.clone();
415
416 let s = stream! {
417 for callback in before_agent_callbacks.as_ref() {
421 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
422 Ok(Some(content)) => {
423 let mut early_event = Event::new(&invocation_id);
425 early_event.author = agent_name.clone();
426 early_event.llm_response.content = Some(content);
427 yield Ok(early_event);
428
429 for after_callback in after_agent_callbacks.as_ref() {
431 match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
432 Ok(Some(after_content)) => {
433 let mut after_event = Event::new(&invocation_id);
434 after_event.author = agent_name.clone();
435 after_event.llm_response.content = Some(after_content);
436 yield Ok(after_event);
437 return;
438 }
439 Ok(None) => continue,
440 Err(e) => {
441 yield Err(e);
442 return;
443 }
444 }
445 }
446 return;
447 }
448 Ok(None) => {
449 continue;
451 }
452 Err(e) => {
453 yield Err(e);
455 return;
456 }
457 }
458 }
459
460 let mut conversation_history = Vec::new();
462
463 if let Some(provider) = &global_instruction_provider {
466 let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
468 if !global_inst.is_empty() {
469 conversation_history.push(Content {
470 role: "user".to_string(),
471 parts: vec![Part::Text { text: global_inst }],
472 });
473 }
474 } else if let Some(ref template) = global_instruction {
475 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
477 if !processed.is_empty() {
478 conversation_history.push(Content {
479 role: "user".to_string(),
480 parts: vec![Part::Text { text: processed }],
481 });
482 }
483 }
484
485 if let Some(provider) = &instruction_provider {
488 let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
490 if !inst.is_empty() {
491 conversation_history.push(Content {
492 role: "user".to_string(),
493 parts: vec![Part::Text { text: inst }],
494 });
495 }
496 } else if let Some(ref template) = instruction {
497 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
499 if !processed.is_empty() {
500 conversation_history.push(Content {
501 role: "user".to_string(),
502 parts: vec![Part::Text { text: processed }],
503 });
504 }
505 }
506
507 let session_history = ctx.session().conversation_history();
511 conversation_history.extend(session_history);
512
513 let mut conversation_history = match include_contents {
516 adk_core::IncludeContents::None => {
517 let mut filtered = Vec::new();
520
521 let instruction_count = conversation_history.iter()
523 .take_while(|c| c.role == "user" && c.parts.iter().any(|p| {
524 if let Part::Text { text } = p {
525 !text.is_empty()
527 } else {
528 false
529 }
530 }))
531 .count();
532
533 filtered.extend(conversation_history.iter().take(instruction_count).cloned());
535
536 if let Some(last) = conversation_history.last() {
538 if last.role == "user" {
539 filtered.push(last.clone());
540 }
541 }
542
543 filtered
544 }
545 adk_core::IncludeContents::Default => {
546 conversation_history
548 }
549 };
550
551 let mut tool_declarations = std::collections::HashMap::new();
554 for tool in &tools {
555 let mut decl = serde_json::json!({
558 "name": tool.name(),
559 "description": tool.enhanced_description(),
560 });
561
562 if let Some(params) = tool.parameters_schema() {
563 decl["parameters"] = params;
564 }
565
566 if let Some(response) = tool.response_schema() {
567 decl["response"] = response;
568 }
569
570 tool_declarations.insert(tool.name().to_string(), decl);
571 }
572
573 if !sub_agents.is_empty() {
575 let transfer_tool_name = "transfer_to_agent";
576 let transfer_tool_decl = serde_json::json!({
577 "name": transfer_tool_name,
578 "description": "Transfer execution to another agent.",
579 "parameters": {
580 "type": "object",
581 "properties": {
582 "agent_name": {
583 "type": "string",
584 "description": "The name of the agent to transfer to."
585 }
586 },
587 "required": ["agent_name"]
588 }
589 });
590 tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
591 }
592
593
594 let mut iteration = 0;
596
597 loop {
598 iteration += 1;
599 if iteration > max_iterations {
600 yield Err(adk_core::AdkError::Agent(
601 format!("Max iterations ({}) exceeded", max_iterations)
602 ));
603 return;
604 }
605
606 let config = output_schema.as_ref().map(|schema| {
608 adk_core::GenerateContentConfig {
609 temperature: None,
610 top_p: None,
611 top_k: None,
612 max_output_tokens: None,
613 response_schema: Some(schema.clone()),
614 }
615 });
616
617 let request = LlmRequest {
618 model: model.name().to_string(),
619 contents: conversation_history.clone(),
620 tools: tool_declarations.clone(),
621 config,
622 };
623
624 let mut current_request = request;
627 let mut model_response_override = None;
628 for callback in before_model_callbacks.as_ref() {
629 match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
630 Ok(BeforeModelResult::Continue(modified_request)) => {
631 current_request = modified_request;
633 }
634 Ok(BeforeModelResult::Skip(response)) => {
635 model_response_override = Some(response);
637 break;
638 }
639 Err(e) => {
640 yield Err(e);
642 return;
643 }
644 }
645 }
646 let request = current_request;
647
648 let mut accumulated_content: Option<Content> = None;
650
651 if let Some(cached_response) = model_response_override {
652 let mut cached_event = Event::new(&invocation_id);
655 cached_event.author = agent_name.clone();
656 cached_event.llm_response.content = cached_response.content.clone();
657 cached_event.llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
658 cached_event.gcp_llm_request = Some(serde_json::to_string(&request).unwrap_or_default());
659 cached_event.gcp_llm_response = Some(serde_json::to_string(&cached_response).unwrap_or_default());
660
661 if let Some(ref content) = cached_response.content {
663 let long_running_ids: Vec<String> = content.parts.iter()
664 .filter_map(|p| {
665 if let Part::FunctionCall { name, .. } = p {
666 if let Some(tool) = tools.iter().find(|t| t.name() == name) {
667 if tool.is_long_running() {
668 return Some(name.clone());
669 }
670 }
671 }
672 None
673 })
674 .collect();
675 cached_event.long_running_tool_ids = long_running_ids;
676 }
677
678 yield Ok(cached_event);
679
680 accumulated_content = cached_response.content;
681 } else {
682 let request_json = serde_json::to_string(&request).unwrap_or_default();
684
685 let llm_ts = std::time::SystemTime::now()
687 .duration_since(std::time::UNIX_EPOCH)
688 .unwrap_or_default()
689 .as_nanos();
690 let llm_event_id = format!("{}_llm_{}", invocation_id, llm_ts);
691 let llm_span = tracing::info_span!(
692 "call_llm",
693 "gcp.vertex.agent.event_id" = %llm_event_id,
694 "gcp.vertex.agent.invocation_id" = %invocation_id,
695 "gcp.vertex.agent.session_id" = %ctx.session_id(),
696 "gcp.vertex.agent.llm_request" = %request_json,
697 "gcp.vertex.agent.llm_response" = tracing::field::Empty );
699 let _llm_guard = llm_span.enter();
700
701 use adk_core::StreamingMode;
703 let streaming_mode = ctx.run_config().streaming_mode;
704 let should_stream_to_client = matches!(streaming_mode, StreamingMode::SSE | StreamingMode::Bidi);
705
706 let mut response_stream = model.generate_content(request, true).await?;
708
709 use futures::StreamExt;
710
711 let mut last_chunk: Option<LlmResponse> = None;
713
714 while let Some(chunk_result) = response_stream.next().await {
716 let mut chunk = match chunk_result {
717 Ok(c) => c,
718 Err(e) => {
719 yield Err(e);
720 return;
721 }
722 };
723
724 for callback in after_model_callbacks.as_ref() {
727 match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
728 Ok(Some(modified_chunk)) => {
729 chunk = modified_chunk;
731 break;
732 }
733 Ok(None) => {
734 continue;
736 }
737 Err(e) => {
738 yield Err(e);
740 return;
741 }
742 }
743 }
744
745 if let Some(chunk_content) = chunk.content.clone() {
747 if let Some(ref mut acc) = accumulated_content {
748 acc.parts.extend(chunk_content.parts);
749 } else {
750 accumulated_content = Some(chunk_content);
751 }
752 }
753
754 if should_stream_to_client {
756 let mut partial_event = Event::with_id(&llm_event_id, &invocation_id);
757 partial_event.author = agent_name.clone();
758 partial_event.llm_request = Some(request_json.clone());
759 partial_event.gcp_llm_request = Some(request_json.clone());
760 partial_event.gcp_llm_response = Some(serde_json::to_string(&chunk).unwrap_or_default());
761 partial_event.llm_response.partial = chunk.partial;
762 partial_event.llm_response.turn_complete = chunk.turn_complete;
763 partial_event.llm_response.finish_reason = chunk.finish_reason;
764 partial_event.llm_response.usage_metadata = chunk.usage_metadata.clone();
765 partial_event.llm_response.content = chunk.content.clone();
766
767 if let Some(ref content) = chunk.content {
769 let long_running_ids: Vec<String> = content.parts.iter()
770 .filter_map(|p| {
771 if let Part::FunctionCall { name, .. } = p {
772 if let Some(tool) = tools.iter().find(|t| t.name() == name) {
773 if tool.is_long_running() {
774 return Some(name.clone());
775 }
776 }
777 }
778 None
779 })
780 .collect();
781 partial_event.long_running_tool_ids = long_running_ids;
782 }
783
784 yield Ok(partial_event);
785 }
786
787 last_chunk = Some(chunk.clone());
789
790 if chunk.turn_complete {
792 break;
793 }
794 }
795
796 if !should_stream_to_client {
798 let mut final_event = Event::with_id(&llm_event_id, &invocation_id);
799 final_event.author = agent_name.clone();
800 final_event.llm_request = Some(request_json.clone());
801 final_event.gcp_llm_request = Some(request_json.clone());
802 final_event.llm_response.content = accumulated_content.clone();
803 final_event.llm_response.partial = false;
804 final_event.llm_response.turn_complete = true;
805
806 if let Some(ref last) = last_chunk {
808 final_event.llm_response.finish_reason = last.finish_reason;
809 final_event.llm_response.usage_metadata = last.usage_metadata.clone();
810 final_event.gcp_llm_response = Some(serde_json::to_string(last).unwrap_or_default());
811 }
812
813 if let Some(ref content) = accumulated_content {
815 let long_running_ids: Vec<String> = content.parts.iter()
816 .filter_map(|p| {
817 if let Part::FunctionCall { name, .. } = p {
818 if let Some(tool) = tools.iter().find(|t| t.name() == name) {
819 if tool.is_long_running() {
820 return Some(name.clone());
821 }
822 }
823 }
824 None
825 })
826 .collect();
827 final_event.long_running_tool_ids = long_running_ids;
828 }
829
830 yield Ok(final_event);
831 }
832
833 if let Some(ref content) = accumulated_content {
835 let response_json = serde_json::to_string(content).unwrap_or_default();
836 llm_span.record("gcp.vertex.agent.llm_response", &response_json);
837 }
838 }
839
840 let function_call_names: Vec<String> = accumulated_content.as_ref()
842 .map(|c| c.parts.iter()
843 .filter_map(|p| {
844 if let Part::FunctionCall { name, .. } = p {
845 Some(name.clone())
846 } else {
847 None
848 }
849 })
850 .collect())
851 .unwrap_or_default();
852
853 let has_function_calls = !function_call_names.is_empty();
854
855 let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
859 tools.iter()
860 .find(|t| t.name() == name)
861 .map(|t| t.is_long_running())
862 .unwrap_or(false)
863 });
864
865 if let Some(ref content) = accumulated_content {
867 conversation_history.push(content.clone());
868
869 if let Some(ref output_key) = output_key {
871 if !has_function_calls { let mut text_parts = String::new();
873 for part in &content.parts {
874 if let Part::Text { text } = part {
875 text_parts.push_str(text);
876 }
877 }
878 if !text_parts.is_empty() {
879 let mut state_event = Event::new(&invocation_id);
881 state_event.author = agent_name.clone();
882 state_event.actions.state_delta.insert(
883 output_key.clone(),
884 serde_json::Value::String(text_parts),
885 );
886 yield Ok(state_event);
887 }
888 }
889 }
890 }
891
892 if !has_function_calls {
893 if let Some(ref content) = accumulated_content {
896 let response_json = serde_json::to_string(content).unwrap_or_default();
897 tracing::Span::current().record("gcp.vertex.agent.llm_response", &response_json);
898 }
899
900 tracing::info!(agent.name = %agent_name, "Agent execution complete");
901 break;
902 }
903
904 if let Some(content) = &accumulated_content {
906 for part in &content.parts {
907 if let Part::FunctionCall { name, args, id } = part {
908 if name == "transfer_to_agent" {
910 let target_agent = args.get("agent_name")
911 .and_then(|v| v.as_str())
912 .unwrap_or_default()
913 .to_string();
914
915 let mut transfer_event = Event::new(&invocation_id);
916 transfer_event.author = agent_name.clone();
917 transfer_event.actions.transfer_to_agent = Some(target_agent);
918
919 yield Ok(transfer_event);
920 return;
921 }
922
923
924 let (tool_result, tool_actions) = if let Some(tool) = tools.iter().find(|t| t.name() == name) {
926 let tool_ctx: Arc<dyn ToolContext> = Arc::new(AgentToolContext::new(
928 ctx.clone(),
929 format!("{}_{}", invocation_id, name),
930 ));
931
932 let span_name = format!("execute_tool {}", name);
934 let tool_span = tracing::info_span!(
935 "",
936 otel.name = %span_name,
937 tool.name = %name,
938 "gcp.vertex.agent.event_id" = %format!("{}_{}", invocation_id, name),
939 "gcp.vertex.agent.invocation_id" = %invocation_id,
940 "gcp.vertex.agent.session_id" = %ctx.session_id()
941 );
942
943 let result = async {
945 tracing::info!(tool.name = %name, tool.args = %args, "tool_call");
946 match tool.execute(tool_ctx.clone(), args.clone()).await {
947 Ok(result) => {
948 tracing::info!(tool.name = %name, tool.result = %result, "tool_result");
949 result
950 }
951 Err(e) => {
952 tracing::warn!(tool.name = %name, error = %e, "tool_error");
953 serde_json::json!({ "error": e.to_string() })
954 }
955 }
956 }.instrument(tool_span).await;
957
958 (result, tool_ctx.actions())
959 } else {
960 (serde_json::json!({ "error": format!("Tool {} not found", name) }), EventActions::default())
961 };
962
963 let mut tool_event = Event::new(&invocation_id);
965 tool_event.author = agent_name.clone();
966 tool_event.actions = tool_actions.clone();
967 tool_event.llm_response.content = Some(Content {
968 role: "function".to_string(),
969 parts: vec![Part::FunctionResponse {
970 function_response: FunctionResponseData {
971 name: name.clone(),
972 response: tool_result.clone(),
973 },
974 id: id.clone(),
975 }],
976 });
977 yield Ok(tool_event);
978
979 if tool_actions.escalate || tool_actions.skip_summarization {
981 return;
983 }
984
985 conversation_history.push(Content {
987 role: "function".to_string(),
988 parts: vec![Part::FunctionResponse {
989 function_response: FunctionResponseData {
990 name: name.clone(),
991 response: tool_result,
992 },
993 id: id.clone(),
994 }],
995 });
996 }
997 }
998 }
999
1000 if all_calls_are_long_running {
1004 }
1008 }
1009
1010 for callback in after_agent_callbacks.as_ref() {
1013 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
1014 Ok(Some(content)) => {
1015 let mut after_event = Event::new(&invocation_id);
1017 after_event.author = agent_name.clone();
1018 after_event.llm_response.content = Some(content);
1019 yield Ok(after_event);
1020 break; }
1022 Ok(None) => {
1023 continue;
1025 }
1026 Err(e) => {
1027 yield Err(e);
1029 return;
1030 }
1031 }
1032 }
1033 };
1034
1035 Ok(Box::pin(s))
1036 }
1037}