1#[cfg(feature = "skills")]
32use crate::skill_shim::load_skill_index;
33use crate::skill_shim::{SelectionPolicy, SkillIndex};
34use adk_core::{
35 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
36 InvocationContext, Llm, LlmRequest, Part, Result,
37};
38use async_stream::stream;
39use async_trait::async_trait;
40use futures::StreamExt;
41use std::collections::HashMap;
42use std::sync::Arc;
43
44pub struct LlmConditionalAgent {
63 name: String,
64 description: String,
65 model: Arc<dyn Llm>,
66 instruction: String,
67 routes: HashMap<String, Arc<dyn Agent>>,
68 default_agent: Option<Arc<dyn Agent>>,
69 all_agents: Vec<Arc<dyn Agent>>,
71 skills_index: Option<Arc<SkillIndex>>,
72 skill_policy: SelectionPolicy,
73 max_skill_chars: usize,
74 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
75 after_callbacks: Arc<Vec<AfterAgentCallback>>,
76}
77
78pub struct LlmConditionalAgentBuilder {
79 name: String,
80 description: Option<String>,
81 model: Arc<dyn Llm>,
82 instruction: Option<String>,
83 routes: HashMap<String, Arc<dyn Agent>>,
84 default_agent: Option<Arc<dyn Agent>>,
85 skills_index: Option<Arc<SkillIndex>>,
86 skill_policy: SelectionPolicy,
87 max_skill_chars: usize,
88 before_callbacks: Vec<BeforeAgentCallback>,
89 after_callbacks: Vec<AfterAgentCallback>,
90}
91
92impl LlmConditionalAgentBuilder {
93 pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
95 Self {
96 name: name.into(),
97 description: None,
98 model,
99 instruction: None,
100 routes: HashMap::new(),
101 default_agent: None,
102 skills_index: None,
103 skill_policy: SelectionPolicy::default(),
104 max_skill_chars: 2000,
105 before_callbacks: Vec::new(),
106 after_callbacks: Vec::new(),
107 }
108 }
109
110 pub fn description(mut self, desc: impl Into<String>) -> Self {
112 self.description = Some(desc.into());
113 self
114 }
115
116 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
121 self.instruction = Some(instruction.into());
122 self
123 }
124
125 pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
130 self.routes.insert(label.into().to_lowercase(), agent);
131 self
132 }
133
134 pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
136 self.default_agent = Some(agent);
137 self
138 }
139
140 #[cfg(feature = "skills")]
141 pub fn with_skills(mut self, index: SkillIndex) -> Self {
142 self.skills_index = Some(Arc::new(index));
143 self
144 }
145
146 #[cfg(feature = "skills")]
147 pub fn with_auto_skills(self) -> Result<Self> {
148 self.with_skills_from_root(".")
149 }
150
151 #[cfg(feature = "skills")]
152 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
153 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
154 self.skills_index = Some(Arc::new(index));
155 Ok(self)
156 }
157
158 #[cfg(feature = "skills")]
159 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
160 self.skill_policy = policy;
161 self
162 }
163
164 #[cfg(feature = "skills")]
165 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
166 self.max_skill_chars = max_chars;
167 self
168 }
169
170 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
172 self.before_callbacks.push(callback);
173 self
174 }
175
176 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
178 self.after_callbacks.push(callback);
179 self
180 }
181
182 pub fn build(self) -> Result<LlmConditionalAgent> {
184 let instruction = self.instruction.ok_or_else(|| {
185 adk_core::AdkError::agent("Instruction is required for LlmConditionalAgent")
186 })?;
187
188 if self.routes.is_empty() {
189 return Err(adk_core::AdkError::agent(
190 "At least one route is required for LlmConditionalAgent",
191 ));
192 }
193
194 let mut all_agents: Vec<Arc<dyn Agent>> = self.routes.values().cloned().collect();
196 if let Some(ref default) = self.default_agent {
197 all_agents.push(default.clone());
198 }
199
200 Ok(LlmConditionalAgent {
201 name: self.name,
202 description: self.description.unwrap_or_default(),
203 model: self.model,
204 instruction,
205 routes: self.routes,
206 default_agent: self.default_agent,
207 all_agents,
208 skills_index: self.skills_index,
209 skill_policy: self.skill_policy,
210 max_skill_chars: self.max_skill_chars,
211 before_callbacks: Arc::new(self.before_callbacks),
212 after_callbacks: Arc::new(self.after_callbacks),
213 })
214 }
215}
216
217impl LlmConditionalAgent {
218 pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
220 LlmConditionalAgentBuilder::new(name, model)
221 }
222
223 fn resolve_route(
224 classification: &str,
225 routes: &HashMap<String, Arc<dyn Agent>>,
226 ) -> Option<Arc<dyn Agent>> {
227 if let Some(agent) = routes.get(classification) {
228 return Some(agent.clone());
229 }
230
231 let mut labels = routes.keys().collect::<Vec<_>>();
232 labels.sort_by(|left, right| right.len().cmp(&left.len()).then_with(|| left.cmp(right)));
233
234 labels
235 .into_iter()
236 .find(|label| classification.contains(label.as_str()))
237 .and_then(|label| routes.get(label).cloned())
238 }
239}
240
241#[async_trait]
242impl Agent for LlmConditionalAgent {
243 fn name(&self) -> &str {
244 &self.name
245 }
246
247 fn description(&self) -> &str {
248 &self.description
249 }
250
251 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
252 &self.all_agents
253 }
254
255 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
256 let run_ctx = super::skill_context::with_skill_injected_context(
257 ctx,
258 self.skills_index.as_ref(),
259 &self.skill_policy,
260 self.max_skill_chars,
261 );
262 let model = self.model.clone();
263 let instruction = self.instruction.clone();
264 let routes = self.routes.clone();
265 let default_agent = self.default_agent.clone();
266 let invocation_id = run_ctx.invocation_id().to_string();
267 let agent_name = self.name.clone();
268 let before_callbacks = self.before_callbacks.clone();
269 let after_callbacks = self.after_callbacks.clone();
270
271 let s = stream! {
272 for callback in before_callbacks.as_ref() {
274 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
275 Ok(Some(content)) => {
276 let mut early_event = Event::new(&invocation_id);
277 early_event.author = agent_name.clone();
278 early_event.llm_response.content = Some(content);
279 yield Ok(early_event);
280
281 for after_cb in after_callbacks.as_ref() {
282 match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
283 Ok(Some(after_content)) => {
284 let mut after_event = Event::new(&invocation_id);
285 after_event.author = agent_name.clone();
286 after_event.llm_response.content = Some(after_content);
287 yield Ok(after_event);
288 return;
289 }
290 Ok(None) => continue,
291 Err(e) => { yield Err(e); return; }
292 }
293 }
294 return;
295 }
296 Ok(None) => continue,
297 Err(e) => { yield Err(e); return; }
298 }
299 }
300
301 let user_content = run_ctx.user_content().clone();
303 let user_text: String = user_content.parts.iter()
304 .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
305 .collect::<Vec<_>>()
306 .join(" ");
307
308 let classification_prompt = format!(
309 "{}\n\nUser input: {}",
310 instruction,
311 user_text
312 );
313
314 let request = LlmRequest {
315 model: model.name().to_string(),
316 contents: vec![Content::new("user").with_text(&classification_prompt)],
317 tools: HashMap::new(),
318 config: None,
319 };
320
321 let mut response_stream = match model.generate_content(request, false).await {
323 Ok(stream) => stream,
324 Err(e) => {
325 yield Err(e);
326 return;
327 }
328 };
329
330 let mut classification = String::new();
332 while let Some(chunk_result) = response_stream.next().await {
333 match chunk_result {
334 Ok(chunk) => {
335 if let Some(content) = chunk.content {
336 for part in content.parts {
337 if let Part::Text { text } = part {
338 classification.push_str(&text);
339 }
340 }
341 }
342 }
343 Err(e) => {
344 yield Err(e);
345 return;
346 }
347 }
348 }
349
350 let classification = classification.trim().to_lowercase();
352
353 let mut routing_event = Event::new(&invocation_id);
355 routing_event.author = agent_name.clone();
356 routing_event.llm_response.content = Some(
357 Content::new("model").with_text(format!("[Routing to: {}]", classification))
358 );
359 yield Ok(routing_event);
360
361 let target_agent = Self::resolve_route(&classification, &routes).or(default_agent);
363
364 if let Some(agent) = target_agent {
366 match agent.run(run_ctx.clone()).await {
367 Ok(mut stream) => {
368 while let Some(event) = stream.next().await {
369 yield event;
370 }
371 }
372 Err(e) => {
373 yield Err(e);
374 }
375 }
376 } else {
377 let mut error_event = Event::new(&invocation_id);
379 error_event.author = agent_name.clone();
380 error_event.llm_response.content = Some(
381 Content::new("model").with_text(format!(
382 "No route found for classification '{}'. Available routes: {:?}",
383 classification,
384 routes.keys().collect::<Vec<_>>()
385 ))
386 );
387 yield Ok(error_event);
388 }
389
390 for callback in after_callbacks.as_ref() {
392 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
393 Ok(Some(content)) => {
394 let mut after_event = Event::new(&invocation_id);
395 after_event.author = agent_name.clone();
396 after_event.llm_response.content = Some(content);
397 yield Ok(after_event);
398 break;
399 }
400 Ok(None) => continue,
401 Err(e) => { yield Err(e); return; }
402 }
403 }
404 };
405
406 Ok(Box::pin(s))
407 }
408}