1use adk_core::{
32 Agent, Content, Event, EventStream, InvocationContext, Llm, LlmRequest, Part, Result,
33};
34use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
35use async_stream::stream;
36use async_trait::async_trait;
37use futures::StreamExt;
38use std::collections::HashMap;
39use std::sync::Arc;
40
41pub struct LlmConditionalAgent {
60 name: String,
61 description: String,
62 model: Arc<dyn Llm>,
63 instruction: String,
64 routes: HashMap<String, Arc<dyn Agent>>,
65 default_agent: Option<Arc<dyn Agent>>,
66 all_agents: Vec<Arc<dyn Agent>>,
68 skills_index: Option<Arc<SkillIndex>>,
69 skill_policy: SelectionPolicy,
70 max_skill_chars: usize,
71}
72
73pub struct LlmConditionalAgentBuilder {
74 name: String,
75 description: Option<String>,
76 model: Arc<dyn Llm>,
77 instruction: Option<String>,
78 routes: HashMap<String, Arc<dyn Agent>>,
79 default_agent: Option<Arc<dyn Agent>>,
80 skills_index: Option<Arc<SkillIndex>>,
81 skill_policy: SelectionPolicy,
82 max_skill_chars: usize,
83}
84
85impl LlmConditionalAgentBuilder {
86 pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
88 Self {
89 name: name.into(),
90 description: None,
91 model,
92 instruction: None,
93 routes: HashMap::new(),
94 default_agent: None,
95 skills_index: None,
96 skill_policy: SelectionPolicy::default(),
97 max_skill_chars: 2000,
98 }
99 }
100
101 pub fn description(mut self, desc: impl Into<String>) -> Self {
103 self.description = Some(desc.into());
104 self
105 }
106
107 pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
112 self.instruction = Some(instruction.into());
113 self
114 }
115
116 pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
121 self.routes.insert(label.into().to_lowercase(), agent);
122 self
123 }
124
125 pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
127 self.default_agent = Some(agent);
128 self
129 }
130
131 pub fn with_skills(mut self, index: SkillIndex) -> Self {
132 self.skills_index = Some(Arc::new(index));
133 self
134 }
135
136 pub fn with_auto_skills(self) -> Result<Self> {
137 self.with_skills_from_root(".")
138 }
139
140 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
141 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
142 self.skills_index = Some(Arc::new(index));
143 Ok(self)
144 }
145
146 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
147 self.skill_policy = policy;
148 self
149 }
150
151 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
152 self.max_skill_chars = max_chars;
153 self
154 }
155
156 pub fn build(self) -> Result<LlmConditionalAgent> {
158 let instruction = self.instruction.ok_or_else(|| {
159 adk_core::AdkError::Agent("Instruction is required for LlmConditionalAgent".to_string())
160 })?;
161
162 if self.routes.is_empty() {
163 return Err(adk_core::AdkError::Agent(
164 "At least one route is required for LlmConditionalAgent".to_string(),
165 ));
166 }
167
168 let mut all_agents: Vec<Arc<dyn Agent>> = self.routes.values().cloned().collect();
170 if let Some(ref default) = self.default_agent {
171 all_agents.push(default.clone());
172 }
173
174 Ok(LlmConditionalAgent {
175 name: self.name,
176 description: self.description.unwrap_or_default(),
177 model: self.model,
178 instruction,
179 routes: self.routes,
180 default_agent: self.default_agent,
181 all_agents,
182 skills_index: self.skills_index,
183 skill_policy: self.skill_policy,
184 max_skill_chars: self.max_skill_chars,
185 })
186 }
187}
188
189impl LlmConditionalAgent {
190 pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
192 LlmConditionalAgentBuilder::new(name, model)
193 }
194
195 fn resolve_route(
196 classification: &str,
197 routes: &HashMap<String, Arc<dyn Agent>>,
198 ) -> Option<Arc<dyn Agent>> {
199 if let Some(agent) = routes.get(classification) {
200 return Some(agent.clone());
201 }
202
203 let mut labels = routes.keys().collect::<Vec<_>>();
204 labels.sort_by(|left, right| right.len().cmp(&left.len()).then_with(|| left.cmp(right)));
205
206 labels
207 .into_iter()
208 .find(|label| classification.contains(label.as_str()))
209 .and_then(|label| routes.get(label).cloned())
210 }
211}
212
213#[async_trait]
214impl Agent for LlmConditionalAgent {
215 fn name(&self) -> &str {
216 &self.name
217 }
218
219 fn description(&self) -> &str {
220 &self.description
221 }
222
223 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
224 &self.all_agents
225 }
226
227 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
228 let run_ctx = super::skill_context::with_skill_injected_context(
229 ctx,
230 self.skills_index.as_ref(),
231 &self.skill_policy,
232 self.max_skill_chars,
233 );
234 let model = self.model.clone();
235 let instruction = self.instruction.clone();
236 let routes = self.routes.clone();
237 let default_agent = self.default_agent.clone();
238 let invocation_id = run_ctx.invocation_id().to_string();
239 let agent_name = self.name.clone();
240
241 let s = stream! {
242 let user_content = run_ctx.user_content().clone();
244 let user_text: String = user_content.parts.iter()
245 .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
246 .collect::<Vec<_>>()
247 .join(" ");
248
249 let classification_prompt = format!(
250 "{}\n\nUser input: {}",
251 instruction,
252 user_text
253 );
254
255 let request = LlmRequest {
256 model: model.name().to_string(),
257 contents: vec![Content::new("user").with_text(&classification_prompt)],
258 tools: HashMap::new(),
259 config: None,
260 };
261
262 let mut response_stream = match model.generate_content(request, false).await {
264 Ok(stream) => stream,
265 Err(e) => {
266 yield Err(e);
267 return;
268 }
269 };
270
271 let mut classification = String::new();
273 while let Some(chunk_result) = response_stream.next().await {
274 match chunk_result {
275 Ok(chunk) => {
276 if let Some(content) = chunk.content {
277 for part in content.parts {
278 if let Part::Text { text } = part {
279 classification.push_str(&text);
280 }
281 }
282 }
283 }
284 Err(e) => {
285 yield Err(e);
286 return;
287 }
288 }
289 }
290
291 let classification = classification.trim().to_lowercase();
293
294 let mut routing_event = Event::new(&invocation_id);
296 routing_event.author = agent_name.clone();
297 routing_event.llm_response.content = Some(
298 Content::new("model").with_text(format!("[Routing to: {}]", classification))
299 );
300 yield Ok(routing_event);
301
302 let target_agent = Self::resolve_route(&classification, &routes).or(default_agent);
304
305 if let Some(agent) = target_agent {
307 match agent.run(run_ctx.clone()).await {
308 Ok(mut stream) => {
309 while let Some(event) = stream.next().await {
310 yield event;
311 }
312 }
313 Err(e) => {
314 yield Err(e);
315 }
316 }
317 } else {
318 let mut error_event = Event::new(&invocation_id);
320 error_event.author = agent_name;
321 error_event.llm_response.content = Some(
322 Content::new("model").with_text(format!(
323 "No route found for classification '{}'. Available routes: {:?}",
324 classification,
325 routes.keys().collect::<Vec<_>>()
326 ))
327 );
328 yield Ok(error_event);
329 }
330 };
331
332 Ok(Box::pin(s))
333 }
334}