1use adk_core::{
2 AfterAgentCallback, AfterModelCallback, AfterToolCallback, Agent, BeforeAgentCallback,
3 BeforeModelCallback, BeforeModelResult, BeforeToolCallback, CallbackContext, Content, Event,
4 EventActions, GlobalInstructionProvider, InstructionProvider, InvocationContext, Llm,
5 LlmRequest, MemoryEntry, Part, ReadonlyContext, Result, Tool, ToolContext,
6};
7use async_stream::stream;
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex};
10
11pub struct LlmAgent {
12 name: String,
13 description: String,
14 model: Arc<dyn Llm>,
15 instruction: Option<String>,
16 instruction_provider: Option<Arc<InstructionProvider>>,
17 global_instruction: Option<String>,
18 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
19 #[allow(dead_code)] input_schema: Option<serde_json::Value>,
21 output_schema: Option<serde_json::Value>,
22 #[allow(dead_code)] disallow_transfer_to_parent: bool,
24 #[allow(dead_code)] disallow_transfer_to_peers: bool,
26 include_contents: adk_core::IncludeContents,
27 tools: Vec<Arc<dyn Tool>>,
28 sub_agents: Vec<Arc<dyn Agent>>,
29 output_key: Option<String>,
30 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
31 after_callbacks: Arc<Vec<AfterAgentCallback>>,
32 before_model_callbacks: Arc<Vec<BeforeModelCallback>>,
33 after_model_callbacks: Arc<Vec<AfterModelCallback>>,
34 before_tool_callbacks: Arc<Vec<BeforeToolCallback>>,
35 after_tool_callbacks: Arc<Vec<AfterToolCallback>>,
36}
37
38impl std::fmt::Debug for LlmAgent {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("LlmAgent")
41 .field("name", &self.name)
42 .field("description", &self.description)
43 .field("model", &self.model.name())
44 .field("instruction", &self.instruction)
45 .field("tools_count", &self.tools.len())
46 .field("sub_agents_count", &self.sub_agents.len())
47 .finish()
48 }
49}
50
51pub struct LlmAgentBuilder {
52 name: String,
53 description: Option<String>,
54 model: Option<Arc<dyn Llm>>,
55 instruction: Option<String>,
56 instruction_provider: Option<Arc<InstructionProvider>>,
57 global_instruction: Option<String>,
58 global_instruction_provider: Option<Arc<GlobalInstructionProvider>>,
59 input_schema: Option<serde_json::Value>,
60 output_schema: Option<serde_json::Value>,
61 disallow_transfer_to_parent: bool,
62 disallow_transfer_to_peers: bool,
63 include_contents: adk_core::IncludeContents,
64 tools: Vec<Arc<dyn Tool>>,
65 sub_agents: Vec<Arc<dyn Agent>>,
66 output_key: Option<String>,
67 before_callbacks: Vec<BeforeAgentCallback>,
68 after_callbacks: Vec<AfterAgentCallback>,
69 before_model_callbacks: Vec<BeforeModelCallback>,
70 after_model_callbacks: Vec<AfterModelCallback>,
71 before_tool_callbacks: Vec<BeforeToolCallback>,
72 after_tool_callbacks: Vec<AfterToolCallback>,
73}
74
75impl LlmAgentBuilder {
76 pub fn new(name: impl Into<String>) -> Self {
77 Self {
78 name: name.into(),
79 description: None,
80 model: None,
81 instruction: None,
82 instruction_provider: None,
83 global_instruction: None,
84 global_instruction_provider: None,
85 input_schema: None,
86 output_schema: None,
87 disallow_transfer_to_parent: false,
88 disallow_transfer_to_peers: false,
89 include_contents: adk_core::IncludeContents::Default,
90 tools: Vec::new(),
91 sub_agents: Vec::new(),
92 output_key: None,
93 before_callbacks: Vec::new(),
94 after_callbacks: Vec::new(),
95 before_model_callbacks: Vec::new(),
96 after_model_callbacks: Vec::new(),
97 before_tool_callbacks: Vec::new(),
98 after_tool_callbacks: Vec::new(),
99 }
100 }
101
102 pub fn description(mut self, desc: impl Into<String>) -> Self {
103 self.description = Some(desc.into());
104 self
105 }
106
107 pub fn model(mut self, model: Arc<dyn Llm>) -> Self {
108 self.model = Some(model);
109 self
110 }
111
112 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
113 self.instruction = Some(instruction.into());
114 self
115 }
116
117 pub fn instruction_provider(mut self, provider: InstructionProvider) -> Self {
118 self.instruction_provider = Some(Arc::new(provider));
119 self
120 }
121
122 pub fn global_instruction(mut self, instruction: impl Into<String>) -> Self {
123 self.global_instruction = Some(instruction.into());
124 self
125 }
126
127 pub fn global_instruction_provider(mut self, provider: GlobalInstructionProvider) -> Self {
128 self.global_instruction_provider = Some(Arc::new(provider));
129 self
130 }
131
132 pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
133 self.input_schema = Some(schema);
134 self
135 }
136
137 pub fn output_schema(mut self, schema: serde_json::Value) -> Self {
138 self.output_schema = Some(schema);
139 self
140 }
141
142 pub fn disallow_transfer_to_parent(mut self, disallow: bool) -> Self {
143 self.disallow_transfer_to_parent = disallow;
144 self
145 }
146
147 pub fn disallow_transfer_to_peers(mut self, disallow: bool) -> Self {
148 self.disallow_transfer_to_peers = disallow;
149 self
150 }
151
152 pub fn include_contents(mut self, include: adk_core::IncludeContents) -> Self {
153 self.include_contents = include;
154 self
155 }
156
157 pub fn output_key(mut self, key: impl Into<String>) -> Self {
158 self.output_key = Some(key.into());
159 self
160 }
161
162 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
163 self.tools.push(tool);
164 self
165 }
166
167 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
168 self.sub_agents.push(agent);
169 self
170 }
171
172 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
173 self.before_callbacks.push(callback);
174 self
175 }
176
177 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
178 self.after_callbacks.push(callback);
179 self
180 }
181
182 pub fn before_model_callback(mut self, callback: BeforeModelCallback) -> Self {
183 self.before_model_callbacks.push(callback);
184 self
185 }
186
187 pub fn after_model_callback(mut self, callback: AfterModelCallback) -> Self {
188 self.after_model_callbacks.push(callback);
189 self
190 }
191
192 pub fn before_tool_callback(mut self, callback: BeforeToolCallback) -> Self {
193 self.before_tool_callbacks.push(callback);
194 self
195 }
196
197 pub fn after_tool_callback(mut self, callback: AfterToolCallback) -> Self {
198 self.after_tool_callbacks.push(callback);
199 self
200 }
201
202 pub fn build(self) -> Result<LlmAgent> {
203 let model =
204 self.model.ok_or_else(|| adk_core::AdkError::Agent("Model is required".to_string()))?;
205
206 Ok(LlmAgent {
207 name: self.name,
208 description: self.description.unwrap_or_default(),
209 model,
210 instruction: self.instruction,
211 instruction_provider: self.instruction_provider,
212 global_instruction: self.global_instruction,
213 global_instruction_provider: self.global_instruction_provider,
214 input_schema: self.input_schema,
215 output_schema: self.output_schema,
216 disallow_transfer_to_parent: self.disallow_transfer_to_parent,
217 disallow_transfer_to_peers: self.disallow_transfer_to_peers,
218 include_contents: self.include_contents,
219 tools: self.tools,
220 sub_agents: self.sub_agents,
221 output_key: self.output_key,
222 before_callbacks: Arc::new(self.before_callbacks),
223 after_callbacks: Arc::new(self.after_callbacks),
224 before_model_callbacks: Arc::new(self.before_model_callbacks),
225 after_model_callbacks: Arc::new(self.after_model_callbacks),
226 before_tool_callbacks: Arc::new(self.before_tool_callbacks),
227 after_tool_callbacks: Arc::new(self.after_tool_callbacks),
228 })
229 }
230}
231
232struct AgentToolContext {
235 parent_ctx: Arc<dyn InvocationContext>,
236 function_call_id: String,
237 actions: Mutex<EventActions>,
238}
239
240impl AgentToolContext {
241 fn new(parent_ctx: Arc<dyn InvocationContext>, function_call_id: String) -> Self {
242 Self { parent_ctx, function_call_id, actions: Mutex::new(EventActions::default()) }
243 }
244}
245
246#[async_trait]
247impl ReadonlyContext for AgentToolContext {
248 fn invocation_id(&self) -> &str {
249 self.parent_ctx.invocation_id()
250 }
251
252 fn agent_name(&self) -> &str {
253 self.parent_ctx.agent_name()
254 }
255
256 fn user_id(&self) -> &str {
257 self.parent_ctx.user_id()
259 }
260
261 fn app_name(&self) -> &str {
262 self.parent_ctx.app_name()
264 }
265
266 fn session_id(&self) -> &str {
267 self.parent_ctx.session_id()
269 }
270
271 fn branch(&self) -> &str {
272 self.parent_ctx.branch()
273 }
274
275 fn user_content(&self) -> &Content {
276 self.parent_ctx.user_content()
277 }
278}
279
280#[async_trait]
281impl CallbackContext for AgentToolContext {
282 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
283 self.parent_ctx.artifacts()
285 }
286}
287
288#[async_trait]
289impl ToolContext for AgentToolContext {
290 fn function_call_id(&self) -> &str {
291 &self.function_call_id
292 }
293
294 fn actions(&self) -> EventActions {
295 self.actions.lock().unwrap().clone()
296 }
297
298 fn set_actions(&self, actions: EventActions) {
299 *self.actions.lock().unwrap() = actions;
300 }
301
302 async fn search_memory(&self, query: &str) -> Result<Vec<MemoryEntry>> {
303 if let Some(memory) = self.parent_ctx.memory() {
305 memory.search(query).await
306 } else {
307 Ok(vec![])
308 }
309 }
310}
311
312#[async_trait]
313impl Agent for LlmAgent {
314 fn name(&self) -> &str {
315 &self.name
316 }
317
318 fn description(&self) -> &str {
319 &self.description
320 }
321
322 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
323 &self.sub_agents
324 }
325
326 #[adk_telemetry::instrument(
327 skip(self, ctx),
328 fields(
329 agent.name = %self.name,
330 agent.description = %self.description,
331 invocation.id = %ctx.invocation_id(),
332 user.id = %ctx.user_id(),
333 session.id = %ctx.session_id()
334 )
335 )]
336 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<adk_core::EventStream> {
337 adk_telemetry::info!("Starting agent execution");
338
339 let agent_name = self.name.clone();
340 let invocation_id = ctx.invocation_id().to_string();
341 let model = self.model.clone();
342 let tools = self.tools.clone();
343 let sub_agents = self.sub_agents.clone();
344
345 let instruction = self.instruction.clone();
346 let instruction_provider = self.instruction_provider.clone();
347 let global_instruction = self.global_instruction.clone();
348 let global_instruction_provider = self.global_instruction_provider.clone();
349 let output_key = self.output_key.clone();
350 let output_schema = self.output_schema.clone();
351 let include_contents = self.include_contents;
352 let before_agent_callbacks = self.before_callbacks.clone();
354 let after_agent_callbacks = self.after_callbacks.clone();
355 let before_model_callbacks = self.before_model_callbacks.clone();
356 let after_model_callbacks = self.after_model_callbacks.clone();
357 let _before_tool_callbacks = self.before_tool_callbacks.clone();
358 let _after_tool_callbacks = self.after_tool_callbacks.clone();
359
360 let s = stream! {
361 for callback in before_agent_callbacks.as_ref() {
365 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
366 Ok(Some(content)) => {
367 let mut early_event = Event::new(&invocation_id);
369 early_event.author = agent_name.clone();
370 early_event.llm_response.content = Some(content);
371 yield Ok(early_event);
372
373 for after_callback in after_agent_callbacks.as_ref() {
375 match after_callback(ctx.clone() as Arc<dyn CallbackContext>).await {
376 Ok(Some(after_content)) => {
377 let mut after_event = Event::new(&invocation_id);
378 after_event.author = agent_name.clone();
379 after_event.llm_response.content = Some(after_content);
380 yield Ok(after_event);
381 return;
382 }
383 Ok(None) => continue,
384 Err(e) => {
385 yield Err(e);
386 return;
387 }
388 }
389 }
390 return;
391 }
392 Ok(None) => {
393 continue;
395 }
396 Err(e) => {
397 yield Err(e);
399 return;
400 }
401 }
402 }
403
404 let mut conversation_history = Vec::new();
406
407 if let Some(provider) = &global_instruction_provider {
410 let global_inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
412 if !global_inst.is_empty() {
413 conversation_history.push(Content {
414 role: "user".to_string(),
415 parts: vec![Part::Text { text: global_inst }],
416 });
417 }
418 } else if let Some(ref template) = global_instruction {
419 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
421 if !processed.is_empty() {
422 conversation_history.push(Content {
423 role: "user".to_string(),
424 parts: vec![Part::Text { text: processed }],
425 });
426 }
427 }
428
429 if let Some(provider) = &instruction_provider {
432 let inst = provider(ctx.clone() as Arc<dyn ReadonlyContext>).await?;
434 if !inst.is_empty() {
435 conversation_history.push(Content {
436 role: "user".to_string(),
437 parts: vec![Part::Text { text: inst }],
438 });
439 }
440 } else if let Some(ref template) = instruction {
441 let processed = adk_core::inject_session_state(ctx.as_ref(), template).await?;
443 if !processed.is_empty() {
444 conversation_history.push(Content {
445 role: "user".to_string(),
446 parts: vec![Part::Text { text: processed }],
447 });
448 }
449 }
450
451 let session_history = ctx.session().conversation_history();
454 conversation_history.extend(session_history);
455
456 conversation_history.push(ctx.user_content().clone());
458
459 let mut conversation_history = match include_contents {
462 adk_core::IncludeContents::None => {
463 let mut filtered = Vec::new();
466
467 let instruction_count = conversation_history.iter()
469 .take_while(|c| c.role == "user" && c.parts.iter().any(|p| {
470 if let Part::Text { text } = p {
471 !text.is_empty()
473 } else {
474 false
475 }
476 }))
477 .count();
478
479 filtered.extend(conversation_history.iter().take(instruction_count).cloned());
481
482 if let Some(last) = conversation_history.last() {
484 if last.role == "user" {
485 filtered.push(last.clone());
486 }
487 }
488
489 filtered
490 }
491 adk_core::IncludeContents::Default => {
492 conversation_history
494 }
495 };
496
497 let mut tool_declarations = std::collections::HashMap::new();
500 for tool in &tools {
501 let mut decl = serde_json::json!({
504 "name": tool.name(),
505 "description": tool.enhanced_description(),
506 });
507
508 if let Some(params) = tool.parameters_schema() {
509 decl["parameters"] = params;
510 }
511
512 if let Some(response) = tool.response_schema() {
513 decl["response"] = response;
514 }
515
516 tool_declarations.insert(tool.name().to_string(), decl);
517 }
518
519 if !sub_agents.is_empty() {
521 let transfer_tool_name = "transfer_to_agent";
522 let transfer_tool_decl = serde_json::json!({
523 "name": transfer_tool_name,
524 "description": "Transfer execution to another agent.",
525 "parameters": {
526 "type": "object",
527 "properties": {
528 "agent_name": {
529 "type": "string",
530 "description": "The name of the agent to transfer to."
531 }
532 },
533 "required": ["agent_name"]
534 }
535 });
536 tool_declarations.insert(transfer_tool_name.to_string(), transfer_tool_decl);
537 }
538
539
540 let max_iterations = 10;
542 let mut iteration = 0;
543
544 loop {
545 iteration += 1;
546 if iteration > max_iterations {
547 yield Err(adk_core::AdkError::Agent(
548 format!("Max iterations ({}) exceeded", max_iterations)
549 ));
550 return;
551 }
552
553 let config = output_schema.as_ref().map(|schema| {
555 adk_core::GenerateContentConfig {
556 temperature: None,
557 top_p: None,
558 top_k: None,
559 max_output_tokens: None,
560 response_schema: Some(schema.clone()),
561 }
562 });
563
564 let request = LlmRequest {
565 model: model.name().to_string(),
566 contents: conversation_history.clone(),
567 tools: tool_declarations.clone(),
568 config,
569 };
570
571 let mut current_request = request;
574 let mut model_response_override = None;
575 for callback in before_model_callbacks.as_ref() {
576 match callback(ctx.clone() as Arc<dyn CallbackContext>, current_request.clone()).await {
577 Ok(BeforeModelResult::Continue(modified_request)) => {
578 current_request = modified_request;
580 }
581 Ok(BeforeModelResult::Skip(response)) => {
582 model_response_override = Some(response);
584 break;
585 }
586 Err(e) => {
587 yield Err(e);
589 return;
590 }
591 }
592 }
593 let request = current_request;
594
595 let mut accumulated_content: Option<Content> = None;
597
598 if let Some(cached_response) = model_response_override {
599 let mut cached_event = Event::new(&invocation_id);
602 cached_event.author = agent_name.clone();
603 cached_event.llm_response.content = cached_response.content.clone();
604
605 if let Some(ref content) = cached_response.content {
607 let long_running_ids: Vec<String> = content.parts.iter()
608 .filter_map(|p| {
609 if let Part::FunctionCall { name, .. } = p {
610 if let Some(tool) = tools.iter().find(|t| t.name() == name) {
611 if tool.is_long_running() {
612 return Some(name.clone());
613 }
614 }
615 }
616 None
617 })
618 .collect();
619 cached_event.long_running_tool_ids = long_running_ids;
620 }
621
622 yield Ok(cached_event);
623
624 accumulated_content = cached_response.content;
625 } else {
626 let mut response_stream = model.generate_content(request, true).await?;
628
629 use futures::StreamExt;
630
631 while let Some(chunk_result) = response_stream.next().await {
633 let mut chunk = match chunk_result {
634 Ok(c) => c,
635 Err(e) => {
636 yield Err(e);
637 return;
638 }
639 };
640
641 for callback in after_model_callbacks.as_ref() {
644 match callback(ctx.clone() as Arc<dyn CallbackContext>, chunk.clone()).await {
645 Ok(Some(modified_chunk)) => {
646 chunk = modified_chunk;
648 break;
649 }
650 Ok(None) => {
651 continue;
653 }
654 Err(e) => {
655 yield Err(e);
657 return;
658 }
659 }
660 }
661
662 let mut partial_event = Event::new(&invocation_id);
664 partial_event.author = agent_name.clone();
665 partial_event.llm_response.content = chunk.content.clone();
666
667 if let Some(ref content) = chunk.content {
669 let long_running_ids: Vec<String> = content.parts.iter()
670 .filter_map(|p| {
671 if let Part::FunctionCall { name, .. } = p {
672 if let Some(tool) = tools.iter().find(|t| t.name() == name) {
674 if tool.is_long_running() {
675 return Some(name.clone());
677 }
678 }
679 }
680 None
681 })
682 .collect();
683 partial_event.long_running_tool_ids = long_running_ids;
684 }
685
686 yield Ok(partial_event);
687
688 if let Some(chunk_content) = chunk.content {
690 if let Some(ref mut acc) = accumulated_content {
691 acc.parts.extend(chunk_content.parts);
693 } else {
694 accumulated_content = Some(chunk_content);
696 }
697 }
698
699 if chunk.turn_complete {
701 break;
702 }
703 }
704 }
705
706 let function_call_names: Vec<String> = accumulated_content.as_ref()
708 .map(|c| c.parts.iter()
709 .filter_map(|p| {
710 if let Part::FunctionCall { name, .. } = p {
711 Some(name.clone())
712 } else {
713 None
714 }
715 })
716 .collect())
717 .unwrap_or_default();
718
719 let has_function_calls = !function_call_names.is_empty();
720
721 let all_calls_are_long_running = has_function_calls && function_call_names.iter().all(|name| {
725 tools.iter()
726 .find(|t| t.name() == name)
727 .map(|t| t.is_long_running())
728 .unwrap_or(false)
729 });
730
731 if let Some(ref content) = accumulated_content {
733 conversation_history.push(content.clone());
734
735 if let Some(ref output_key) = output_key {
737 if !has_function_calls { let mut text_parts = String::new();
739 for part in &content.parts {
740 if let Part::Text { text } = part {
741 text_parts.push_str(text);
742 }
743 }
744 if !text_parts.is_empty() {
745 let mut state_event = Event::new(&invocation_id);
747 state_event.author = agent_name.clone();
748 state_event.actions.state_delta.insert(
749 output_key.clone(),
750 serde_json::Value::String(text_parts),
751 );
752 yield Ok(state_event);
753 }
754 }
755 }
756 }
757
758 if !has_function_calls {
759 break;
761 }
762
763 if let Some(content) = &accumulated_content {
765 for part in &content.parts {
766 if let Part::FunctionCall { name, args, id } = part {
767 if name == "transfer_to_agent" {
769 let target_agent = args.get("agent_name")
770 .and_then(|v| v.as_str())
771 .unwrap_or_default()
772 .to_string();
773
774 let mut transfer_event = Event::new(&invocation_id);
775 transfer_event.author = agent_name.clone();
776 transfer_event.actions.transfer_to_agent = Some(target_agent);
777
778 yield Ok(transfer_event);
779 return;
780 }
781
782
783 let (tool_result, tool_actions) = if let Some(tool) = tools.iter().find(|t| t.name() == name) {
785 let tool_ctx: Arc<dyn ToolContext> = Arc::new(AgentToolContext::new(
787 ctx.clone(),
788 format!("{}_{}", invocation_id, name),
789 ));
790
791 let result = match tool.execute(tool_ctx.clone(), args.clone()).await {
792 Ok(result) => result,
793 Err(e) => serde_json::json!({ "error": e.to_string() }),
794 };
795
796 (result, tool_ctx.actions())
797 } else {
798 (serde_json::json!({ "error": format!("Tool {} not found", name) }), EventActions::default())
799 };
800
801 let mut tool_event = Event::new(&invocation_id);
803 tool_event.author = agent_name.clone();
804 tool_event.actions = tool_actions.clone();
805 tool_event.llm_response.content = Some(Content {
806 role: "function".to_string(),
807 parts: vec![Part::FunctionResponse {
808 name: name.clone(),
809 response: tool_result.clone(),
810 id: id.clone(),
811 }],
812 });
813 yield Ok(tool_event);
814
815 if tool_actions.escalate || tool_actions.skip_summarization {
817 return;
819 }
820
821 conversation_history.push(Content {
823 role: "function".to_string(),
824 parts: vec![Part::FunctionResponse {
825 name: name.clone(),
826 response: tool_result,
827 id: id.clone(),
828 }],
829 });
830 }
831 }
832 }
833
834 if all_calls_are_long_running {
837 break;
838 }
839 }
840
841 for callback in after_agent_callbacks.as_ref() {
844 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
845 Ok(Some(content)) => {
846 let mut after_event = Event::new(&invocation_id);
848 after_event.author = agent_name.clone();
849 after_event.llm_response.content = Some(content);
850 yield Ok(after_event);
851 break; }
853 Ok(None) => {
854 continue;
856 }
857 Err(e) => {
858 yield Err(e);
860 return;
861 }
862 }
863 }
864 };
865
866 Ok(Box::pin(s))
867 }
868}