adk_agent/workflow/
conditional_agent.rs1#[cfg(feature = "skills")]
30use crate::skill_shim::load_skill_index;
31use crate::skill_shim::{SelectionPolicy, SkillIndex};
32use adk_core::{
33 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
34 InvocationContext, Result,
35};
36use async_stream::stream;
37use async_trait::async_trait;
38use futures::StreamExt;
39use std::sync::Arc;
40
41type ConditionFn = Arc<dyn Fn(&dyn InvocationContext) -> bool + Send + Sync>;
42
43pub struct ConditionalAgent {
59 name: String,
60 description: String,
61 condition: ConditionFn,
62 if_agent: Arc<dyn Agent>,
63 else_agent: Option<Arc<dyn Agent>>,
64 all_agents: Vec<Arc<dyn Agent>>,
66 skills_index: Option<Arc<SkillIndex>>,
67 skill_policy: SelectionPolicy,
68 max_skill_chars: usize,
69 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
70 after_callbacks: Arc<Vec<AfterAgentCallback>>,
71}
72
73impl ConditionalAgent {
74 pub fn new<F>(name: impl Into<String>, condition: F, if_agent: Arc<dyn Agent>) -> Self
75 where
76 F: Fn(&dyn InvocationContext) -> bool + Send + Sync + 'static,
77 {
78 let all_agents = vec![if_agent.clone()];
79 Self {
80 name: name.into(),
81 description: String::new(),
82 condition: Arc::new(condition),
83 if_agent,
84 else_agent: None,
85 all_agents,
86 skills_index: None,
87 skill_policy: SelectionPolicy::default(),
88 max_skill_chars: 2000,
89 before_callbacks: Arc::new(Vec::new()),
90 after_callbacks: Arc::new(Vec::new()),
91 }
92 }
93
94 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
95 self.description = desc.into();
96 self
97 }
98
99 pub fn with_else(mut self, else_agent: Arc<dyn Agent>) -> Self {
100 self.all_agents.push(else_agent.clone());
101 self.else_agent = Some(else_agent);
102 self
103 }
104
105 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
106 if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
107 callbacks.push(callback);
108 }
109 self
110 }
111
112 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
113 if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
114 callbacks.push(callback);
115 }
116 self
117 }
118
119 #[cfg(feature = "skills")]
120 pub fn with_skills(mut self, index: SkillIndex) -> Self {
121 self.skills_index = Some(Arc::new(index));
122 self
123 }
124
125 #[cfg(feature = "skills")]
126 pub fn with_auto_skills(self) -> Result<Self> {
127 self.with_skills_from_root(".")
128 }
129
130 #[cfg(feature = "skills")]
131 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
132 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
133 self.skills_index = Some(Arc::new(index));
134 Ok(self)
135 }
136
137 #[cfg(feature = "skills")]
138 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
139 self.skill_policy = policy;
140 self
141 }
142
143 #[cfg(feature = "skills")]
144 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
145 self.max_skill_chars = max_chars;
146 self
147 }
148}
149
150#[async_trait]
151impl Agent for ConditionalAgent {
152 fn name(&self) -> &str {
153 &self.name
154 }
155
156 fn description(&self) -> &str {
157 &self.description
158 }
159
160 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
161 &self.all_agents
162 }
163
164 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
165 let run_ctx = super::skill_context::with_skill_injected_context(
166 ctx,
167 self.skills_index.as_ref(),
168 &self.skill_policy,
169 self.max_skill_chars,
170 );
171 let before_callbacks = self.before_callbacks.clone();
172 let after_callbacks = self.after_callbacks.clone();
173 let if_agent = self.if_agent.clone();
174 let else_agent = self.else_agent.clone();
175 let agent_name = self.name.clone();
176 let invocation_id = run_ctx.invocation_id().to_string();
177 let condition = self.condition.clone();
178
179 let s = stream! {
180 for callback in before_callbacks.as_ref() {
181 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
182 Ok(Some(content)) => {
183 let mut early_event = Event::new(&invocation_id);
184 early_event.author = agent_name.clone();
185 early_event.llm_response.content = Some(content);
186 yield Ok(early_event);
187
188 for after_callback in after_callbacks.as_ref() {
189 match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
190 Ok(Some(after_content)) => {
191 let mut after_event = Event::new(&invocation_id);
192 after_event.author = agent_name.clone();
193 after_event.llm_response.content = Some(after_content);
194 yield Ok(after_event);
195 return;
196 }
197 Ok(None) => continue,
198 Err(e) => {
199 yield Err(e);
200 return;
201 }
202 }
203 }
204 return;
205 }
206 Ok(None) => continue,
207 Err(e) => {
208 yield Err(e);
209 return;
210 }
211 }
212 }
213
214 let target_agent = if condition(run_ctx.as_ref()) {
215 Some(if_agent)
216 } else {
217 else_agent
218 };
219
220 if let Some(agent) = target_agent {
221 let mut stream = match agent.run(run_ctx.clone()).await {
222 Ok(stream) => stream,
223 Err(e) => {
224 yield Err(e);
225 return;
226 }
227 };
228
229 while let Some(result) = stream.next().await {
230 match result {
231 Ok(event) => yield Ok(event),
232 Err(e) => {
233 yield Err(e);
234 return;
235 }
236 }
237 }
238 }
239
240 for callback in after_callbacks.as_ref() {
241 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
242 Ok(Some(content)) => {
243 let mut after_event = Event::new(&invocation_id);
244 after_event.author = agent_name.clone();
245 after_event.llm_response.content = Some(content);
246 yield Ok(after_event);
247 break;
248 }
249 Ok(None) => continue,
250 Err(e) => {
251 yield Err(e);
252 return;
253 }
254 }
255 }
256 };
257
258 Ok(Box::pin(s))
259 }
260}