Skip to main content

adk_agent/workflow/
conditional_agent.rs

1//! Rule-based conditional routing agent.
2//!
3//! `ConditionalAgent` provides **synchronous, rule-based** conditional routing.
4//! The condition function is evaluated synchronously and must return a boolean.
5//!
6//! # When to Use
7//!
8//! Use `ConditionalAgent` for **deterministic** routing decisions:
9//! - A/B testing based on session state or flags
10//! - Environment-based routing (e.g., production vs staging)
11//! - Feature flag checks
12//!
13//! # For Intelligent Routing
14//!
15//! If you need **LLM-based intelligent routing** where the model classifies
16//! user intent and routes accordingly, use [`LlmConditionalAgent`] instead:
17//!
18//! ```rust,ignore
19//! // LLM decides which agent to route to
20//! let router = LlmConditionalAgent::builder("router", model)
21//!     .instruction("Classify as 'technical' or 'general'")
22//!     .route("technical", tech_agent)
23//!     .route("general", general_agent)
24//!     .build()?;
25//! ```
26//!
27//! See [`crate::workflow::LlmConditionalAgent`] for details.
28
29use adk_core::{
30    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
31    InvocationContext, Result,
32};
33use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
34use async_stream::stream;
35use async_trait::async_trait;
36use futures::StreamExt;
37use std::sync::Arc;
38
39type ConditionFn = Arc<dyn Fn(&dyn InvocationContext) -> bool + Send + Sync>;
40
41/// Rule-based conditional routing agent.
42///
43/// Executes one of two sub-agents based on a synchronous condition function.
44/// For LLM-based intelligent routing, use [`crate::LlmConditionalAgent`] instead.
45///
46/// # Example
47///
48/// ```rust,ignore
49/// // Route based on session state flag
50/// let router = ConditionalAgent::new(
51///     "premium_router",
52///     |ctx| ctx.session().state().get("is_premium").map(|v| v.as_bool()).flatten().unwrap_or(false),
53///     Arc::new(premium_agent),
54/// ).with_else(Arc::new(basic_agent));
55/// ```
56pub struct ConditionalAgent {
57    name: String,
58    description: String,
59    condition: ConditionFn,
60    if_agent: Arc<dyn Agent>,
61    else_agent: Option<Arc<dyn Agent>>,
62    /// Cached list of all branch agents for tree discovery via `sub_agents()`.
63    all_agents: Vec<Arc<dyn Agent>>,
64    skills_index: Option<Arc<SkillIndex>>,
65    skill_policy: SelectionPolicy,
66    max_skill_chars: usize,
67    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
68    after_callbacks: Arc<Vec<AfterAgentCallback>>,
69}
70
71impl ConditionalAgent {
72    pub fn new<F>(name: impl Into<String>, condition: F, if_agent: Arc<dyn Agent>) -> Self
73    where
74        F: Fn(&dyn InvocationContext) -> bool + Send + Sync + 'static,
75    {
76        let all_agents = vec![if_agent.clone()];
77        Self {
78            name: name.into(),
79            description: String::new(),
80            condition: Arc::new(condition),
81            if_agent,
82            else_agent: None,
83            all_agents,
84            skills_index: None,
85            skill_policy: SelectionPolicy::default(),
86            max_skill_chars: 2000,
87            before_callbacks: Arc::new(Vec::new()),
88            after_callbacks: Arc::new(Vec::new()),
89        }
90    }
91
92    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
93        self.description = desc.into();
94        self
95    }
96
97    pub fn with_else(mut self, else_agent: Arc<dyn Agent>) -> Self {
98        self.all_agents.push(else_agent.clone());
99        self.else_agent = Some(else_agent);
100        self
101    }
102
103    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
104        Arc::get_mut(&mut self.before_callbacks)
105            .expect("before_callbacks not yet shared")
106            .push(callback);
107        self
108    }
109
110    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
111        Arc::get_mut(&mut self.after_callbacks)
112            .expect("after_callbacks not yet shared")
113            .push(callback);
114        self
115    }
116
117    pub fn with_skills(mut self, index: SkillIndex) -> Self {
118        self.skills_index = Some(Arc::new(index));
119        self
120    }
121
122    pub fn with_auto_skills(self) -> Result<Self> {
123        self.with_skills_from_root(".")
124    }
125
126    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
127        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
128        self.skills_index = Some(Arc::new(index));
129        Ok(self)
130    }
131
132    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
133        self.skill_policy = policy;
134        self
135    }
136
137    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
138        self.max_skill_chars = max_chars;
139        self
140    }
141}
142
143#[async_trait]
144impl Agent for ConditionalAgent {
145    fn name(&self) -> &str {
146        &self.name
147    }
148
149    fn description(&self) -> &str {
150        &self.description
151    }
152
153    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
154        &self.all_agents
155    }
156
157    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
158        let run_ctx = super::skill_context::with_skill_injected_context(
159            ctx,
160            self.skills_index.as_ref(),
161            &self.skill_policy,
162            self.max_skill_chars,
163        );
164        let before_callbacks = self.before_callbacks.clone();
165        let after_callbacks = self.after_callbacks.clone();
166        let if_agent = self.if_agent.clone();
167        let else_agent = self.else_agent.clone();
168        let agent_name = self.name.clone();
169        let invocation_id = run_ctx.invocation_id().to_string();
170        let condition = self.condition.clone();
171
172        let s = stream! {
173            for callback in before_callbacks.as_ref() {
174                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
175                    Ok(Some(content)) => {
176                        let mut early_event = Event::new(&invocation_id);
177                        early_event.author = agent_name.clone();
178                        early_event.llm_response.content = Some(content);
179                        yield Ok(early_event);
180
181                        for after_callback in after_callbacks.as_ref() {
182                            match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
183                                Ok(Some(after_content)) => {
184                                    let mut after_event = Event::new(&invocation_id);
185                                    after_event.author = agent_name.clone();
186                                    after_event.llm_response.content = Some(after_content);
187                                    yield Ok(after_event);
188                                    return;
189                                }
190                                Ok(None) => continue,
191                                Err(e) => {
192                                    yield Err(e);
193                                    return;
194                                }
195                            }
196                        }
197                        return;
198                    }
199                    Ok(None) => continue,
200                    Err(e) => {
201                        yield Err(e);
202                        return;
203                    }
204                }
205            }
206
207            let target_agent = if condition(run_ctx.as_ref()) {
208                Some(if_agent)
209            } else {
210                else_agent
211            };
212
213            if let Some(agent) = target_agent {
214                let mut stream = match agent.run(run_ctx.clone()).await {
215                    Ok(stream) => stream,
216                    Err(e) => {
217                        yield Err(e);
218                        return;
219                    }
220                };
221
222                while let Some(result) = stream.next().await {
223                    match result {
224                        Ok(event) => yield Ok(event),
225                        Err(e) => {
226                            yield Err(e);
227                            return;
228                        }
229                    }
230                }
231            }
232
233            for callback in after_callbacks.as_ref() {
234                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
235                    Ok(Some(content)) => {
236                        let mut after_event = Event::new(&invocation_id);
237                        after_event.author = agent_name.clone();
238                        after_event.llm_response.content = Some(content);
239                        yield Ok(after_event);
240                        break;
241                    }
242                    Ok(None) => continue,
243                    Err(e) => {
244                        yield Err(e);
245                        return;
246                    }
247                }
248            }
249        };
250
251        Ok(Box::pin(s))
252    }
253}