adk_agent/
custom_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, EventStream, InvocationContext, Result,
3};
4use async_trait::async_trait;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9type RunHandler = Box<
10    dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = Result<EventStream>> + Send>>
11        + Send
12        + Sync,
13>;
14
15pub struct CustomAgent {
16    name: String,
17    description: String,
18    sub_agents: Vec<Arc<dyn Agent>>,
19    #[allow(dead_code)] // Part of public API, callbacks not yet implemented
20    before_callbacks: Vec<BeforeAgentCallback>,
21    #[allow(dead_code)] // Part of public API, callbacks not yet implemented
22    after_callbacks: Vec<AfterAgentCallback>,
23    handler: RunHandler,
24}
25
26impl CustomAgent {
27    pub fn builder(name: impl Into<String>) -> CustomAgentBuilder {
28        CustomAgentBuilder::new(name)
29    }
30}
31
32#[async_trait]
33impl Agent for CustomAgent {
34    fn name(&self) -> &str {
35        &self.name
36    }
37
38    fn description(&self) -> &str {
39        &self.description
40    }
41
42    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
43        &self.sub_agents
44    }
45
46    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
47        (self.handler)(ctx).await
48    }
49}
50
51pub struct CustomAgentBuilder {
52    name: String,
53    description: String,
54    sub_agents: Vec<Arc<dyn Agent>>,
55    before_callbacks: Vec<BeforeAgentCallback>,
56    after_callbacks: Vec<AfterAgentCallback>,
57    handler: Option<RunHandler>,
58}
59
60impl CustomAgentBuilder {
61    pub fn new(name: impl Into<String>) -> Self {
62        Self {
63            name: name.into(),
64            description: String::new(),
65            sub_agents: Vec::new(),
66            before_callbacks: Vec::new(),
67            after_callbacks: Vec::new(),
68            handler: None,
69        }
70    }
71
72    pub fn description(mut self, description: impl Into<String>) -> Self {
73        self.description = description.into();
74        self
75    }
76
77    pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
78        self.sub_agents.push(agent);
79        self
80    }
81
82    pub fn sub_agents(mut self, agents: Vec<Arc<dyn Agent>>) -> Self {
83        self.sub_agents = agents;
84        self
85    }
86
87    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
88        self.before_callbacks.push(callback);
89        self
90    }
91
92    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
93        self.after_callbacks.push(callback);
94        self
95    }
96
97    pub fn handler<F, Fut>(mut self, handler: F) -> Self
98    where
99        F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
100        Fut: Future<Output = Result<EventStream>> + Send + 'static,
101    {
102        self.handler = Some(Box::new(move |ctx| Box::pin(handler(ctx))));
103        self
104    }
105
106    pub fn build(self) -> Result<CustomAgent> {
107        let handler = self.handler.ok_or_else(|| {
108            adk_core::AdkError::Agent("CustomAgent requires a handler".to_string())
109        })?;
110
111        // Validate sub-agents have unique names
112        let mut seen_names = std::collections::HashSet::new();
113        for agent in &self.sub_agents {
114            if !seen_names.insert(agent.name()) {
115                return Err(adk_core::AdkError::Agent(format!(
116                    "Duplicate sub-agent name: {}",
117                    agent.name()
118                )));
119            }
120        }
121
122        Ok(CustomAgent {
123            name: self.name,
124            description: self.description,
125            sub_agents: self.sub_agents,
126            before_callbacks: self.before_callbacks,
127            after_callbacks: self.after_callbacks,
128            handler,
129        })
130    }
131}