Skip to main content

adk_agent/workflow/
llm_conditional_agent.rs

1//! LLM-based intelligent conditional routing agent.
2//!
3//! `LlmConditionalAgent` provides **intelligent, LLM-based** conditional routing.
4//! The model classifies user input and routes to the appropriate sub-agent.
5//!
6//! # When to Use
7//!
8//! Use `LlmConditionalAgent` for **intelligent** routing decisions:
9//! - Intent classification (technical vs general vs creative)
10//! - Multi-way routing (more than 2 destinations)
11//! - Context-aware routing that requires understanding the content
12//!
13//! # For Rule-Based Routing
14//!
15//! If you need **deterministic, rule-based** routing (e.g., A/B testing,
16//! feature flags), use [`ConditionalAgent`] instead.
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! let router = LlmConditionalAgent::builder("router", model)
22//!     .instruction("Classify as 'technical', 'general', or 'creative'.
23//!                   Respond with ONLY the category name.")
24//!     .route("technical", Arc::new(tech_agent))
25//!     .route("general", Arc::new(general_agent))
26//!     .route("creative", Arc::new(creative_agent))
27//!     .default_route(Arc::new(general_agent))
28//!     .build()?;
29//! ```
30
31#[cfg(feature = "skills")]
32use crate::skill_shim::load_skill_index;
33use crate::skill_shim::{SelectionPolicy, SkillIndex};
34use adk_core::{
35    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
36    InvocationContext, Llm, LlmRequest, Part, Result,
37};
38use async_stream::stream;
39use async_trait::async_trait;
40use futures::StreamExt;
41use std::collections::HashMap;
42use std::sync::Arc;
43
44/// LLM-based intelligent conditional routing agent.
45///
46/// Uses an LLM to classify user input and route to the appropriate sub-agent
47/// based on the classification result. Supports multi-way routing.
48///
49/// For rule-based routing (A/B testing, feature flags), use [`crate::ConditionalAgent`].
50///
51/// # Example
52///
53/// ```rust,ignore
54/// let router = LlmConditionalAgent::builder("router", model)
55///     .instruction("Classify as 'technical', 'general', or 'creative'.")
56///     .route("technical", tech_agent)
57///     .route("general", general_agent.clone())
58///     .route("creative", creative_agent)
59///     .default_route(general_agent)
60///     .build()?;
61/// ```
62pub struct LlmConditionalAgent {
63    name: String,
64    description: String,
65    model: Arc<dyn Llm>,
66    instruction: String,
67    routes: HashMap<String, Arc<dyn Agent>>,
68    default_agent: Option<Arc<dyn Agent>>,
69    /// Cached list of all route agents (+ default) for tree discovery via `sub_agents()`.
70    all_agents: Vec<Arc<dyn Agent>>,
71    skills_index: Option<Arc<SkillIndex>>,
72    skill_policy: SelectionPolicy,
73    max_skill_chars: usize,
74    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
75    after_callbacks: Arc<Vec<AfterAgentCallback>>,
76}
77
78/// Builder for constructing an [`LlmConditionalAgent`].
79pub struct LlmConditionalAgentBuilder {
80    name: String,
81    description: Option<String>,
82    model: Arc<dyn Llm>,
83    instruction: Option<String>,
84    routes: HashMap<String, Arc<dyn Agent>>,
85    default_agent: Option<Arc<dyn Agent>>,
86    skills_index: Option<Arc<SkillIndex>>,
87    skill_policy: SelectionPolicy,
88    max_skill_chars: usize,
89    before_callbacks: Vec<BeforeAgentCallback>,
90    after_callbacks: Vec<AfterAgentCallback>,
91}
92
93impl LlmConditionalAgentBuilder {
94    /// Create a new builder with the given name and model.
95    pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
96        Self {
97            name: name.into(),
98            description: None,
99            model,
100            instruction: None,
101            routes: HashMap::new(),
102            default_agent: None,
103            skills_index: None,
104            skill_policy: SelectionPolicy::default(),
105            max_skill_chars: 2000,
106            before_callbacks: Vec::new(),
107            after_callbacks: Vec::new(),
108        }
109    }
110
111    /// Set a description for the agent.
112    pub fn description(mut self, desc: impl Into<String>) -> Self {
113        self.description = Some(desc.into());
114        self
115    }
116
117    /// Set the classification instruction.
118    ///
119    /// The instruction should tell the LLM to classify the user's input
120    /// and respond with ONLY the category name (matching a route key).
121    pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
122        self.instruction = Some(instruction.into());
123        self
124    }
125
126    /// Add a route mapping a classification label to an agent.
127    ///
128    /// When the LLM's response contains this label, execution transfers
129    /// to the specified agent.
130    pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
131        self.routes.insert(label.into().to_lowercase(), agent);
132        self
133    }
134
135    /// Set the default agent to use when no route matches.
136    pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
137        self.default_agent = Some(agent);
138        self
139    }
140
141    /// Set a preloaded skills index for this agent.
142    #[cfg(feature = "skills")]
143    pub fn with_skills(mut self, index: SkillIndex) -> Self {
144        self.skills_index = Some(Arc::new(index));
145        self
146    }
147
148    /// Auto-load skills from `.skills/` in the current working directory.
149    #[cfg(feature = "skills")]
150    pub fn with_auto_skills(self) -> Result<Self> {
151        self.with_skills_from_root(".")
152    }
153
154    /// Auto-load skills from `.skills/` under a custom root directory.
155    #[cfg(feature = "skills")]
156    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
157        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
158        self.skills_index = Some(Arc::new(index));
159        Ok(self)
160    }
161
162    /// Customize skill selection behavior.
163    #[cfg(feature = "skills")]
164    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
165        self.skill_policy = policy;
166        self
167    }
168
169    /// Limit injected skill content length.
170    #[cfg(feature = "skills")]
171    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
172        self.max_skill_chars = max_chars;
173        self
174    }
175
176    /// Add a before-agent callback.
177    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
178        self.before_callbacks.push(callback);
179        self
180    }
181
182    /// Add an after-agent callback.
183    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
184        self.after_callbacks.push(callback);
185        self
186    }
187
188    /// Build the LlmConditionalAgent.
189    pub fn build(self) -> Result<LlmConditionalAgent> {
190        let instruction = self.instruction.ok_or_else(|| {
191            adk_core::AdkError::agent("Instruction is required for LlmConditionalAgent")
192        })?;
193
194        if self.routes.is_empty() {
195            return Err(adk_core::AdkError::agent(
196                "At least one route is required for LlmConditionalAgent",
197            ));
198        }
199
200        // Collect all agents for sub_agents() tree discovery
201        let mut all_agents: Vec<Arc<dyn Agent>> = self.routes.values().cloned().collect();
202        if let Some(ref default) = self.default_agent {
203            all_agents.push(default.clone());
204        }
205
206        Ok(LlmConditionalAgent {
207            name: self.name,
208            description: self.description.unwrap_or_default(),
209            model: self.model,
210            instruction,
211            routes: self.routes,
212            default_agent: self.default_agent,
213            all_agents,
214            skills_index: self.skills_index,
215            skill_policy: self.skill_policy,
216            max_skill_chars: self.max_skill_chars,
217            before_callbacks: Arc::new(self.before_callbacks),
218            after_callbacks: Arc::new(self.after_callbacks),
219        })
220    }
221}
222
223impl LlmConditionalAgent {
224    /// Create a new builder for LlmConditionalAgent.
225    pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
226        LlmConditionalAgentBuilder::new(name, model)
227    }
228
229    fn resolve_route(
230        classification: &str,
231        routes: &HashMap<String, Arc<dyn Agent>>,
232    ) -> Option<Arc<dyn Agent>> {
233        if let Some(agent) = routes.get(classification) {
234            return Some(agent.clone());
235        }
236
237        let mut labels = routes.keys().collect::<Vec<_>>();
238        labels.sort_by(|left, right| right.len().cmp(&left.len()).then_with(|| left.cmp(right)));
239
240        labels
241            .into_iter()
242            .find(|label| classification.contains(label.as_str()))
243            .and_then(|label| routes.get(label).cloned())
244    }
245}
246
247#[async_trait]
248impl Agent for LlmConditionalAgent {
249    fn name(&self) -> &str {
250        &self.name
251    }
252
253    fn description(&self) -> &str {
254        &self.description
255    }
256
257    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
258        &self.all_agents
259    }
260
261    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
262        let run_ctx = super::skill_context::with_skill_injected_context(
263            ctx,
264            self.skills_index.as_ref(),
265            &self.skill_policy,
266            self.max_skill_chars,
267        );
268        let model = self.model.clone();
269        let instruction = self.instruction.clone();
270        let routes = self.routes.clone();
271        let default_agent = self.default_agent.clone();
272        let invocation_id = run_ctx.invocation_id().to_string();
273        let agent_name = self.name.clone();
274        let before_callbacks = self.before_callbacks.clone();
275        let after_callbacks = self.after_callbacks.clone();
276
277        let s = stream! {
278            // ===== BEFORE AGENT CALLBACKS =====
279            for callback in before_callbacks.as_ref() {
280                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
281                    Ok(Some(content)) => {
282                        let mut early_event = Event::new(&invocation_id);
283                        early_event.author = agent_name.clone();
284                        early_event.llm_response.content = Some(content);
285                        yield Ok(early_event);
286
287                        for after_cb in after_callbacks.as_ref() {
288                            match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
289                                Ok(Some(after_content)) => {
290                                    let mut after_event = Event::new(&invocation_id);
291                                    after_event.author = agent_name.clone();
292                                    after_event.llm_response.content = Some(after_content);
293                                    yield Ok(after_event);
294                                    return;
295                                }
296                                Ok(None) => continue,
297                                Err(e) => { yield Err(e); return; }
298                            }
299                        }
300                        return;
301                    }
302                    Ok(None) => continue,
303                    Err(e) => { yield Err(e); return; }
304                }
305            }
306
307            // Build classification request
308            let user_content = run_ctx.user_content().clone();
309            let user_text: String = user_content.parts.iter()
310                .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
311                .collect::<Vec<_>>()
312                .join(" ");
313
314            let classification_prompt = format!(
315                "{}\n\nUser input: {}",
316                instruction,
317                user_text
318            );
319
320            let request = LlmRequest {
321                model: model.name().to_string(),
322                contents: vec![Content::new("user").with_text(&classification_prompt)],
323                tools: HashMap::new(),
324                config: None,
325            };
326
327            // Call LLM for classification
328            let mut response_stream = match model.generate_content(request, false).await {
329                Ok(stream) => stream,
330                Err(e) => {
331                    yield Err(e);
332                    return;
333                }
334            };
335
336            // Collect classification response
337            let mut classification = String::new();
338            while let Some(chunk_result) = response_stream.next().await {
339                match chunk_result {
340                    Ok(chunk) => {
341                        if let Some(content) = chunk.content {
342                            for part in content.parts {
343                                if let Part::Text { text } = part {
344                                    classification.push_str(&text);
345                                }
346                            }
347                        }
348                    }
349                    Err(e) => {
350                        yield Err(e);
351                        return;
352                    }
353                }
354            }
355
356            // Normalize classification
357            let classification = classification.trim().to_lowercase();
358
359            // Emit routing event
360            let mut routing_event = Event::new(&invocation_id);
361            routing_event.author = agent_name.clone();
362            routing_event.llm_response.content = Some(
363                Content::new("model").with_text(format!("[Routing to: {}]", classification))
364            );
365            yield Ok(routing_event);
366
367            // Find matching route
368            let target_agent = Self::resolve_route(&classification, &routes).or(default_agent);
369
370            // Execute target agent
371            if let Some(agent) = target_agent {
372                match agent.run(run_ctx.clone()).await {
373                    Ok(mut stream) => {
374                        while let Some(event) = stream.next().await {
375                            yield event;
376                        }
377                    }
378                    Err(e) => {
379                        yield Err(e);
380                    }
381                }
382            } else {
383                // No matching route and no default
384                let mut error_event = Event::new(&invocation_id);
385                error_event.author = agent_name.clone();
386                error_event.llm_response.content = Some(
387                    Content::new("model").with_text(format!(
388                        "No route found for classification '{}'. Available routes: {:?}",
389                        classification,
390                        routes.keys().collect::<Vec<_>>()
391                    ))
392                );
393                yield Ok(error_event);
394            }
395
396            // ===== AFTER AGENT CALLBACKS =====
397            for callback in after_callbacks.as_ref() {
398                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
399                    Ok(Some(content)) => {
400                        let mut after_event = Event::new(&invocation_id);
401                        after_event.author = agent_name.clone();
402                        after_event.llm_response.content = Some(content);
403                        yield Ok(after_event);
404                        break;
405                    }
406                    Ok(None) => continue,
407                    Err(e) => { yield Err(e); return; }
408                }
409            }
410        };
411
412        Ok(Box::pin(s))
413    }
414}