adk_agent/workflow/
llm_conditional_agent.rs

1//! LLM-based intelligent conditional routing agent.
2//!
3//! `LlmConditionalAgent` provides **intelligent, LLM-based** conditional routing.
4//! The model classifies user input and routes to the appropriate sub-agent.
5//!
6//! # When to Use
7//!
8//! Use `LlmConditionalAgent` for **intelligent** routing decisions:
9//! - Intent classification (technical vs general vs creative)
10//! - Multi-way routing (more than 2 destinations)
11//! - Context-aware routing that requires understanding the content
12//!
13//! # For Rule-Based Routing
14//!
15//! If you need **deterministic, rule-based** routing (e.g., A/B testing,
16//! feature flags), use [`ConditionalAgent`] instead.
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! let router = LlmConditionalAgent::new("router", model)
22//!     .instruction("Classify as 'technical', 'general', or 'creative'.
23//!                   Respond with ONLY the category name.")
24//!     .route("technical", Arc::new(tech_agent))
25//!     .route("general", Arc::new(general_agent))
26//!     .route("creative", Arc::new(creative_agent))
27//!     .default_route(Arc::new(general_agent))
28//!     .build()?;
29//! ```
30
31use adk_core::{
32    Agent, Content, Event, EventStream, InvocationContext, Llm, LlmRequest, Part, Result,
33};
34use async_stream::stream;
35use async_trait::async_trait;
36use futures::StreamExt;
37use std::collections::HashMap;
38use std::sync::Arc;
39
40/// LLM-based intelligent conditional routing agent.
41///
42/// Uses an LLM to classify user input and route to the appropriate sub-agent
43/// based on the classification result. Supports multi-way routing.
44///
45/// For rule-based routing (A/B testing, feature flags), use [`ConditionalAgent`].
46///
47/// # Example
48///
49/// ```rust,ignore
50/// let router = LlmConditionalAgent::new("router", model)
51///     .instruction("Classify as 'technical', 'general', or 'creative'.")
52///     .route("technical", tech_agent)
53///     .route("general", general_agent.clone())
54///     .route("creative", creative_agent)
55///     .default_route(general_agent)
56///     .build()?;
57/// ```
58pub struct LlmConditionalAgent {
59    name: String,
60    description: String,
61    model: Arc<dyn Llm>,
62    instruction: String,
63    routes: HashMap<String, Arc<dyn Agent>>,
64    default_agent: Option<Arc<dyn Agent>>,
65}
66
67pub struct LlmConditionalAgentBuilder {
68    name: String,
69    description: Option<String>,
70    model: Arc<dyn Llm>,
71    instruction: Option<String>,
72    routes: HashMap<String, Arc<dyn Agent>>,
73    default_agent: Option<Arc<dyn Agent>>,
74}
75
76impl LlmConditionalAgentBuilder {
77    /// Create a new builder with the given name and model.
78    pub fn new(name: impl Into<String>, model: Arc<dyn Llm>) -> Self {
79        Self {
80            name: name.into(),
81            description: None,
82            model,
83            instruction: None,
84            routes: HashMap::new(),
85            default_agent: None,
86        }
87    }
88
89    /// Set a description for the agent.
90    pub fn description(mut self, desc: impl Into<String>) -> Self {
91        self.description = Some(desc.into());
92        self
93    }
94
95    /// Set the classification instruction.
96    ///
97    /// The instruction should tell the LLM to classify the user's input
98    /// and respond with ONLY the category name (matching a route key).
99    pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
100        self.instruction = Some(instruction.into());
101        self
102    }
103
104    /// Add a route mapping a classification label to an agent.
105    ///
106    /// When the LLM's response contains this label, execution transfers
107    /// to the specified agent.
108    pub fn route(mut self, label: impl Into<String>, agent: Arc<dyn Agent>) -> Self {
109        self.routes.insert(label.into().to_lowercase(), agent);
110        self
111    }
112
113    /// Set the default agent to use when no route matches.
114    pub fn default_route(mut self, agent: Arc<dyn Agent>) -> Self {
115        self.default_agent = Some(agent);
116        self
117    }
118
119    /// Build the LlmConditionalAgent.
120    pub fn build(self) -> Result<LlmConditionalAgent> {
121        let instruction = self.instruction.ok_or_else(|| {
122            adk_core::AdkError::Agent("Instruction is required for LlmConditionalAgent".to_string())
123        })?;
124
125        if self.routes.is_empty() {
126            return Err(adk_core::AdkError::Agent(
127                "At least one route is required for LlmConditionalAgent".to_string(),
128            ));
129        }
130
131        Ok(LlmConditionalAgent {
132            name: self.name,
133            description: self.description.unwrap_or_default(),
134            model: self.model,
135            instruction,
136            routes: self.routes,
137            default_agent: self.default_agent,
138        })
139    }
140}
141
142impl LlmConditionalAgent {
143    /// Create a new builder for LlmConditionalAgent.
144    pub fn builder(name: impl Into<String>, model: Arc<dyn Llm>) -> LlmConditionalAgentBuilder {
145        LlmConditionalAgentBuilder::new(name, model)
146    }
147}
148
149#[async_trait]
150impl Agent for LlmConditionalAgent {
151    fn name(&self) -> &str {
152        &self.name
153    }
154
155    fn description(&self) -> &str {
156        &self.description
157    }
158
159    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
160        // Return empty - routes are internal
161        &[]
162    }
163
164    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
165        let model = self.model.clone();
166        let instruction = self.instruction.clone();
167        let routes = self.routes.clone();
168        let default_agent = self.default_agent.clone();
169        let invocation_id = ctx.invocation_id().to_string();
170        let agent_name = self.name.clone();
171
172        let s = stream! {
173            // Build classification request
174            let user_content = ctx.user_content().clone();
175            let user_text: String = user_content.parts.iter()
176                .filter_map(|p| if let Part::Text { text } = p { Some(text.as_str()) } else { None })
177                .collect::<Vec<_>>()
178                .join(" ");
179
180            let classification_prompt = format!(
181                "{}\n\nUser input: {}",
182                instruction,
183                user_text
184            );
185
186            let request = LlmRequest {
187                model: model.name().to_string(),
188                contents: vec![Content::new("user").with_text(&classification_prompt)],
189                tools: HashMap::new(),
190                config: None,
191            };
192
193            // Call LLM for classification
194            let mut response_stream = match model.generate_content(request, false).await {
195                Ok(stream) => stream,
196                Err(e) => {
197                    yield Err(e);
198                    return;
199                }
200            };
201
202            // Collect classification response
203            let mut classification = String::new();
204            while let Some(chunk_result) = response_stream.next().await {
205                match chunk_result {
206                    Ok(chunk) => {
207                        if let Some(content) = chunk.content {
208                            for part in content.parts {
209                                if let Part::Text { text } = part {
210                                    classification.push_str(&text);
211                                }
212                            }
213                        }
214                    }
215                    Err(e) => {
216                        yield Err(e);
217                        return;
218                    }
219                }
220            }
221
222            // Normalize classification
223            let classification = classification.trim().to_lowercase();
224
225            // Emit routing event
226            let mut routing_event = Event::new(&invocation_id);
227            routing_event.author = agent_name.clone();
228            routing_event.llm_response.content = Some(
229                Content::new("model").with_text(format!("[Routing to: {}]", classification))
230            );
231            yield Ok(routing_event);
232
233            // Find matching route
234            let target_agent = routes.iter()
235                .find(|(label, _)| classification.contains(label.as_str()))
236                .map(|(_, agent)| agent.clone())
237                .or(default_agent);
238
239            // Execute target agent
240            if let Some(agent) = target_agent {
241                match agent.run(ctx.clone()).await {
242                    Ok(mut stream) => {
243                        while let Some(event) = stream.next().await {
244                            yield event;
245                        }
246                    }
247                    Err(e) => {
248                        yield Err(e);
249                    }
250                }
251            } else {
252                // No matching route and no default
253                let mut error_event = Event::new(&invocation_id);
254                error_event.author = agent_name;
255                error_event.llm_response.content = Some(
256                    Content::new("model").with_text(format!(
257                        "No route found for classification '{}'. Available routes: {:?}",
258                        classification,
259                        routes.keys().collect::<Vec<_>>()
260                    ))
261                );
262                yield Ok(error_event);
263            }
264        };
265
266        Ok(Box::pin(s))
267    }
268}