adk_agent/workflow/
conditional_agent.rs1#[cfg(feature = "skills")]
30use crate::skill_shim::load_skill_index;
31use crate::skill_shim::{SelectionPolicy, SkillIndex};
32use adk_core::{
33 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
34 InvocationContext, Result,
35};
36use async_stream::stream;
37use async_trait::async_trait;
38use futures::StreamExt;
39use std::sync::Arc;
40
41type ConditionFn = Arc<dyn Fn(&dyn InvocationContext) -> bool + Send + Sync>;
42
43pub struct ConditionalAgent {
59 name: String,
60 description: String,
61 condition: ConditionFn,
62 if_agent: Arc<dyn Agent>,
63 else_agent: Option<Arc<dyn Agent>>,
64 all_agents: Vec<Arc<dyn Agent>>,
66 skills_index: Option<Arc<SkillIndex>>,
67 skill_policy: SelectionPolicy,
68 max_skill_chars: usize,
69 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
70 after_callbacks: Arc<Vec<AfterAgentCallback>>,
71}
72
73impl ConditionalAgent {
74 pub fn new<F>(name: impl Into<String>, condition: F, if_agent: Arc<dyn Agent>) -> Self
76 where
77 F: Fn(&dyn InvocationContext) -> bool + Send + Sync + 'static,
78 {
79 let all_agents = vec![if_agent.clone()];
80 Self {
81 name: name.into(),
82 description: String::new(),
83 condition: Arc::new(condition),
84 if_agent,
85 else_agent: None,
86 all_agents,
87 skills_index: None,
88 skill_policy: SelectionPolicy::default(),
89 max_skill_chars: 2000,
90 before_callbacks: Arc::new(Vec::new()),
91 after_callbacks: Arc::new(Vec::new()),
92 }
93 }
94
95 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
97 self.description = desc.into();
98 self
99 }
100
101 pub fn with_else(mut self, else_agent: Arc<dyn Agent>) -> Self {
103 self.all_agents.push(else_agent.clone());
104 self.else_agent = Some(else_agent);
105 self
106 }
107
108 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
110 if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
111 callbacks.push(callback);
112 }
113 self
114 }
115
116 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
118 if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
119 callbacks.push(callback);
120 }
121 self
122 }
123
124 #[cfg(feature = "skills")]
126 pub fn with_skills(mut self, index: SkillIndex) -> Self {
127 self.skills_index = Some(Arc::new(index));
128 self
129 }
130
131 #[cfg(feature = "skills")]
133 pub fn with_auto_skills(self) -> Result<Self> {
134 self.with_skills_from_root(".")
135 }
136
137 #[cfg(feature = "skills")]
139 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
140 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
141 self.skills_index = Some(Arc::new(index));
142 Ok(self)
143 }
144
145 #[cfg(feature = "skills")]
147 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
148 self.skill_policy = policy;
149 self
150 }
151
152 #[cfg(feature = "skills")]
154 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
155 self.max_skill_chars = max_chars;
156 self
157 }
158}
159
160#[async_trait]
161impl Agent for ConditionalAgent {
162 fn name(&self) -> &str {
163 &self.name
164 }
165
166 fn description(&self) -> &str {
167 &self.description
168 }
169
170 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
171 &self.all_agents
172 }
173
174 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
175 let run_ctx = super::skill_context::with_skill_injected_context(
176 ctx,
177 self.skills_index.as_ref(),
178 &self.skill_policy,
179 self.max_skill_chars,
180 );
181 let before_callbacks = self.before_callbacks.clone();
182 let after_callbacks = self.after_callbacks.clone();
183 let if_agent = self.if_agent.clone();
184 let else_agent = self.else_agent.clone();
185 let agent_name = self.name.clone();
186 let invocation_id = run_ctx.invocation_id().to_string();
187 let condition = self.condition.clone();
188
189 let s = stream! {
190 for callback in before_callbacks.as_ref() {
191 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
192 Ok(Some(content)) => {
193 let mut early_event = Event::new(&invocation_id);
194 early_event.author = agent_name.clone();
195 early_event.llm_response.content = Some(content);
196 yield Ok(early_event);
197
198 for after_callback in after_callbacks.as_ref() {
199 match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
200 Ok(Some(after_content)) => {
201 let mut after_event = Event::new(&invocation_id);
202 after_event.author = agent_name.clone();
203 after_event.llm_response.content = Some(after_content);
204 yield Ok(after_event);
205 return;
206 }
207 Ok(None) => continue,
208 Err(e) => {
209 yield Err(e);
210 return;
211 }
212 }
213 }
214 return;
215 }
216 Ok(None) => continue,
217 Err(e) => {
218 yield Err(e);
219 return;
220 }
221 }
222 }
223
224 let target_agent = if condition(run_ctx.as_ref()) {
225 Some(if_agent)
226 } else {
227 else_agent
228 };
229
230 if let Some(agent) = target_agent {
231 let mut stream = match agent.run(run_ctx.clone()).await {
232 Ok(stream) => stream,
233 Err(e) => {
234 yield Err(e);
235 return;
236 }
237 };
238
239 while let Some(result) = stream.next().await {
240 match result {
241 Ok(event) => yield Ok(event),
242 Err(e) => {
243 yield Err(e);
244 return;
245 }
246 }
247 }
248 }
249
250 for callback in after_callbacks.as_ref() {
251 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
252 Ok(Some(content)) => {
253 let mut after_event = Event::new(&invocation_id);
254 after_event.author = agent_name.clone();
255 after_event.llm_response.content = Some(content);
256 yield Ok(after_event);
257 break;
258 }
259 Ok(None) => continue,
260 Err(e) => {
261 yield Err(e);
262 return;
263 }
264 }
265 }
266 };
267
268 Ok(Box::pin(s))
269 }
270}