Skip to main content

adk_agent/
custom_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
3    InvocationContext, Result,
4};
5use async_stream::stream;
6use async_trait::async_trait;
7use futures::StreamExt;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12type RunHandler = Box<
13    dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = Result<EventStream>> + Send>>
14        + Send
15        + Sync,
16>;
17
18pub struct CustomAgent {
19    name: String,
20    description: String,
21    sub_agents: Vec<Arc<dyn Agent>>,
22    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
23    after_callbacks: Arc<Vec<AfterAgentCallback>>,
24    handler: RunHandler,
25}
26
27impl CustomAgent {
28    pub fn builder(name: impl Into<String>) -> CustomAgentBuilder {
29        CustomAgentBuilder::new(name)
30    }
31}
32
33#[async_trait]
34impl Agent for CustomAgent {
35    fn name(&self) -> &str {
36        &self.name
37    }
38
39    fn description(&self) -> &str {
40        &self.description
41    }
42
43    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
44        &self.sub_agents
45    }
46
47    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
48        let handler = &self.handler;
49        let before_callbacks = self.before_callbacks.clone();
50        let after_callbacks = self.after_callbacks.clone();
51        let agent_name = self.name.clone();
52
53        // Execute before callbacks — if any returns content, short-circuit
54        for callback in before_callbacks.as_ref() {
55            match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
56                Ok(Some(content)) => {
57                    let invocation_id = ctx.invocation_id().to_string();
58                    let s = stream! {
59                        let mut early_event = Event::new(&invocation_id);
60                        early_event.author = agent_name.clone();
61                        early_event.llm_response.content = Some(content);
62                        yield Ok(early_event);
63
64                        for after_cb in after_callbacks.as_ref() {
65                            match after_cb(ctx.clone() as Arc<dyn CallbackContext>).await {
66                                Ok(Some(after_content)) => {
67                                    let mut after_event = Event::new(&invocation_id);
68                                    after_event.author = agent_name.clone();
69                                    after_event.llm_response.content = Some(after_content);
70                                    yield Ok(after_event);
71                                    return;
72                                }
73                                Ok(None) => continue,
74                                Err(e) => { yield Err(e); return; }
75                            }
76                        }
77                    };
78                    return Ok(Box::pin(s));
79                }
80                Ok(None) => continue,
81                Err(e) => return Err(e),
82            }
83        }
84
85        // Run the actual handler
86        let mut inner_stream = (handler)(ctx.clone()).await?;
87
88        let s = stream! {
89            while let Some(result) = inner_stream.next().await {
90                yield result;
91            }
92
93            // Execute after callbacks
94            for callback in after_callbacks.as_ref() {
95                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
96                    Ok(Some(content)) => {
97                        let mut after_event = Event::new(ctx.invocation_id());
98                        after_event.author = agent_name.clone();
99                        after_event.llm_response.content = Some(content);
100                        yield Ok(after_event);
101                        break;
102                    }
103                    Ok(None) => continue,
104                    Err(e) => { yield Err(e); return; }
105                }
106            }
107        };
108
109        Ok(Box::pin(s))
110    }
111}
112
113pub struct CustomAgentBuilder {
114    name: String,
115    description: String,
116    sub_agents: Vec<Arc<dyn Agent>>,
117    before_callbacks: Vec<BeforeAgentCallback>,
118    after_callbacks: Vec<AfterAgentCallback>,
119    handler: Option<RunHandler>,
120}
121
122impl CustomAgentBuilder {
123    pub fn new(name: impl Into<String>) -> Self {
124        Self {
125            name: name.into(),
126            description: String::new(),
127            sub_agents: Vec::new(),
128            before_callbacks: Vec::new(),
129            after_callbacks: Vec::new(),
130            handler: None,
131        }
132    }
133
134    pub fn description(mut self, description: impl Into<String>) -> Self {
135        self.description = description.into();
136        self
137    }
138
139    pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
140        self.sub_agents.push(agent);
141        self
142    }
143
144    pub fn sub_agents(mut self, agents: Vec<Arc<dyn Agent>>) -> Self {
145        self.sub_agents = agents;
146        self
147    }
148
149    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
150        self.before_callbacks.push(callback);
151        self
152    }
153
154    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
155        self.after_callbacks.push(callback);
156        self
157    }
158
159    pub fn handler<F, Fut>(mut self, handler: F) -> Self
160    where
161        F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
162        Fut: Future<Output = Result<EventStream>> + Send + 'static,
163    {
164        self.handler = Some(Box::new(move |ctx| Box::pin(handler(ctx))));
165        self
166    }
167
168    pub fn build(self) -> Result<CustomAgent> {
169        let handler = self.handler.ok_or_else(|| {
170            adk_core::AdkError::Agent("CustomAgent requires a handler".to_string())
171        })?;
172
173        // Validate sub-agents have unique names
174        let mut seen_names = std::collections::HashSet::new();
175        for agent in &self.sub_agents {
176            if !seen_names.insert(agent.name()) {
177                return Err(adk_core::AdkError::Agent(format!(
178                    "Duplicate sub-agent name: {}",
179                    agent.name()
180                )));
181            }
182        }
183
184        Ok(CustomAgent {
185            name: self.name,
186            description: self.description,
187            sub_agents: self.sub_agents,
188            before_callbacks: Arc::new(self.before_callbacks),
189            after_callbacks: Arc::new(self.after_callbacks),
190            handler,
191        })
192    }
193}