adk_agent/workflow/
llm_conditional_agent.rs1use adk_core::{
32 Agent, Content, Event, EventStream, InvocationContext, Llm, LlmRequest, Part, Result,
33};
34use async_stream::stream;
35use async_trait::async_trait;
36use futures::StreamExt;
37use std::collections::HashMap;
38use std::sync::Arc;
39
40pub struct LlmConditionalAgent {
59 name: String,
60 description: String,
61 model: Arc<dyn Llm>,
62 instruction: String,
63 routes: HashMap<String, Arc<dyn Agent>>,
64 default_agent: Option<Arc<dyn Agent>>,
65}
66
67pub struct LlmConditionalAgentBuilder {
68 name: String,
69 description: Option<String>,
70 model: Arc<dyn Llm>,
71 instruction: Option<String>,
72 routes: HashMap<String, Arc<dyn Agent>>,
73 default_agent: Option<Arc<dyn Agent>>,
74}
75
76impl LlmConditionalAgentBuilder {
77 pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
79 Self {
80 name: name.into(),
81 description: None,
82 model,
83 instruction: None,
84 routes: HashMap::new(),
85 default_agent: None,
86 }
87 }
88
89 pub fn description(mut self, desc: impl Into<String>) -> Self {
91 self.description = Some(desc.into());
92 self
93 }
94
95 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
100 self.instruction = Some(instruction.into());
101 self
102 }
103
104 pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
109 self.routes.insert(label.into().to_lowercase(), agent);
110 self
111 }
112
113 pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
115 self.default_agent = Some(agent);
116 self
117 }
118
119 pub fn build(self) -> Result<LlmConditionalAgent> {
121 let instruction = self.instruction.ok_or_else(|| {
122 adk_core::AdkError::Agent("Instruction is required for LlmConditionalAgent".to_string())
123 })?;
124
125 if self.routes.is_empty() {
126 return Err(adk_core::AdkError::Agent(
127 "At least one route is required for LlmConditionalAgent".to_string(),
128 ));
129 }
130
131 Ok(LlmConditionalAgent {
132 name: self.name,
133 description: self.description.unwrap_or_default(),
134 model: self.model,
135 instruction,
136 routes: self.routes,
137 default_agent: self.default_agent,
138 })
139 }
140}
141
142impl LlmConditionalAgent {
143 pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
145 LlmConditionalAgentBuilder::new(name, model)
146 }
147}
148
149#[async_trait]
150impl Agent for LlmConditionalAgent {
151 fn name(&self) -> &str {
152 &self.name
153 }
154
155 fn description(&self) -> &str {
156 &self.description
157 }
158
159 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
160 &[]
162 }
163
164 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
165 let model = self.model.clone();
166 let instruction = self.instruction.clone();
167 let routes = self.routes.clone();
168 let default_agent = self.default_agent.clone();
169 let invocation_id = ctx.invocation_id().to_string();
170 let agent_name = self.name.clone();
171
172 let s = stream! {
173 let user_content = ctx.user_content().clone();
175 let user_text: String = user_content.parts.iter()
176 .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
177 .collect::<Vec<_>>()
178 .join(" ");
179
180 let classification_prompt = format!(
181 "{}\n\nUser input: {}",
182 instruction,
183 user_text
184 );
185
186 let request = LlmRequest {
187 model: model.name().to_string(),
188 contents: vec![Content::new("user").with_text(&classification_prompt)],
189 tools: HashMap::new(),
190 config: None,
191 };
192
193 let mut response_stream = match model.generate_content(request, false).await {
195 Ok(stream) => stream,
196 Err(e) => {
197 yield Err(e);
198 return;
199 }
200 };
201
202 let mut classification = String::new();
204 while let Some(chunk_result) = response_stream.next().await {
205 match chunk_result {
206 Ok(chunk) => {
207 if let Some(content) = chunk.content {
208 for part in content.parts {
209 if let Part::Text { text } = part {
210 classification.push_str(&text);
211 }
212 }
213 }
214 }
215 Err(e) => {
216 yield Err(e);
217 return;
218 }
219 }
220 }
221
222 let classification = classification.trim().to_lowercase();
224
225 let mut routing_event = Event::new(&invocation_id);
227 routing_event.author = agent_name.clone();
228 routing_event.llm_response.content = Some(
229 Content::new("model").with_text(format!("[Routing to: {}]", classification))
230 );
231 yield Ok(routing_event);
232
233 let target_agent = routes.iter()
235 .find(|(label, _)| classification.contains(label.as_str()))
236 .map(|(_, agent)| agent.clone())
237 .or(default_agent);
238
239 if let Some(agent) = target_agent {
241 match agent.run(ctx.clone()).await {
242 Ok(mut stream) => {
243 while let Some(event) = stream.next().await {
244 yield event;
245 }
246 }
247 Err(e) => {
248 yield Err(e);
249 }
250 }
251 } else {
252 let mut error_event = Event::new(&invocation_id);
254 error_event.author = agent_name;
255 error_event.llm_response.content = Some(
256 Content::new("model").with_text(format!(
257 "No route found for classification '{}'. Available routes: {:?}",
258 classification,
259 routes.keys().collect::<Vec<_>>()
260 ))
261 );
262 yield Ok(error_event);
263 }
264 };
265
266 Ok(Box::pin(s))
267 }
268}