Skip to main content

adk_agent/workflow/
llm_conditional_agent.rs

1//! LLM-based intelligent conditional routing agent.
2//!
3//! `LlmConditionalAgent` provides **intelligent, LLM-based** conditional routing.
4//! The model classifies user input and routes to the appropriate sub-agent.
5//!
6//! # When to Use
7//!
8//! Use `LlmConditionalAgent` for **intelligent** routing decisions:
9//! - Intent classification (technical vs general vs creative)
10//! - Multi-way routing (more than 2 destinations)
11//! - Context-aware routing that requires understanding the content
12//!
13//! # For Rule-Based Routing
14//!
15//! If you need **deterministic, rule-based** routing (e.g., A/B testing,
16//! feature flags), use [`ConditionalAgent`] instead.
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! let router = LlmConditionalAgent::builder("router", model)
22//!     .instruction("Classify as 'technical', 'general', or 'creative'.
23//!                   Respond with ONLY the category name.")
24//!     .route("technical", Arc::new(tech_agent))
25//!     .route("general", Arc::new(general_agent))
26//!     .route("creative", Arc::new(creative_agent))
27//!     .default_route(Arc::new(general_agent))
28//!     .build()?;
29//! ```
30
31use adk_core::{
32    Agent, Content, Event, EventStream, InvocationContext, Llm, LlmRequest, Part, Result,
33};
34use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
35use async_stream::stream;
36use async_trait::async_trait;
37use futures::StreamExt;
38use std::collections::HashMap;
39use std::sync::Arc;
40
41/// LLM-based intelligent conditional routing agent.
42///
43/// Uses an LLM to classify user input and route to the appropriate sub-agent
44/// based on the classification result. Supports multi-way routing.
45///
46/// For rule-based routing (A/B testing, feature flags), use [`crate::ConditionalAgent`].
47///
48/// # Example
49///
50/// ```rust,ignore
51/// let router = LlmConditionalAgent::builder("router", model)
52///     .instruction("Classify as 'technical', 'general', or 'creative'.")
53///     .route("technical", tech_agent)
54///     .route("general", general_agent.clone())
55///     .route("creative", creative_agent)
56///     .default_route(general_agent)
57///     .build()?;
58/// ```
59pub struct LlmConditionalAgent {
60    name: String,
61    description: String,
62    model: Arc<dyn Llm>,
63    instruction: String,
64    routes: HashMap<String, Arc<dyn Agent>>,
65    default_agent: Option<Arc<dyn Agent>>,
66    /// Cached list of all route agents (+ default) for tree discovery via `sub_agents()`.
67    all_agents: Vec<Arc<dyn Agent>>,
68    skills_index: Option<Arc<SkillIndex>>,
69    skill_policy: SelectionPolicy,
70    max_skill_chars: usize,
71}
72
73pub struct LlmConditionalAgentBuilder {
74    name: String,
75    description: Option<String>,
76    model: Arc<dyn Llm>,
77    instruction: Option<String>,
78    routes: HashMap<String, Arc<dyn Agent>>,
79    default_agent: Option<Arc<dyn Agent>>,
80    skills_index: Option<Arc<SkillIndex>>,
81    skill_policy: SelectionPolicy,
82    max_skill_chars: usize,
83}
84
85impl LlmConditionalAgentBuilder {
86    /// Create a new builder with the given name and model.
87    pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
88        Self {
89            name: name.into(),
90            description: None,
91            model,
92            instruction: None,
93            routes: HashMap::new(),
94            default_agent: None,
95            skills_index: None,
96            skill_policy: SelectionPolicy::default(),
97            max_skill_chars: 2000,
98        }
99    }
100
101    /// Set a description for the agent.
102    pub fn description(mut self, desc: impl Into<String>) -> Self {
103        self.description = Some(desc.into());
104        self
105    }
106
107    /// Set the classification instruction.
108    ///
109    /// The instruction should tell the LLM to classify the user's input
110    /// and respond with ONLY the category name (matching a route key).
111    pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
112        self.instruction = Some(instruction.into());
113        self
114    }
115
116    /// Add a route mapping a classification label to an agent.
117    ///
118    /// When the LLM's response contains this label, execution transfers
119    /// to the specified agent.
120    pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
121        self.routes.insert(label.into().to_lowercase(), agent);
122        self
123    }
124
125    /// Set the default agent to use when no route matches.
126    pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
127        self.default_agent = Some(agent);
128        self
129    }
130
131    pub fn with_skills(mut self, index: SkillIndex) -> Self {
132        self.skills_index = Some(Arc::new(index));
133        self
134    }
135
136    pub fn with_auto_skills(self) -> Result<Self> {
137        self.with_skills_from_root(".")
138    }
139
140    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
141        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
142        self.skills_index = Some(Arc::new(index));
143        Ok(self)
144    }
145
146    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
147        self.skill_policy = policy;
148        self
149    }
150
151    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
152        self.max_skill_chars = max_chars;
153        self
154    }
155
156    /// Build the LlmConditionalAgent.
157    pub fn build(self) -> Result<LlmConditionalAgent> {
158        let instruction = self.instruction.ok_or_else(|| {
159            adk_core::AdkError::Agent("Instruction is required for LlmConditionalAgent".to_string())
160        })?;
161
162        if self.routes.is_empty() {
163            return Err(adk_core::AdkError::Agent(
164                "At least one route is required for LlmConditionalAgent".to_string(),
165            ));
166        }
167
168        // Collect all agents for sub_agents() tree discovery
169        let mut all_agents: Vec<Arc<dyn Agent>> = self.routes.values().cloned().collect();
170        if let Some(ref default) = self.default_agent {
171            all_agents.push(default.clone());
172        }
173
174        Ok(LlmConditionalAgent {
175            name: self.name,
176            description: self.description.unwrap_or_default(),
177            model: self.model,
178            instruction,
179            routes: self.routes,
180            default_agent: self.default_agent,
181            all_agents,
182            skills_index: self.skills_index,
183            skill_policy: self.skill_policy,
184            max_skill_chars: self.max_skill_chars,
185        })
186    }
187}
188
189impl LlmConditionalAgent {
190    /// Create a new builder for LlmConditionalAgent.
191    pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
192        LlmConditionalAgentBuilder::new(name, model)
193    }
194
195    fn resolve_route(
196        classification: &str,
197        routes: &HashMap<String, Arc<dyn Agent>>,
198    ) -> Option<Arc<dyn Agent>> {
199        if let Some(agent) = routes.get(classification) {
200            return Some(agent.clone());
201        }
202
203        let mut labels = routes.keys().collect::<Vec<_>>();
204        labels.sort_by(|left, right| right.len().cmp(&left.len()).then_with(|| left.cmp(right)));
205
206        labels
207            .into_iter()
208            .find(|label| classification.contains(label.as_str()))
209            .and_then(|label| routes.get(label).cloned())
210    }
211}
212
213#[async_trait]
214impl Agent for LlmConditionalAgent {
215    fn name(&self) -> &str {
216        &self.name
217    }
218
219    fn description(&self) -> &str {
220        &self.description
221    }
222
223    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
224        &self.all_agents
225    }
226
227    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
228        let run_ctx = super::skill_context::with_skill_injected_context(
229            ctx,
230            self.skills_index.as_ref(),
231            &self.skill_policy,
232            self.max_skill_chars,
233        );
234        let model = self.model.clone();
235        let instruction = self.instruction.clone();
236        let routes = self.routes.clone();
237        let default_agent = self.default_agent.clone();
238        let invocation_id = run_ctx.invocation_id().to_string();
239        let agent_name = self.name.clone();
240
241        let s = stream! {
242            // Build classification request
243            let user_content = run_ctx.user_content().clone();
244            let user_text: String = user_content.parts.iter()
245                .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
246                .collect::<Vec<_>>()
247                .join(" ");
248
249            let classification_prompt = format!(
250                "{}\n\nUser input: {}",
251                instruction,
252                user_text
253            );
254
255            let request = LlmRequest {
256                model: model.name().to_string(),
257                contents: vec![Content::new("user").with_text(&classification_prompt)],
258                tools: HashMap::new(),
259                config: None,
260            };
261
262            // Call LLM for classification
263            let mut response_stream = match model.generate_content(request, false).await {
264                Ok(stream) => stream,
265                Err(e) => {
266                    yield Err(e);
267                    return;
268                }
269            };
270
271            // Collect classification response
272            let mut classification = String::new();
273            while let Some(chunk_result) = response_stream.next().await {
274                match chunk_result {
275                    Ok(chunk) => {
276                        if let Some(content) = chunk.content {
277                            for part in content.parts {
278                                if let Part::Text { text } = part {
279                                    classification.push_str(&text);
280                                }
281                            }
282                        }
283                    }
284                    Err(e) => {
285                        yield Err(e);
286                        return;
287                    }
288                }
289            }
290
291            // Normalize classification
292            let classification = classification.trim().to_lowercase();
293
294            // Emit routing event
295            let mut routing_event = Event::new(&invocation_id);
296            routing_event.author = agent_name.clone();
297            routing_event.llm_response.content = Some(
298                Content::new("model").with_text(format!("[Routing to: {}]", classification))
299            );
300            yield Ok(routing_event);
301
302            // Find matching route
303            let target_agent = Self::resolve_route(&classification, &routes).or(default_agent);
304
305            // Execute target agent
306            if let Some(agent) = target_agent {
307                match agent.run(run_ctx.clone()).await {
308                    Ok(mut stream) => {
309                        while let Some(event) = stream.next().await {
310                            yield event;
311                        }
312                    }
313                    Err(e) => {
314                        yield Err(e);
315                    }
316                }
317            } else {
318                // No matching route and no default
319                let mut error_event = Event::new(&invocation_id);
320                error_event.author = agent_name;
321                error_event.llm_response.content = Some(
322                    Content::new("model").with_text(format!(
323                        "No route found for classification '{}'. Available routes: {:?}",
324                        classification,
325                        routes.keys().collect::<Vec<_>>()
326                    ))
327                );
328                yield Ok(error_event);
329            }
330        };
331
332        Ok(Box::pin(s))
333    }
334}