Skip to main content

adk_agent/workflow/
llm_conditional_agent.rs

1//! LLM-based intelligent conditional routing agent.
2//!
3//! `LlmConditionalAgent` provides **intelligent, LLM-based** conditional routing.
4//! The model classifies user input and routes to the appropriate sub-agent.
5//!
6//! # When to Use
7//!
8//! Use `LlmConditionalAgent` for **intelligent** routing decisions:
9//! - Intent classification (technical vs general vs creative)
10//! - Multi-way routing (more than 2 destinations)
11//! - Context-aware routing that requires understanding the content
12//!
13//! # For Rule-Based Routing
14//!
15//! If you need **deterministic, rule-based** routing (e.g., A/B testing,
16//! feature flags), use [`ConditionalAgent`] instead.
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! let router = LlmConditionalAgent::builder("router", model)
22//!     .instruction("Classify as 'technical', 'general', or 'creative'.
23//!                   Respond with ONLY the category name.")
24//!     .route("technical", Arc::new(tech_agent))
25//!     .route("general", Arc::new(general_agent))
26//!     .route("creative", Arc::new(creative_agent))
27//!     .default_route(Arc::new(general_agent))
28//!     .build()?;
29//! ```
30
31use adk_core::{
32    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
33    InvocationContext, Llm, LlmRequest, Part, Result,
34};
35use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
36use async_stream::stream;
37use async_trait::async_trait;
38use futures::StreamExt;
39use std::collections::HashMap;
40use std::sync::Arc;
41
42/// LLM-based intelligent conditional routing agent.
43///
44/// Uses an LLM to classify user input and route to the appropriate sub-agent
45/// based on the classification result. Supports multi-way routing.
46///
47/// For rule-based routing (A/B testing, feature flags), use [`crate::ConditionalAgent`].
48///
49/// # Example
50///
51/// ```rust,ignore
52/// let router = LlmConditionalAgent::builder("router", model)
53///     .instruction("Classify as 'technical', 'general', or 'creative'.")
54///     .route("technical", tech_agent)
55///     .route("general", general_agent.clone())
56///     .route("creative", creative_agent)
57///     .default_route(general_agent)
58///     .build()?;
59/// ```
60pub struct LlmConditionalAgent {
61    name: String,
62    description: String,
63    model: Arc<dyn Llm>,
64    instruction: String,
65    routes: HashMap<String, Arc<dyn Agent>>,
66    default_agent: Option<Arc<dyn Agent>>,
67    /// Cached list of all route agents (+ default) for tree discovery via `sub_agents()`.
68    all_agents: Vec<Arc<dyn Agent>>,
69    skills_index: Option<Arc<SkillIndex>>,
70    skill_policy: SelectionPolicy,
71    max_skill_chars: usize,
72    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
73    after_callbacks: Arc<Vec<AfterAgentCallback>>,
74}
75
76pub struct LlmConditionalAgentBuilder {
77    name: String,
78    description: Option<String>,
79    model: Arc<dyn Llm>,
80    instruction: Option<String>,
81    routes: HashMap<String, Arc<dyn Agent>>,
82    default_agent: Option<Arc<dyn Agent>>,
83    skills_index: Option<Arc<SkillIndex>>,
84    skill_policy: SelectionPolicy,
85    max_skill_chars: usize,
86    before_callbacks: Vec<BeforeAgentCallback>,
87    after_callbacks: Vec<AfterAgentCallback>,
88}
89
90impl LlmConditionalAgentBuilder {
91    /// Create a new builder with the given name and model.
92    pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
93        Self {
94            name: name.into(),
95            description: None,
96            model,
97            instruction: None,
98            routes: HashMap::new(),
99            default_agent: None,
100            skills_index: None,
101            skill_policy: SelectionPolicy::default(),
102            max_skill_chars: 2000,
103            before_callbacks: Vec::new(),
104            after_callbacks: Vec::new(),
105        }
106    }
107
108    /// Set a description for the agent.
109    pub fn description(mut self, desc: impl Into<String>) -> Self {
110        self.description = Some(desc.into());
111        self
112    }
113
114    /// Set the classification instruction.
115    ///
116    /// The instruction should tell the LLM to classify the user's input
117    /// and respond with ONLY the category name (matching a route key).
118    pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
119        self.instruction = Some(instruction.into());
120        self
121    }
122
123    /// Add a route mapping a classification label to an agent.
124    ///
125    /// When the LLM's response contains this label, execution transfers
126    /// to the specified agent.
127    pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
128        self.routes.insert(label.into().to_lowercase(), agent);
129        self
130    }
131
132    /// Set the default agent to use when no route matches.
133    pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
134        self.default_agent = Some(agent);
135        self
136    }
137
138    pub fn with_skills(mut self, index: SkillIndex) -> Self {
139        self.skills_index = Some(Arc::new(index));
140        self
141    }
142
143    pub fn with_auto_skills(self) -> Result<Self> {
144        self.with_skills_from_root(".")
145    }
146
147    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
148        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
149        self.skills_index = Some(Arc::new(index));
150        Ok(self)
151    }
152
153    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
154        self.skill_policy = policy;
155        self
156    }
157
158    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
159        self.max_skill_chars = max_chars;
160        self
161    }
162
163    /// Add a before-agent callback.
164    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
165        self.before_callbacks.push(callback);
166        self
167    }
168
169    /// Add an after-agent callback.
170    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
171        self.after_callbacks.push(callback);
172        self
173    }
174
175    /// Build the LlmConditionalAgent.
176    pub fn build(self) -> Result<LlmConditionalAgent> {
177        let instruction = self.instruction.ok_or_else(|| {
178            adk_core::AdkError::agent("Instruction is required for LlmConditionalAgent")
179        })?;
180
181        if self.routes.is_empty() {
182            return Err(adk_core::AdkError::agent(
183                "At least one route is required for LlmConditionalAgent",
184            ));
185        }
186
187        // Collect all agents for sub_agents() tree discovery
188        let mut all_agents: Vec<Arc<dyn Agent>> = self.routes.values().cloned().collect();
189        if let Some(ref default) = self.default_agent {
190            all_agents.push(default.clone());
191        }
192
193        Ok(LlmConditionalAgent {
194            name: self.name,
195            description: self.description.unwrap_or_default(),
196            model: self.model,
197            instruction,
198            routes: self.routes,
199            default_agent: self.default_agent,
200            all_agents,
201            skills_index: self.skills_index,
202            skill_policy: self.skill_policy,
203            max_skill_chars: self.max_skill_chars,
204            before_callbacks: Arc::new(self.before_callbacks),
205            after_callbacks: Arc::new(self.after_callbacks),
206        })
207    }
208}
209
210impl LlmConditionalAgent {
211    /// Create a new builder for LlmConditionalAgent.
212    pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
213        LlmConditionalAgentBuilder::new(name, model)
214    }
215
216    fn resolve_route(
217        classification: &str,
218        routes: &HashMap<String, Arc<dyn Agent>>,
219    ) -> Option<Arc<dyn Agent>> {
220        if let Some(agent) = routes.get(classification) {
221            return Some(agent.clone());
222        }
223
224        let mut labels = routes.keys().collect::<Vec<_>>();
225        labels.sort_by(|left, right| right.len().cmp(&left.len()).then_with(|| left.cmp(right)));
226
227        labels
228            .into_iter()
229            .find(|label| classification.contains(label.as_str()))
230            .and_then(|label| routes.get(label).cloned())
231    }
232}
233
234#[async_trait]
235impl Agent for LlmConditionalAgent {
236    fn name(&self) -> &str {
237        &self.name
238    }
239
240    fn description(&self) -> &str {
241        &self.description
242    }
243
244    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
245        &self.all_agents
246    }
247
248    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
249        let run_ctx = super::skill_context::with_skill_injected_context(
250            ctx,
251            self.skills_index.as_ref(),
252            &self.skill_policy,
253            self.max_skill_chars,
254        );
255        let model = self.model.clone();
256        let instruction = self.instruction.clone();
257        let routes = self.routes.clone();
258        let default_agent = self.default_agent.clone();
259        let invocation_id = run_ctx.invocation_id().to_string();
260        let agent_name = self.name.clone();
261        let before_callbacks = self.before_callbacks.clone();
262        let after_callbacks = self.after_callbacks.clone();
263
264        let s = stream! {
265            // ===== BEFORE AGENT CALLBACKS =====
266            for callback in before_callbacks.as_ref() {
267                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
268                    Ok(Some(content)) => {
269                        let mut early_event = Event::new(&invocation_id);
270                        early_event.author = agent_name.clone();
271                        early_event.llm_response.content = Some(content);
272                        yield Ok(early_event);
273
274                        for after_cb in after_callbacks.as_ref() {
275                            match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
276                                Ok(Some(after_content)) => {
277                                    let mut after_event = Event::new(&invocation_id);
278                                    after_event.author = agent_name.clone();
279                                    after_event.llm_response.content = Some(after_content);
280                                    yield Ok(after_event);
281                                    return;
282                                }
283                                Ok(None) => continue,
284                                Err(e) => { yield Err(e); return; }
285                            }
286                        }
287                        return;
288                    }
289                    Ok(None) => continue,
290                    Err(e) => { yield Err(e); return; }
291                }
292            }
293
294            // Build classification request
295            let user_content = run_ctx.user_content().clone();
296            let user_text: String = user_content.parts.iter()
297                .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
298                .collect::<Vec<_>>()
299                .join(" ");
300
301            let classification_prompt = format!(
302                "{}\n\nUser input: {}",
303                instruction,
304                user_text
305            );
306
307            let request = LlmRequest {
308                model: model.name().to_string(),
309                contents: vec![Content::new("user").with_text(&classification_prompt)],
310                tools: HashMap::new(),
311                config: None,
312            };
313
314            // Call LLM for classification
315            let mut response_stream = match model.generate_content(request, false).await {
316                Ok(stream) => stream,
317                Err(e) => {
318                    yield Err(e);
319                    return;
320                }
321            };
322
323            // Collect classification response
324            let mut classification = String::new();
325            while let Some(chunk_result) = response_stream.next().await {
326                match chunk_result {
327                    Ok(chunk) => {
328                        if let Some(content) = chunk.content {
329                            for part in content.parts {
330                                if let Part::Text { text } = part {
331                                    classification.push_str(&text);
332                                }
333                            }
334                        }
335                    }
336                    Err(e) => {
337                        yield Err(e);
338                        return;
339                    }
340                }
341            }
342
343            // Normalize classification
344            let classification = classification.trim().to_lowercase();
345
346            // Emit routing event
347            let mut routing_event = Event::new(&invocation_id);
348            routing_event.author = agent_name.clone();
349            routing_event.llm_response.content = Some(
350                Content::new("model").with_text(format!("[Routing to: {}]", classification))
351            );
352            yield Ok(routing_event);
353
354            // Find matching route
355            let target_agent = Self::resolve_route(&classification, &routes).or(default_agent);
356
357            // Execute target agent
358            if let Some(agent) = target_agent {
359                match agent.run(run_ctx.clone()).await {
360                    Ok(mut stream) => {
361                        while let Some(event) = stream.next().await {
362                            yield event;
363                        }
364                    }
365                    Err(e) => {
366                        yield Err(e);
367                    }
368                }
369            } else {
370                // No matching route and no default
371                let mut error_event = Event::new(&invocation_id);
372                error_event.author = agent_name.clone();
373                error_event.llm_response.content = Some(
374                    Content::new("model").with_text(format!(
375                        "No route found for classification '{}'. Available routes: {:?}",
376                        classification,
377                        routes.keys().collect::<Vec<_>>()
378                    ))
379                );
380                yield Ok(error_event);
381            }
382
383            // ===== AFTER AGENT CALLBACKS =====
384            for callback in after_callbacks.as_ref() {
385                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
386                    Ok(Some(content)) => {
387                        let mut after_event = Event::new(&invocation_id);
388                        after_event.author = agent_name.clone();
389                        after_event.llm_response.content = Some(content);
390                        yield Ok(after_event);
391                        break;
392                    }
393                    Ok(None) => continue,
394                    Err(e) => { yield Err(e); return; }
395                }
396            }
397        };
398
399        Ok(Box::pin(s))
400    }
401}