1#[cfg(feature = "skills")]
32use crate::skill_shim::load_skill_index;
33use crate::skill_shim::{SelectionPolicy, SkillIndex};
34use adk_core::{
35 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
36 InvocationContext, Llm, LlmRequest, Part, Result,
37};
38use async_stream::stream;
39use async_trait::async_trait;
40use futures::StreamExt;
41use std::collections::HashMap;
42use std::sync::Arc;
43
44pub struct LlmConditionalAgent {
63 name: String,
64 description: String,
65 model: Arc<dyn Llm>,
66 instruction: String,
67 routes: HashMap<String, Arc<dyn Agent>>,
68 default_agent: Option<Arc<dyn Agent>>,
69 all_agents: Vec<Arc<dyn Agent>>,
71 skills_index: Option<Arc<SkillIndex>>,
72 skill_policy: SelectionPolicy,
73 max_skill_chars: usize,
74 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
75 after_callbacks: Arc<Vec<AfterAgentCallback>>,
76}
77
78pub struct LlmConditionalAgentBuilder {
80 name: String,
81 description: Option<String>,
82 model: Arc<dyn Llm>,
83 instruction: Option<String>,
84 routes: HashMap<String, Arc<dyn Agent>>,
85 default_agent: Option<Arc<dyn Agent>>,
86 skills_index: Option<Arc<SkillIndex>>,
87 skill_policy: SelectionPolicy,
88 max_skill_chars: usize,
89 before_callbacks: Vec<BeforeAgentCallback>,
90 after_callbacks: Vec<AfterAgentCallback>,
91}
92
93impl LlmConditionalAgentBuilder {
94 pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
96 Self {
97 name: name.into(),
98 description: None,
99 model,
100 instruction: None,
101 routes: HashMap::new(),
102 default_agent: None,
103 skills_index: None,
104 skill_policy: SelectionPolicy::default(),
105 max_skill_chars: 2000,
106 before_callbacks: Vec::new(),
107 after_callbacks: Vec::new(),
108 }
109 }
110
111 pub fn description(mut self, desc: impl Into<String>) -> Self {
113 self.description = Some(desc.into());
114 self
115 }
116
117 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
122 self.instruction = Some(instruction.into());
123 self
124 }
125
126 pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
131 self.routes.insert(label.into().to_lowercase(), agent);
132 self
133 }
134
135 pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
137 self.default_agent = Some(agent);
138 self
139 }
140
141 #[cfg(feature = "skills")]
143 pub fn with_skills(mut self, index: SkillIndex) -> Self {
144 self.skills_index = Some(Arc::new(index));
145 self
146 }
147
148 #[cfg(feature = "skills")]
150 pub fn with_auto_skills(self) -> Result<Self> {
151 self.with_skills_from_root(".")
152 }
153
154 #[cfg(feature = "skills")]
156 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
157 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
158 self.skills_index = Some(Arc::new(index));
159 Ok(self)
160 }
161
162 #[cfg(feature = "skills")]
164 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
165 self.skill_policy = policy;
166 self
167 }
168
169 #[cfg(feature = "skills")]
171 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
172 self.max_skill_chars = max_chars;
173 self
174 }
175
176 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
178 self.before_callbacks.push(callback);
179 self
180 }
181
182 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
184 self.after_callbacks.push(callback);
185 self
186 }
187
188 pub fn build(self) -> Result<LlmConditionalAgent> {
190 let instruction = self.instruction.ok_or_else(|| {
191 adk_core::AdkError::agent("Instruction is required for LlmConditionalAgent")
192 })?;
193
194 if self.routes.is_empty() {
195 return Err(adk_core::AdkError::agent(
196 "At least one route is required for LlmConditionalAgent",
197 ));
198 }
199
200 let mut all_agents: Vec<Arc<dyn Agent>> = self.routes.values().cloned().collect();
202 if let Some(ref default) = self.default_agent {
203 all_agents.push(default.clone());
204 }
205
206 Ok(LlmConditionalAgent {
207 name: self.name,
208 description: self.description.unwrap_or_default(),
209 model: self.model,
210 instruction,
211 routes: self.routes,
212 default_agent: self.default_agent,
213 all_agents,
214 skills_index: self.skills_index,
215 skill_policy: self.skill_policy,
216 max_skill_chars: self.max_skill_chars,
217 before_callbacks: Arc::new(self.before_callbacks),
218 after_callbacks: Arc::new(self.after_callbacks),
219 })
220 }
221}
222
223impl LlmConditionalAgent {
224 pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
226 LlmConditionalAgentBuilder::new(name, model)
227 }
228
229 fn resolve_route(
230 classification: &str,
231 routes: &HashMap<String, Arc<dyn Agent>>,
232 ) -> Option<Arc<dyn Agent>> {
233 if let Some(agent) = routes.get(classification) {
234 return Some(agent.clone());
235 }
236
237 let mut labels = routes.keys().collect::<Vec<_>>();
238 labels.sort_by(|left, right| right.len().cmp(&left.len()).then_with(|| left.cmp(right)));
239
240 labels
241 .into_iter()
242 .find(|label| classification.contains(label.as_str()))
243 .and_then(|label| routes.get(label).cloned())
244 }
245}
246
247#[async_trait]
248impl Agent for LlmConditionalAgent {
249 fn name(&self) -> &str {
250 &self.name
251 }
252
253 fn description(&self) -> &str {
254 &self.description
255 }
256
257 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
258 &self.all_agents
259 }
260
261 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
262 let run_ctx = super::skill_context::with_skill_injected_context(
263 ctx,
264 self.skills_index.as_ref(),
265 &self.skill_policy,
266 self.max_skill_chars,
267 );
268 let model = self.model.clone();
269 let instruction = self.instruction.clone();
270 let routes = self.routes.clone();
271 let default_agent = self.default_agent.clone();
272 let invocation_id = run_ctx.invocation_id().to_string();
273 let agent_name = self.name.clone();
274 let before_callbacks = self.before_callbacks.clone();
275 let after_callbacks = self.after_callbacks.clone();
276
277 let s = stream! {
278 for callback in before_callbacks.as_ref() {
280 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
281 Ok(Some(content)) => {
282 let mut early_event = Event::new(&invocation_id);
283 early_event.author = agent_name.clone();
284 early_event.llm_response.content = Some(content);
285 yield Ok(early_event);
286
287 for after_cb in after_callbacks.as_ref() {
288 match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
289 Ok(Some(after_content)) => {
290 let mut after_event = Event::new(&invocation_id);
291 after_event.author = agent_name.clone();
292 after_event.llm_response.content = Some(after_content);
293 yield Ok(after_event);
294 return;
295 }
296 Ok(None) => continue,
297 Err(e) => { yield Err(e); return; }
298 }
299 }
300 return;
301 }
302 Ok(None) => continue,
303 Err(e) => { yield Err(e); return; }
304 }
305 }
306
307 let user_content = run_ctx.user_content().clone();
309 let user_text: String = user_content.parts.iter()
310 .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
311 .collect::<Vec<_>>()
312 .join(" ");
313
314 let classification_prompt = format!(
315 "{}\n\nUser input: {}",
316 instruction,
317 user_text
318 );
319
320 let request = LlmRequest {
321 model: model.name().to_string(),
322 contents: vec![Content::new("user").with_text(&classification_prompt)],
323 tools: HashMap::new(),
324 config: None,
325 };
326
327 let mut response_stream = match model.generate_content(request, false).await {
329 Ok(stream) => stream,
330 Err(e) => {
331 yield Err(e);
332 return;
333 }
334 };
335
336 let mut classification = String::new();
338 while let Some(chunk_result) = response_stream.next().await {
339 match chunk_result {
340 Ok(chunk) => {
341 if let Some(content) = chunk.content {
342 for part in content.parts {
343 if let Part::Text { text } = part {
344 classification.push_str(&text);
345 }
346 }
347 }
348 }
349 Err(e) => {
350 yield Err(e);
351 return;
352 }
353 }
354 }
355
356 let classification = classification.trim().to_lowercase();
358
359 let mut routing_event = Event::new(&invocation_id);
361 routing_event.author = agent_name.clone();
362 routing_event.llm_response.content = Some(
363 Content::new("model").with_text(format!("[Routing to: {}]", classification))
364 );
365 yield Ok(routing_event);
366
367 let target_agent = Self::resolve_route(&classification, &routes).or(default_agent);
369
370 if let Some(agent) = target_agent {
372 match agent.run(run_ctx.clone()).await {
373 Ok(mut stream) => {
374 while let Some(event) = stream.next().await {
375 yield event;
376 }
377 }
378 Err(e) => {
379 yield Err(e);
380 }
381 }
382 } else {
383 let mut error_event = Event::new(&invocation_id);
385 error_event.author = agent_name.clone();
386 error_event.llm_response.content = Some(
387 Content::new("model").with_text(format!(
388 "No route found for classification '{}'. Available routes: {:?}",
389 classification,
390 routes.keys().collect::<Vec<_>>()
391 ))
392 );
393 yield Ok(error_event);
394 }
395
396 for callback in after_callbacks.as_ref() {
398 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
399 Ok(Some(content)) => {
400 let mut after_event = Event::new(&invocation_id);
401 after_event.author = agent_name.clone();
402 after_event.llm_response.content = Some(content);
403 yield Ok(after_event);
404 break;
405 }
406 Ok(None) => continue,
407 Err(e) => { yield Err(e); return; }
408 }
409 }
410 };
411
412 Ok(Box::pin(s))
413 }
414}