adk_agent/workflow/
conditional_agent.rs1use adk_core::{
30 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
31 InvocationContext, Result,
32};
33use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
34use async_stream::stream;
35use async_trait::async_trait;
36use futures::StreamExt;
37use std::sync::Arc;
38
39type ConditionFn = Arc<dyn Fn(&dyn InvocationContext) -> bool + Send + Sync>;
40
41pub struct ConditionalAgent {
57 name: String,
58 description: String,
59 condition: ConditionFn,
60 if_agent: Arc<dyn Agent>,
61 else_agent: Option<Arc<dyn Agent>>,
62 all_agents: Vec<Arc<dyn Agent>>,
64 skills_index: Option<Arc<SkillIndex>>,
65 skill_policy: SelectionPolicy,
66 max_skill_chars: usize,
67 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
68 after_callbacks: Arc<Vec<AfterAgentCallback>>,
69}
70
71impl ConditionalAgent {
72 pub fn new<F>(name: impl Into<String>, condition: F, if_agent: Arc<dyn Agent>) -> Self
73 where
74 F: Fn(&dyn InvocationContext) -> bool + Send + Sync + 'static,
75 {
76 let all_agents = vec![if_agent.clone()];
77 Self {
78 name: name.into(),
79 description: String::new(),
80 condition: Arc::new(condition),
81 if_agent,
82 else_agent: None,
83 all_agents,
84 skills_index: None,
85 skill_policy: SelectionPolicy::default(),
86 max_skill_chars: 2000,
87 before_callbacks: Arc::new(Vec::new()),
88 after_callbacks: Arc::new(Vec::new()),
89 }
90 }
91
92 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
93 self.description = desc.into();
94 self
95 }
96
97 pub fn with_else(mut self, else_agent: Arc<dyn Agent>) -> Self {
98 self.all_agents.push(else_agent.clone());
99 self.else_agent = Some(else_agent);
100 self
101 }
102
103 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
104 Arc::get_mut(&mut self.before_callbacks)
105 .expect("before_callbacks not yet shared")
106 .push(callback);
107 self
108 }
109
110 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
111 Arc::get_mut(&mut self.after_callbacks)
112 .expect("after_callbacks not yet shared")
113 .push(callback);
114 self
115 }
116
117 pub fn with_skills(mut self, index: SkillIndex) -> Self {
118 self.skills_index = Some(Arc::new(index));
119 self
120 }
121
122 pub fn with_auto_skills(self) -> Result<Self> {
123 self.with_skills_from_root(".")
124 }
125
126 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
127 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
128 self.skills_index = Some(Arc::new(index));
129 Ok(self)
130 }
131
132 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
133 self.skill_policy = policy;
134 self
135 }
136
137 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
138 self.max_skill_chars = max_chars;
139 self
140 }
141}
142
143#[async_trait]
144impl Agent for ConditionalAgent {
145 fn name(&self) -> &str {
146 &self.name
147 }
148
149 fn description(&self) -> &str {
150 &self.description
151 }
152
153 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
154 &self.all_agents
155 }
156
157 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
158 let run_ctx = super::skill_context::with_skill_injected_context(
159 ctx,
160 self.skills_index.as_ref(),
161 &self.skill_policy,
162 self.max_skill_chars,
163 );
164 let before_callbacks = self.before_callbacks.clone();
165 let after_callbacks = self.after_callbacks.clone();
166 let if_agent = self.if_agent.clone();
167 let else_agent = self.else_agent.clone();
168 let agent_name = self.name.clone();
169 let invocation_id = run_ctx.invocation_id().to_string();
170 let condition = self.condition.clone();
171
172 let s = stream! {
173 for callback in before_callbacks.as_ref() {
174 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
175 Ok(Some(content)) => {
176 let mut early_event = Event::new(&invocation_id);
177 early_event.author = agent_name.clone();
178 early_event.llm_response.content = Some(content);
179 yield Ok(early_event);
180
181 for after_callback in after_callbacks.as_ref() {
182 match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
183 Ok(Some(after_content)) => {
184 let mut after_event = Event::new(&invocation_id);
185 after_event.author = agent_name.clone();
186 after_event.llm_response.content = Some(after_content);
187 yield Ok(after_event);
188 return;
189 }
190 Ok(None) => continue,
191 Err(e) => {
192 yield Err(e);
193 return;
194 }
195 }
196 }
197 return;
198 }
199 Ok(None) => continue,
200 Err(e) => {
201 yield Err(e);
202 return;
203 }
204 }
205 }
206
207 let target_agent = if condition(run_ctx.as_ref()) {
208 Some(if_agent)
209 } else {
210 else_agent
211 };
212
213 if let Some(agent) = target_agent {
214 let mut stream = match agent.run(run_ctx.clone()).await {
215 Ok(stream) => stream,
216 Err(e) => {
217 yield Err(e);
218 return;
219 }
220 };
221
222 while let Some(result) = stream.next().await {
223 match result {
224 Ok(event) => yield Ok(event),
225 Err(e) => {
226 yield Err(e);
227 return;
228 }
229 }
230 }
231 }
232
233 for callback in after_callbacks.as_ref() {
234 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
235 Ok(Some(content)) => {
236 let mut after_event = Event::new(&invocation_id);
237 after_event.author = agent_name.clone();
238 after_event.llm_response.content = Some(content);
239 yield Ok(after_event);
240 break;
241 }
242 Ok(None) => continue,
243 Err(e) => {
244 yield Err(e);
245 return;
246 }
247 }
248 }
249 };
250
251 Ok(Box::pin(s))
252 }
253}