adk_agent/workflow/
conditional_agent.rs1use adk_core::{
30 AfterAgentCallback, Agent, BeforeAgentCallback, EventStream, InvocationContext, Result,
31};
32use async_trait::async_trait;
33use std::sync::Arc;
34
35type ConditionFn = Box<dyn Fn(&dyn InvocationContext) -> bool + Send + Sync>;
36
37pub struct ConditionalAgent {
53 name: String,
54 description: String,
55 condition: ConditionFn,
56 if_agent: Arc<dyn Agent>,
57 else_agent: Option<Arc<dyn Agent>>,
58 before_callbacks: Vec<BeforeAgentCallback>,
59 after_callbacks: Vec<AfterAgentCallback>,
60}
61
62impl ConditionalAgent {
63 pub fn new<F>(name: impl Into<String>, condition: F, if_agent: Arc<dyn Agent>) -> Self
64 where
65 F: Fn(&dyn InvocationContext) -> bool + Send + Sync + 'static,
66 {
67 Self {
68 name: name.into(),
69 description: String::new(),
70 condition: Box::new(condition),
71 if_agent,
72 else_agent: None,
73 before_callbacks: Vec::new(),
74 after_callbacks: Vec::new(),
75 }
76 }
77
78 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
79 self.description = desc.into();
80 self
81 }
82
83 pub fn with_else(mut self, else_agent: Arc<dyn Agent>) -> Self {
84 self.else_agent = Some(else_agent);
85 self
86 }
87
88 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
89 self.before_callbacks.push(callback);
90 self
91 }
92
93 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
94 self.after_callbacks.push(callback);
95 self
96 }
97}
98
99#[async_trait]
100impl Agent for ConditionalAgent {
101 fn name(&self) -> &str {
102 &self.name
103 }
104
105 fn description(&self) -> &str {
106 &self.description
107 }
108
109 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
110 &[]
111 }
112
113 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
114 let agent = if (self.condition)(ctx.as_ref()) {
115 self.if_agent.clone()
116 } else if let Some(else_agent) = &self.else_agent {
117 else_agent.clone()
118 } else {
119 return Ok(Box::pin(futures::stream::empty()));
120 };
121
122 agent.run(ctx).await
123 }
124}