Skip to main content

adk_agent/workflow/
conditional_agent.rs

1//! Rule-based conditional routing agent.
2//!
3//! `ConditionalAgent` provides **synchronous, rule-based** conditional routing.
4//! The condition function is evaluated synchronously and must return a boolean.
5//!
6//! # When to Use
7//!
8//! Use `ConditionalAgent` for **deterministic** routing decisions:
9//! - A/B testing based on session state or flags
10//! - Environment-based routing (e.g., production vs staging)
11//! - Feature flag checks
12//!
13//! # For Intelligent Routing
14//!
15//! If you need **LLM-based intelligent routing** where the model classifies
16//! user intent and routes accordingly, use [`LlmConditionalAgent`] instead:
17//!
18//! ```rust,ignore
19//! // LLM decides which agent to route to
20//! let router = LlmConditionalAgent::builder("router", model)
21//!     .instruction("Classify as 'technical' or 'general'")
22//!     .route("technical", tech_agent)
23//!     .route("general", general_agent)
24//!     .build()?;
25//! ```
26//!
27//! See [`crate::workflow::LlmConditionalAgent`] for details.
28
29#[cfg(feature = "skills")]
30use crate::skill_shim::load_skill_index;
31use crate::skill_shim::{SelectionPolicy, SkillIndex};
32use adk_core::{
33    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
34    InvocationContext, Result,
35};
36use async_stream::stream;
37use async_trait::async_trait;
38use futures::StreamExt;
39use std::sync::Arc;
40
41type ConditionFn = Arc<dyn Fn(&dyn InvocationContext) -> bool + Send + Sync>;
42
43/// Rule-based conditional routing agent.
44///
45/// Executes one of two sub-agents based on a synchronous condition function.
46/// For LLM-based intelligent routing, use [`crate::LlmConditionalAgent`] instead.
47///
48/// # Example
49///
50/// ```rust,ignore
51/// // Route based on session state flag
52/// let router = ConditionalAgent::new(
53///     "premium_router",
54///     |ctx| ctx.session().state().get("is_premium").map(|v| v.as_bool()).flatten().unwrap_or(false),
55///     Arc::new(premium_agent),
56/// ).with_else(Arc::new(basic_agent));
57/// ```
58pub struct ConditionalAgent {
59    name: String,
60    description: String,
61    condition: ConditionFn,
62    if_agent: Arc<dyn Agent>,
63    else_agent: Option<Arc<dyn Agent>>,
64    /// Cached list of all branch agents for tree discovery via `sub_agents()`.
65    all_agents: Vec<Arc<dyn Agent>>,
66    skills_index: Option<Arc<SkillIndex>>,
67    skill_policy: SelectionPolicy,
68    max_skill_chars: usize,
69    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
70    after_callbacks: Arc<Vec<AfterAgentCallback>>,
71}
72
73impl ConditionalAgent {
74    pub fn new<F>(name: impl Into<String>, condition: F, if_agent: Arc<dyn Agent>) -> Self
75    where
76        F: Fn(&dyn InvocationContext) -> bool + Send + Sync + 'static,
77    {
78        let all_agents = vec![if_agent.clone()];
79        Self {
80            name: name.into(),
81            description: String::new(),
82            condition: Arc::new(condition),
83            if_agent,
84            else_agent: None,
85            all_agents,
86            skills_index: None,
87            skill_policy: SelectionPolicy::default(),
88            max_skill_chars: 2000,
89            before_callbacks: Arc::new(Vec::new()),
90            after_callbacks: Arc::new(Vec::new()),
91        }
92    }
93
94    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
95        self.description = desc.into();
96        self
97    }
98
99    pub fn with_else(mut self, else_agent: Arc<dyn Agent>) -> Self {
100        self.all_agents.push(else_agent.clone());
101        self.else_agent = Some(else_agent);
102        self
103    }
104
105    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
106        if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
107            callbacks.push(callback);
108        }
109        self
110    }
111
112    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
113        if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
114            callbacks.push(callback);
115        }
116        self
117    }
118
119    #[cfg(feature = "skills")]
120    pub fn with_skills(mut self, index: SkillIndex) -> Self {
121        self.skills_index = Some(Arc::new(index));
122        self
123    }
124
125    #[cfg(feature = "skills")]
126    pub fn with_auto_skills(self) -> Result<Self> {
127        self.with_skills_from_root(".")
128    }
129
130    #[cfg(feature = "skills")]
131    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
132        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
133        self.skills_index = Some(Arc::new(index));
134        Ok(self)
135    }
136
137    #[cfg(feature = "skills")]
138    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
139        self.skill_policy = policy;
140        self
141    }
142
143    #[cfg(feature = "skills")]
144    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
145        self.max_skill_chars = max_chars;
146        self
147    }
148}
149
150#[async_trait]
151impl Agent for ConditionalAgent {
152    fn name(&self) -> &str {
153        &self.name
154    }
155
156    fn description(&self) -> &str {
157        &self.description
158    }
159
160    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
161        &self.all_agents
162    }
163
164    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
165        let run_ctx = super::skill_context::with_skill_injected_context(
166            ctx,
167            self.skills_index.as_ref(),
168            &self.skill_policy,
169            self.max_skill_chars,
170        );
171        let before_callbacks = self.before_callbacks.clone();
172        let after_callbacks = self.after_callbacks.clone();
173        let if_agent = self.if_agent.clone();
174        let else_agent = self.else_agent.clone();
175        let agent_name = self.name.clone();
176        let invocation_id = run_ctx.invocation_id().to_string();
177        let condition = self.condition.clone();
178
179        let s = stream! {
180            for callback in before_callbacks.as_ref() {
181                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
182                    Ok(Some(content)) => {
183                        let mut early_event = Event::new(&invocation_id);
184                        early_event.author = agent_name.clone();
185                        early_event.llm_response.content = Some(content);
186                        yield Ok(early_event);
187
188                        for after_callback in after_callbacks.as_ref() {
189                            match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
190                                Ok(Some(after_content)) => {
191                                    let mut after_event = Event::new(&invocation_id);
192                                    after_event.author = agent_name.clone();
193                                    after_event.llm_response.content = Some(after_content);
194                                    yield Ok(after_event);
195                                    return;
196                                }
197                                Ok(None) => continue,
198                                Err(e) => {
199                                    yield Err(e);
200                                    return;
201                                }
202                            }
203                        }
204                        return;
205                    }
206                    Ok(None) => continue,
207                    Err(e) => {
208                        yield Err(e);
209                        return;
210                    }
211                }
212            }
213
214            let target_agent = if condition(run_ctx.as_ref()) {
215                Some(if_agent)
216            } else {
217                else_agent
218            };
219
220            if let Some(agent) = target_agent {
221                let mut stream = match agent.run(run_ctx.clone()).await {
222                    Ok(stream) => stream,
223                    Err(e) => {
224                        yield Err(e);
225                        return;
226                    }
227                };
228
229                while let Some(result) = stream.next().await {
230                    match result {
231                        Ok(event) => yield Ok(event),
232                        Err(e) => {
233                            yield Err(e);
234                            return;
235                        }
236                    }
237                }
238            }
239
240            for callback in after_callbacks.as_ref() {
241                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
242                    Ok(Some(content)) => {
243                        let mut after_event = Event::new(&invocation_id);
244                        after_event.author = agent_name.clone();
245                        after_event.llm_response.content = Some(content);
246                        yield Ok(after_event);
247                        break;
248                    }
249                    Ok(None) => continue,
250                    Err(e) => {
251                        yield Err(e);
252                        return;
253                    }
254                }
255            }
256        };
257
258        Ok(Box::pin(s))
259    }
260}