Skip to main content

adk_agent/workflow/
llm_conditional_agent.rs

1//! LLM-based intelligent conditional routing agent.
2//!
3//! `LlmConditionalAgent` provides **intelligent, LLM-based** conditional routing.
4//! The model classifies user input and routes to the appropriate sub-agent.
5//!
6//! # When to Use
7//!
8//! Use `LlmConditionalAgent` for **intelligent** routing decisions:
9//! - Intent classification (technical vs general vs creative)
10//! - Multi-way routing (more than 2 destinations)
11//! - Context-aware routing that requires understanding the content
12//!
13//! # For Rule-Based Routing
14//!
15//! If you need **deterministic, rule-based** routing (e.g., A/B testing,
16//! feature flags), use [`ConditionalAgent`] instead.
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! let router = LlmConditionalAgent::builder("router", model)
22//!     .instruction("Classify as 'technical', 'general', or 'creative'.
23//!                   Respond with ONLY the category name.")
24//!     .route("technical", Arc::new(tech_agent))
25//!     .route("general", Arc::new(general_agent))
26//!     .route("creative", Arc::new(creative_agent))
27//!     .default_route(Arc::new(general_agent))
28//!     .build()?;
29//! ```
30
31#[cfg(feature = "skills")]
32use crate::skill_shim::load_skill_index;
33use crate::skill_shim::{SelectionPolicy, SkillIndex};
34use adk_core::{
35    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
36    InvocationContext, Llm, LlmRequest, Part, Result,
37};
38use async_stream::stream;
39use async_trait::async_trait;
40use futures::StreamExt;
41use std::collections::HashMap;
42use std::sync::Arc;
43
44/// LLM-based intelligent conditional routing agent.
45///
46/// Uses an LLM to classify user input and route to the appropriate sub-agent
47/// based on the classification result. Supports multi-way routing.
48///
49/// For rule-based routing (A/B testing, feature flags), use [`crate::ConditionalAgent`].
50///
51/// # Example
52///
53/// ```rust,ignore
54/// let router = LlmConditionalAgent::builder("router", model)
55///     .instruction("Classify as 'technical', 'general', or 'creative'.")
56///     .route("technical", tech_agent)
57///     .route("general", general_agent.clone())
58///     .route("creative", creative_agent)
59///     .default_route(general_agent)
60///     .build()?;
61/// ```
62pub struct LlmConditionalAgent {
63    name: String,
64    description: String,
65    model: Arc<dyn Llm>,
66    instruction: String,
67    routes: HashMap<String, Arc<dyn Agent>>,
68    default_agent: Option<Arc<dyn Agent>>,
69    /// Cached list of all route agents (+ default) for tree discovery via `sub_agents()`.
70    all_agents: Vec<Arc<dyn Agent>>,
71    skills_index: Option<Arc<SkillIndex>>,
72    skill_policy: SelectionPolicy,
73    max_skill_chars: usize,
74    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
75    after_callbacks: Arc<Vec<AfterAgentCallback>>,
76}
77
78pub struct LlmConditionalAgentBuilder {
79    name: String,
80    description: Option<String>,
81    model: Arc<dyn Llm>,
82    instruction: Option<String>,
83    routes: HashMap<String, Arc<dyn Agent>>,
84    default_agent: Option<Arc<dyn Agent>>,
85    skills_index: Option<Arc<SkillIndex>>,
86    skill_policy: SelectionPolicy,
87    max_skill_chars: usize,
88    before_callbacks: Vec<BeforeAgentCallback>,
89    after_callbacks: Vec<AfterAgentCallback>,
90}
91
92impl LlmConditionalAgentBuilder {
93    /// Create a new builder with the given name and model.
94    pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
95        Self {
96            name: name.into(),
97            description: None,
98            model,
99            instruction: None,
100            routes: HashMap::new(),
101            default_agent: None,
102            skills_index: None,
103            skill_policy: SelectionPolicy::default(),
104            max_skill_chars: 2000,
105            before_callbacks: Vec::new(),
106            after_callbacks: Vec::new(),
107        }
108    }
109
110    /// Set a description for the agent.
111    pub fn description(mut self, desc: impl Into<String>) -> Self {
112        self.description = Some(desc.into());
113        self
114    }
115
116    /// Set the classification instruction.
117    ///
118    /// The instruction should tell the LLM to classify the user's input
119    /// and respond with ONLY the category name (matching a route key).
120    pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
121        self.instruction = Some(instruction.into());
122        self
123    }
124
125    /// Add a route mapping a classification label to an agent.
126    ///
127    /// When the LLM's response contains this label, execution transfers
128    /// to the specified agent.
129    pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
130        self.routes.insert(label.into().to_lowercase(), agent);
131        self
132    }
133
134    /// Set the default agent to use when no route matches.
135    pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
136        self.default_agent = Some(agent);
137        self
138    }
139
140    #[cfg(feature = "skills")]
141    pub fn with_skills(mut self, index: SkillIndex) -> Self {
142        self.skills_index = Some(Arc::new(index));
143        self
144    }
145
146    #[cfg(feature = "skills")]
147    pub fn with_auto_skills(self) -> Result<Self> {
148        self.with_skills_from_root(".")
149    }
150
151    #[cfg(feature = "skills")]
152    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
153        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
154        self.skills_index = Some(Arc::new(index));
155        Ok(self)
156    }
157
158    #[cfg(feature = "skills")]
159    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
160        self.skill_policy = policy;
161        self
162    }
163
164    #[cfg(feature = "skills")]
165    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
166        self.max_skill_chars = max_chars;
167        self
168    }
169
170    /// Add a before-agent callback.
171    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
172        self.before_callbacks.push(callback);
173        self
174    }
175
176    /// Add an after-agent callback.
177    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
178        self.after_callbacks.push(callback);
179        self
180    }
181
182    /// Build the LlmConditionalAgent.
183    pub fn build(self) -> Result<LlmConditionalAgent> {
184        let instruction = self.instruction.ok_or_else(|| {
185            adk_core::AdkError::agent("Instruction is required for LlmConditionalAgent")
186        })?;
187
188        if self.routes.is_empty() {
189            return Err(adk_core::AdkError::agent(
190                "At least one route is required for LlmConditionalAgent",
191            ));
192        }
193
194        // Collect all agents for sub_agents() tree discovery
195        let mut all_agents: Vec<Arc<dyn Agent>> = self.routes.values().cloned().collect();
196        if let Some(ref default) = self.default_agent {
197            all_agents.push(default.clone());
198        }
199
200        Ok(LlmConditionalAgent {
201            name: self.name,
202            description: self.description.unwrap_or_default(),
203            model: self.model,
204            instruction,
205            routes: self.routes,
206            default_agent: self.default_agent,
207            all_agents,
208            skills_index: self.skills_index,
209            skill_policy: self.skill_policy,
210            max_skill_chars: self.max_skill_chars,
211            before_callbacks: Arc::new(self.before_callbacks),
212            after_callbacks: Arc::new(self.after_callbacks),
213        })
214    }
215}
216
217impl LlmConditionalAgent {
218    /// Create a new builder for LlmConditionalAgent.
219    pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
220        LlmConditionalAgentBuilder::new(name, model)
221    }
222
223    fn resolve_route(
224        classification: &str,
225        routes: &HashMap<String, Arc<dyn Agent>>,
226    ) -> Option<Arc<dyn Agent>> {
227        if let Some(agent) = routes.get(classification) {
228            return Some(agent.clone());
229        }
230
231        let mut labels = routes.keys().collect::<Vec<_>>();
232        labels.sort_by(|left, right| right.len().cmp(&left.len()).then_with(|| left.cmp(right)));
233
234        labels
235            .into_iter()
236            .find(|label| classification.contains(label.as_str()))
237            .and_then(|label| routes.get(label).cloned())
238    }
239}
240
241#[async_trait]
242impl Agent for LlmConditionalAgent {
243    fn name(&self) -> &str {
244        &self.name
245    }
246
247    fn description(&self) -> &str {
248        &self.description
249    }
250
251    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
252        &self.all_agents
253    }
254
255    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
256        let run_ctx = super::skill_context::with_skill_injected_context(
257            ctx,
258            self.skills_index.as_ref(),
259            &self.skill_policy,
260            self.max_skill_chars,
261        );
262        let model = self.model.clone();
263        let instruction = self.instruction.clone();
264        let routes = self.routes.clone();
265        let default_agent = self.default_agent.clone();
266        let invocation_id = run_ctx.invocation_id().to_string();
267        let agent_name = self.name.clone();
268        let before_callbacks = self.before_callbacks.clone();
269        let after_callbacks = self.after_callbacks.clone();
270
271        let s = stream! {
272            // ===== BEFORE AGENT CALLBACKS =====
273            for callback in before_callbacks.as_ref() {
274                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
275                    Ok(Some(content)) => {
276                        let mut early_event = Event::new(&invocation_id);
277                        early_event.author = agent_name.clone();
278                        early_event.llm_response.content = Some(content);
279                        yield Ok(early_event);
280
281                        for after_cb in after_callbacks.as_ref() {
282                            match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
283                                Ok(Some(after_content)) => {
284                                    let mut after_event = Event::new(&invocation_id);
285                                    after_event.author = agent_name.clone();
286                                    after_event.llm_response.content = Some(after_content);
287                                    yield Ok(after_event);
288                                    return;
289                                }
290                                Ok(None) => continue,
291                                Err(e) => { yield Err(e); return; }
292                            }
293                        }
294                        return;
295                    }
296                    Ok(None) => continue,
297                    Err(e) => { yield Err(e); return; }
298                }
299            }
300
301            // Build classification request
302            let user_content = run_ctx.user_content().clone();
303            let user_text: String = user_content.parts.iter()
304                .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
305                .collect::<Vec<_>>()
306                .join(" ");
307
308            let classification_prompt = format!(
309                "{}\n\nUser input: {}",
310                instruction,
311                user_text
312            );
313
314            let request = LlmRequest {
315                model: model.name().to_string(),
316                contents: vec![Content::new("user").with_text(&classification_prompt)],
317                tools: HashMap::new(),
318                config: None,
319            };
320
321            // Call LLM for classification
322            let mut response_stream = match model.generate_content(request, false).await {
323                Ok(stream) => stream,
324                Err(e) => {
325                    yield Err(e);
326                    return;
327                }
328            };
329
330            // Collect classification response
331            let mut classification = String::new();
332            while let Some(chunk_result) = response_stream.next().await {
333                match chunk_result {
334                    Ok(chunk) => {
335                        if let Some(content) = chunk.content {
336                            for part in content.parts {
337                                if let Part::Text { text } = part {
338                                    classification.push_str(&text);
339                                }
340                            }
341                        }
342                    }
343                    Err(e) => {
344                        yield Err(e);
345                        return;
346                    }
347                }
348            }
349
350            // Normalize classification
351            let classification = classification.trim().to_lowercase();
352
353            // Emit routing event
354            let mut routing_event = Event::new(&invocation_id);
355            routing_event.author = agent_name.clone();
356            routing_event.llm_response.content = Some(
357                Content::new("model").with_text(format!("[Routing to: {}]", classification))
358            );
359            yield Ok(routing_event);
360
361            // Find matching route
362            let target_agent = Self::resolve_route(&classification, &routes).or(default_agent);
363
364            // Execute target agent
365            if let Some(agent) = target_agent {
366                match agent.run(run_ctx.clone()).await {
367                    Ok(mut stream) => {
368                        while let Some(event) = stream.next().await {
369                            yield event;
370                        }
371                    }
372                    Err(e) => {
373                        yield Err(e);
374                    }
375                }
376            } else {
377                // No matching route and no default
378                let mut error_event = Event::new(&invocation_id);
379                error_event.author = agent_name.clone();
380                error_event.llm_response.content = Some(
381                    Content::new("model").with_text(format!(
382                        "No route found for classification '{}'. Available routes: {:?}",
383                        classification,
384                        routes.keys().collect::<Vec<_>>()
385                    ))
386                );
387                yield Ok(error_event);
388            }
389
390            // ===== AFTER AGENT CALLBACKS =====
391            for callback in after_callbacks.as_ref() {
392                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
393                    Ok(Some(content)) => {
394                        let mut after_event = Event::new(&invocation_id);
395                        after_event.author = agent_name.clone();
396                        after_event.llm_response.content = Some(content);
397                        yield Ok(after_event);
398                        break;
399                    }
400                    Ok(None) => continue,
401                    Err(e) => { yield Err(e); return; }
402                }
403            }
404        };
405
406        Ok(Box::pin(s))
407    }
408}