Skip to main content

adk_agent/workflow/
conditional_agent.rs

1//! Rule-based conditional routing agent.
2//!
3//! `ConditionalAgent` provides **synchronous, rule-based** conditional routing.
4//! The condition function is evaluated synchronously and must return a boolean.
5//!
6//! # When to Use
7//!
8//! Use `ConditionalAgent` for **deterministic** routing decisions:
9//! - A/B testing based on session state or flags
10//! - Environment-based routing (e.g., production vs staging)
11//! - Feature flag checks
12//!
13//! # For Intelligent Routing
14//!
15//! If you need **LLM-based intelligent routing** where the model classifies
16//! user intent and routes accordingly, use [`LlmConditionalAgent`] instead:
17//!
18//! ```rust,ignore
19//! // LLM decides which agent to route to
20//! let router = LlmConditionalAgent::builder("router", model)
21//!     .instruction("Classify as 'technical' or 'general'")
22//!     .route("technical", tech_agent)
23//!     .route("general", general_agent)
24//!     .build()?;
25//! ```
26//!
27//! See [`crate::workflow::LlmConditionalAgent`] for details.
28
29#[cfg(feature = "skills")]
30use crate::skill_shim::load_skill_index;
31use crate::skill_shim::{SelectionPolicy, SkillIndex};
32use adk_core::{
33    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
34    InvocationContext, Result,
35};
36use async_stream::stream;
37use async_trait::async_trait;
38use futures::StreamExt;
39use std::sync::Arc;
40
41type ConditionFn = Arc<dyn Fn(&dyn InvocationContext) -> bool + Send + Sync>;
42
43/// Rule-based conditional routing agent.
44///
45/// Executes one of two sub-agents based on a synchronous condition function.
46/// For LLM-based intelligent routing, use [`crate::LlmConditionalAgent`] instead.
47///
48/// # Example
49///
50/// ```rust,ignore
51/// // Route based on session state flag
52/// let router = ConditionalAgent::new(
53///     "premium_router",
54///     |ctx| ctx.session().state().get("is_premium").map(|v| v.as_bool()).flatten().unwrap_or(false),
55///     Arc::new(premium_agent),
56/// ).with_else(Arc::new(basic_agent));
57/// ```
58pub struct ConditionalAgent {
59    name: String,
60    description: String,
61    condition: ConditionFn,
62    if_agent: Arc<dyn Agent>,
63    else_agent: Option<Arc<dyn Agent>>,
64    /// Cached list of all branch agents for tree discovery via `sub_agents()`.
65    all_agents: Vec<Arc<dyn Agent>>,
66    skills_index: Option<Arc<SkillIndex>>,
67    skill_policy: SelectionPolicy,
68    max_skill_chars: usize,
69    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
70    after_callbacks: Arc<Vec<AfterAgentCallback>>,
71}
72
73impl ConditionalAgent {
74    /// Create a new conditional agent with a condition function and the if-branch agent.
75    pub fn new<F>(name: impl Into<String>, condition: F, if_agent: Arc<dyn Agent>) -> Self
76    where
77        F: Fn(&dyn InvocationContext) -> bool + Send + Sync + 'static,
78    {
79        let all_agents = vec![if_agent.clone()];
80        Self {
81            name: name.into(),
82            description: String::new(),
83            condition: Arc::new(condition),
84            if_agent,
85            else_agent: None,
86            all_agents,
87            skills_index: None,
88            skill_policy: SelectionPolicy::default(),
89            max_skill_chars: 2000,
90            before_callbacks: Arc::new(Vec::new()),
91            after_callbacks: Arc::new(Vec::new()),
92        }
93    }
94
95    /// Set the agent description.
96    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
97        self.description = desc.into();
98        self
99    }
100
101    /// Set the else-branch agent executed when the condition is false.
102    pub fn with_else(mut self, else_agent: Arc<dyn Agent>) -> Self {
103        self.all_agents.push(else_agent.clone());
104        self.else_agent = Some(else_agent);
105        self
106    }
107
108    /// Add a before-agent callback.
109    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
110        if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
111            callbacks.push(callback);
112        }
113        self
114    }
115
116    /// Add an after-agent callback.
117    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
118        if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
119            callbacks.push(callback);
120        }
121        self
122    }
123
124    /// Set a preloaded skills index for this agent.
125    #[cfg(feature = "skills")]
126    pub fn with_skills(mut self, index: SkillIndex) -> Self {
127        self.skills_index = Some(Arc::new(index));
128        self
129    }
130
131    /// Auto-load skills from `.skills/` in the current working directory.
132    #[cfg(feature = "skills")]
133    pub fn with_auto_skills(self) -> Result<Self> {
134        self.with_skills_from_root(".")
135    }
136
137    /// Auto-load skills from `.skills/` under a custom root directory.
138    #[cfg(feature = "skills")]
139    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
140        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
141        self.skills_index = Some(Arc::new(index));
142        Ok(self)
143    }
144
145    /// Customize skill selection behavior.
146    #[cfg(feature = "skills")]
147    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
148        self.skill_policy = policy;
149        self
150    }
151
152    /// Limit injected skill content length.
153    #[cfg(feature = "skills")]
154    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
155        self.max_skill_chars = max_chars;
156        self
157    }
158}
159
160#[async_trait]
161impl Agent for ConditionalAgent {
162    fn name(&self) -> &str {
163        &self.name
164    }
165
166    fn description(&self) -> &str {
167        &self.description
168    }
169
170    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
171        &self.all_agents
172    }
173
174    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
175        let run_ctx = super::skill_context::with_skill_injected_context(
176            ctx,
177            self.skills_index.as_ref(),
178            &self.skill_policy,
179            self.max_skill_chars,
180        );
181        let before_callbacks = self.before_callbacks.clone();
182        let after_callbacks = self.after_callbacks.clone();
183        let if_agent = self.if_agent.clone();
184        let else_agent = self.else_agent.clone();
185        let agent_name = self.name.clone();
186        let invocation_id = run_ctx.invocation_id().to_string();
187        let condition = self.condition.clone();
188
189        let s = stream! {
190            for callback in before_callbacks.as_ref() {
191                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
192                    Ok(Some(content)) => {
193                        let mut early_event = Event::new(&invocation_id);
194                        early_event.author = agent_name.clone();
195                        early_event.llm_response.content = Some(content);
196                        yield Ok(early_event);
197
198                        for after_callback in after_callbacks.as_ref() {
199                            match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
200                                Ok(Some(after_content)) => {
201                                    let mut after_event = Event::new(&invocation_id);
202                                    after_event.author = agent_name.clone();
203                                    after_event.llm_response.content = Some(after_content);
204                                    yield Ok(after_event);
205                                    return;
206                                }
207                                Ok(None) => continue,
208                                Err(e) => {
209                                    yield Err(e);
210                                    return;
211                                }
212                            }
213                        }
214                        return;
215                    }
216                    Ok(None) => continue,
217                    Err(e) => {
218                        yield Err(e);
219                        return;
220                    }
221                }
222            }
223
224            let target_agent = if condition(run_ctx.as_ref()) {
225                Some(if_agent)
226            } else {
227                else_agent
228            };
229
230            if let Some(agent) = target_agent {
231                let mut stream = match agent.run(run_ctx.clone()).await {
232                    Ok(stream) => stream,
233                    Err(e) => {
234                        yield Err(e);
235                        return;
236                    }
237                };
238
239                while let Some(result) = stream.next().await {
240                    match result {
241                        Ok(event) => yield Ok(event),
242                        Err(e) => {
243                            yield Err(e);
244                            return;
245                        }
246                    }
247                }
248            }
249
250            for callback in after_callbacks.as_ref() {
251                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
252                    Ok(Some(content)) => {
253                        let mut after_event = Event::new(&invocation_id);
254                        after_event.author = agent_name.clone();
255                        after_event.llm_response.content = Some(content);
256                        yield Ok(after_event);
257                        break;
258                    }
259                    Ok(None) => continue,
260                    Err(e) => {
261                        yield Err(e);
262                        return;
263                    }
264                }
265            }
266        };
267
268        Ok(Box::pin(s))
269    }
270}