adk_agent/workflow/
conditional_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, EventStream, InvocationContext, Result,
3};
4use async_trait::async_trait;
5use std::sync::Arc;
6
7type ConditionFn = Box<dyn Fn(&dyn InvocationContext) -> bool + Send + Sync>;
8
9pub struct ConditionalAgent {
11 name: String,
12 description: String,
13 condition: ConditionFn,
14 if_agent: Arc<dyn Agent>,
15 else_agent: Option<Arc<dyn Agent>>,
16 before_callbacks: Vec<BeforeAgentCallback>,
17 after_callbacks: Vec<AfterAgentCallback>,
18}
19
20impl ConditionalAgent {
21 pub fn new<F>(name: impl Into<String>, condition: F, if_agent: Arc<dyn Agent>) -> Self
22 where
23 F: Fn(&dyn InvocationContext) -> bool + Send + Sync + 'static,
24 {
25 Self {
26 name: name.into(),
27 description: String::new(),
28 condition: Box::new(condition),
29 if_agent,
30 else_agent: None,
31 before_callbacks: Vec::new(),
32 after_callbacks: Vec::new(),
33 }
34 }
35
36 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
37 self.description = desc.into();
38 self
39 }
40
41 pub fn with_else(mut self, else_agent: Arc<dyn Agent>) -> Self {
42 self.else_agent = Some(else_agent);
43 self
44 }
45
46 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
47 self.before_callbacks.push(callback);
48 self
49 }
50
51 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
52 self.after_callbacks.push(callback);
53 self
54 }
55}
56
57#[async_trait]
58impl Agent for ConditionalAgent {
59 fn name(&self) -> &str {
60 &self.name
61 }
62
63 fn description(&self) -> &str {
64 &self.description
65 }
66
67 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
68 &[]
69 }
70
71 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
72 let agent = if (self.condition)(ctx.as_ref()) {
73 self.if_agent.clone()
74 } else if let Some(else_agent) = &self.else_agent {
75 else_agent.clone()
76 } else {
77 return Ok(Box::pin(futures::stream::empty()));
78 };
79
80 agent.run(ctx).await
81 }
82}