1use adk_core::{
32 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
33 InvocationContext, Llm, LlmRequest, Part, Result,
34};
35use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
36use async_stream::stream;
37use async_trait::async_trait;
38use futures::StreamExt;
39use std::collections::HashMap;
40use std::sync::Arc;
41
42pub struct LlmConditionalAgent {
61 name: String,
62 description: String,
63 model: Arc<dyn Llm>,
64 instruction: String,
65 routes: HashMap<String, Arc<dyn Agent>>,
66 default_agent: Option<Arc<dyn Agent>>,
67 all_agents: Vec<Arc<dyn Agent>>,
69 skills_index: Option<Arc<SkillIndex>>,
70 skill_policy: SelectionPolicy,
71 max_skill_chars: usize,
72 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
73 after_callbacks: Arc<Vec<AfterAgentCallback>>,
74}
75
76pub struct LlmConditionalAgentBuilder {
77 name: String,
78 description: Option<String>,
79 model: Arc<dyn Llm>,
80 instruction: Option<String>,
81 routes: HashMap<String, Arc<dyn Agent>>,
82 default_agent: Option<Arc<dyn Agent>>,
83 skills_index: Option<Arc<SkillIndex>>,
84 skill_policy: SelectionPolicy,
85 max_skill_chars: usize,
86 before_callbacks: Vec<BeforeAgentCallback>,
87 after_callbacks: Vec<AfterAgentCallback>,
88}
89
90impl LlmConditionalAgentBuilder {
91 pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
93 Self {
94 name: name.into(),
95 description: None,
96 model,
97 instruction: None,
98 routes: HashMap::new(),
99 default_agent: None,
100 skills_index: None,
101 skill_policy: SelectionPolicy::default(),
102 max_skill_chars: 2000,
103 before_callbacks: Vec::new(),
104 after_callbacks: Vec::new(),
105 }
106 }
107
108 pub fn description(mut self, desc: impl Into<String>) -> Self {
110 self.description = Some(desc.into());
111 self
112 }
113
114 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
119 self.instruction = Some(instruction.into());
120 self
121 }
122
123 pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
128 self.routes.insert(label.into().to_lowercase(), agent);
129 self
130 }
131
132 pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
134 self.default_agent = Some(agent);
135 self
136 }
137
138 pub fn with_skills(mut self, index: SkillIndex) -> Self {
139 self.skills_index = Some(Arc::new(index));
140 self
141 }
142
143 pub fn with_auto_skills(self) -> Result<Self> {
144 self.with_skills_from_root(".")
145 }
146
147 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
148 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
149 self.skills_index = Some(Arc::new(index));
150 Ok(self)
151 }
152
153 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
154 self.skill_policy = policy;
155 self
156 }
157
158 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
159 self.max_skill_chars = max_chars;
160 self
161 }
162
163 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
165 self.before_callbacks.push(callback);
166 self
167 }
168
169 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
171 self.after_callbacks.push(callback);
172 self
173 }
174
175 pub fn build(self) -> Result<LlmConditionalAgent> {
177 let instruction = self.instruction.ok_or_else(|| {
178 adk_core::AdkError::agent("Instruction is required for LlmConditionalAgent")
179 })?;
180
181 if self.routes.is_empty() {
182 return Err(adk_core::AdkError::agent(
183 "At least one route is required for LlmConditionalAgent",
184 ));
185 }
186
187 let mut all_agents: Vec<Arc<dyn Agent>> = self.routes.values().cloned().collect();
189 if let Some(ref default) = self.default_agent {
190 all_agents.push(default.clone());
191 }
192
193 Ok(LlmConditionalAgent {
194 name: self.name,
195 description: self.description.unwrap_or_default(),
196 model: self.model,
197 instruction,
198 routes: self.routes,
199 default_agent: self.default_agent,
200 all_agents,
201 skills_index: self.skills_index,
202 skill_policy: self.skill_policy,
203 max_skill_chars: self.max_skill_chars,
204 before_callbacks: Arc::new(self.before_callbacks),
205 after_callbacks: Arc::new(self.after_callbacks),
206 })
207 }
208}
209
210impl LlmConditionalAgent {
211 pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
213 LlmConditionalAgentBuilder::new(name, model)
214 }
215
216 fn resolve_route(
217 classification: &str,
218 routes: &HashMap<String, Arc<dyn Agent>>,
219 ) -> Option<Arc<dyn Agent>> {
220 if let Some(agent) = routes.get(classification) {
221 return Some(agent.clone());
222 }
223
224 let mut labels = routes.keys().collect::<Vec<_>>();
225 labels.sort_by(|left, right| right.len().cmp(&left.len()).then_with(|| left.cmp(right)));
226
227 labels
228 .into_iter()
229 .find(|label| classification.contains(label.as_str()))
230 .and_then(|label| routes.get(label).cloned())
231 }
232}
233
234#[async_trait]
235impl Agent for LlmConditionalAgent {
236 fn name(&self) -> &str {
237 &self.name
238 }
239
240 fn description(&self) -> &str {
241 &self.description
242 }
243
244 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
245 &self.all_agents
246 }
247
248 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
249 let run_ctx = super::skill_context::with_skill_injected_context(
250 ctx,
251 self.skills_index.as_ref(),
252 &self.skill_policy,
253 self.max_skill_chars,
254 );
255 let model = self.model.clone();
256 let instruction = self.instruction.clone();
257 let routes = self.routes.clone();
258 let default_agent = self.default_agent.clone();
259 let invocation_id = run_ctx.invocation_id().to_string();
260 let agent_name = self.name.clone();
261 let before_callbacks = self.before_callbacks.clone();
262 let after_callbacks = self.after_callbacks.clone();
263
264 let s = stream! {
265 for callback in before_callbacks.as_ref() {
267 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
268 Ok(Some(content)) => {
269 let mut early_event = Event::new(&invocation_id);
270 early_event.author = agent_name.clone();
271 early_event.llm_response.content = Some(content);
272 yield Ok(early_event);
273
274 for after_cb in after_callbacks.as_ref() {
275 match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
276 Ok(Some(after_content)) => {
277 let mut after_event = Event::new(&invocation_id);
278 after_event.author = agent_name.clone();
279 after_event.llm_response.content = Some(after_content);
280 yield Ok(after_event);
281 return;
282 }
283 Ok(None) => continue,
284 Err(e) => { yield Err(e); return; }
285 }
286 }
287 return;
288 }
289 Ok(None) => continue,
290 Err(e) => { yield Err(e); return; }
291 }
292 }
293
294 let user_content = run_ctx.user_content().clone();
296 let user_text: String = user_content.parts.iter()
297 .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
298 .collect::<Vec<_>>()
299 .join(" ");
300
301 let classification_prompt = format!(
302 "{}\n\nUser input: {}",
303 instruction,
304 user_text
305 );
306
307 let request = LlmRequest {
308 model: model.name().to_string(),
309 contents: vec![Content::new("user").with_text(&classification_prompt)],
310 tools: HashMap::new(),
311 config: None,
312 };
313
314 let mut response_stream = match model.generate_content(request, false).await {
316 Ok(stream) => stream,
317 Err(e) => {
318 yield Err(e);
319 return;
320 }
321 };
322
323 let mut classification = String::new();
325 while let Some(chunk_result) = response_stream.next().await {
326 match chunk_result {
327 Ok(chunk) => {
328 if let Some(content) = chunk.content {
329 for part in content.parts {
330 if let Part::Text { text } = part {
331 classification.push_str(&text);
332 }
333 }
334 }
335 }
336 Err(e) => {
337 yield Err(e);
338 return;
339 }
340 }
341 }
342
343 let classification = classification.trim().to_lowercase();
345
346 let mut routing_event = Event::new(&invocation_id);
348 routing_event.author = agent_name.clone();
349 routing_event.llm_response.content = Some(
350 Content::new("model").with_text(format!("[Routing to: {}]", classification))
351 );
352 yield Ok(routing_event);
353
354 let target_agent = Self::resolve_route(&classification, &routes).or(default_agent);
356
357 if let Some(agent) = target_agent {
359 match agent.run(run_ctx.clone()).await {
360 Ok(mut stream) => {
361 while let Some(event) = stream.next().await {
362 yield event;
363 }
364 }
365 Err(e) => {
366 yield Err(e);
367 }
368 }
369 } else {
370 let mut error_event = Event::new(&invocation_id);
372 error_event.author = agent_name.clone();
373 error_event.llm_response.content = Some(
374 Content::new("model").with_text(format!(
375 "No route found for classification '{}'. Available routes: {:?}",
376 classification,
377 routes.keys().collect::<Vec<_>>()
378 ))
379 );
380 yield Ok(error_event);
381 }
382
383 for callback in after_callbacks.as_ref() {
385 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
386 Ok(Some(content)) => {
387 let mut after_event = Event::new(&invocation_id);
388 after_event.author = agent_name.clone();
389 after_event.llm_response.content = Some(content);
390 yield Ok(after_event);
391 break;
392 }
393 Ok(None) => continue,
394 Err(e) => { yield Err(e); return; }
395 }
396 }
397 };
398
399 Ok(Box::pin(s))
400 }
401}