adk_agent/workflow/
conditional_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, EventStream, InvocationContext, Result,
3};
4use async_trait::async_trait;
5use std::sync::Arc;
6
7type ConditionFn = Box<dyn Fn(&dyn InvocationContext) -> bool + Send + Sync>;
8
9/// Conditional agent runs different sub-agents based on a condition
10pub struct ConditionalAgent {
11    name: String,
12    description: String,
13    condition: ConditionFn,
14    if_agent: Arc<dyn Agent>,
15    else_agent: Option<Arc<dyn Agent>>,
16    before_callbacks: Vec<BeforeAgentCallback>,
17    after_callbacks: Vec<AfterAgentCallback>,
18}
19
20impl ConditionalAgent {
21    pub fn new<F>(name: impl Into<String>, condition: F, if_agent: Arc<dyn Agent>) -> Self
22    where
23        F: Fn(&dyn InvocationContext) -> bool + Send + Sync + 'static,
24    {
25        Self {
26            name: name.into(),
27            description: String::new(),
28            condition: Box::new(condition),
29            if_agent,
30            else_agent: None,
31            before_callbacks: Vec::new(),
32            after_callbacks: Vec::new(),
33        }
34    }
35
36    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
37        self.description = desc.into();
38        self
39    }
40
41    pub fn with_else(mut self, else_agent: Arc<dyn Agent>) -> Self {
42        self.else_agent = Some(else_agent);
43        self
44    }
45
46    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
47        self.before_callbacks.push(callback);
48        self
49    }
50
51    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
52        self.after_callbacks.push(callback);
53        self
54    }
55}
56
57#[async_trait]
58impl Agent for ConditionalAgent {
59    fn name(&self) -> &str {
60        &self.name
61    }
62
63    fn description(&self) -> &str {
64        &self.description
65    }
66
67    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
68        &[]
69    }
70
71    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
72        let agent = if (self.condition)(ctx.as_ref()) {
73            self.if_agent.clone()
74        } else if let Some(else_agent) = &self.else_agent {
75            else_agent.clone()
76        } else {
77            return Ok(Box::pin(futures::stream::empty()));
78        };
79
80        agent.run(ctx).await
81    }
82}