Skip to main content

adk_agent/
custom_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
3    InvocationContext, Result,
4};
5use async_stream::stream;
6use async_trait::async_trait;
7use futures::StreamExt;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12type RunHandler = Box<
13    dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = Result<EventStream>> + Send>>
14        + Send
15        + Sync,
16>;
17
18/// An agent with a user-defined async handler function.
19///
20/// `CustomAgent` allows you to implement arbitrary agent logic without
21/// conforming to the LLM request/response loop. Use the builder to configure
22/// the handler, sub-agents, and lifecycle callbacks.
23pub struct CustomAgent {
24    name: String,
25    description: String,
26    sub_agents: Vec<Arc<dyn Agent>>,
27    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
28    after_callbacks: Arc<Vec<AfterAgentCallback>>,
29    handler: RunHandler,
30}
31
32impl CustomAgent {
33    /// Create a new builder for `CustomAgent` with the given name.
34    pub fn builder(name: impl Into<String>) -> CustomAgentBuilder {
35        CustomAgentBuilder::new(name)
36    }
37}
38
39#[async_trait]
40impl Agent for CustomAgent {
41    fn name(&self) -> &str {
42        &self.name
43    }
44
45    fn description(&self) -> &str {
46        &self.description
47    }
48
49    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
50        &self.sub_agents
51    }
52
53    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
54        let handler = &self.handler;
55        let before_callbacks = self.before_callbacks.clone();
56        let after_callbacks = self.after_callbacks.clone();
57        let agent_name = self.name.clone();
58
59        // Execute before callbacks — if any returns content, short-circuit
60        for callback in before_callbacks.as_ref() {
61            match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
62                Ok(Some(content)) => {
63                    let invocation_id = ctx.invocation_id().to_string();
64                    let s = stream! {
65                        let mut early_event = Event::new(&invocation_id);
66                        early_event.author = agent_name.clone();
67                        early_event.llm_response.content = Some(content);
68                        yield Ok(early_event);
69
70                        for after_cb in after_callbacks.as_ref() {
71                            match after_cb(ctx.clone() as Arc<dyn CallbackContext>).await {
72                                Ok(Some(after_content)) => {
73                                    let mut after_event = Event::new(&invocation_id);
74                                    after_event.author = agent_name.clone();
75                                    after_event.llm_response.content = Some(after_content);
76                                    yield Ok(after_event);
77                                    return;
78                                }
79                                Ok(None) => continue,
80                                Err(e) => { yield Err(e); return; }
81                            }
82                        }
83                    };
84                    return Ok(Box::pin(s));
85                }
86                Ok(None) => continue,
87                Err(e) => return Err(e),
88            }
89        }
90
91        // Run the actual handler
92        let mut inner_stream = (handler)(ctx.clone()).await?;
93
94        let s = stream! {
95            while let Some(result) = inner_stream.next().await {
96                yield result;
97            }
98
99            // Execute after callbacks
100            for callback in after_callbacks.as_ref() {
101                match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
102                    Ok(Some(content)) => {
103                        let mut after_event = Event::new(ctx.invocation_id());
104                        after_event.author = agent_name.clone();
105                        after_event.llm_response.content = Some(content);
106                        yield Ok(after_event);
107                        break;
108                    }
109                    Ok(None) => continue,
110                    Err(e) => { yield Err(e); return; }
111                }
112            }
113        };
114
115        Ok(Box::pin(s))
116    }
117}
118
119/// Builder for constructing a [`CustomAgent`].
120pub struct CustomAgentBuilder {
121    name: String,
122    description: String,
123    sub_agents: Vec<Arc<dyn Agent>>,
124    before_callbacks: Vec<BeforeAgentCallback>,
125    after_callbacks: Vec<AfterAgentCallback>,
126    handler: Option<RunHandler>,
127}
128
129impl CustomAgentBuilder {
130    /// Create a new builder with the given agent name.
131    pub fn new(name: impl Into<String>) -> Self {
132        Self {
133            name: name.into(),
134            description: String::new(),
135            sub_agents: Vec::new(),
136            before_callbacks: Vec::new(),
137            after_callbacks: Vec::new(),
138            handler: None,
139        }
140    }
141
142    /// Set the agent description.
143    pub fn description(mut self, description: impl Into<String>) -> Self {
144        self.description = description.into();
145        self
146    }
147
148    /// Add a sub-agent.
149    pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
150        self.sub_agents.push(agent);
151        self
152    }
153
154    /// Set all sub-agents, replacing any previously added.
155    pub fn sub_agents(mut self, agents: Vec<Arc<dyn Agent>>) -> Self {
156        self.sub_agents = agents;
157        self
158    }
159
160    /// Add a before-agent callback.
161    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
162        self.before_callbacks.push(callback);
163        self
164    }
165
166    /// Add an after-agent callback.
167    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
168        self.after_callbacks.push(callback);
169        self
170    }
171
172    /// Set the async handler function that implements the agent's logic.
173    pub fn handler<F, Fut>(mut self, handler: F) -> Self
174    where
175        F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
176        Fut: Future<Output = Result<EventStream>> + Send + 'static,
177    {
178        self.handler = Some(Box::new(move |ctx| Box::pin(handler(ctx))));
179        self
180    }
181
182    /// Build the [`CustomAgent`], returning an error if no handler was set.
183    pub fn build(self) -> Result<CustomAgent> {
184        let handler = self
185            .handler
186            .ok_or_else(|| adk_core::AdkError::agent("CustomAgent requires a handler"))?;
187
188        // Validate sub-agents have unique names
189        let mut seen_names = std::collections::HashSet::new();
190        for agent in &self.sub_agents {
191            if !seen_names.insert(agent.name()) {
192                return Err(adk_core::AdkError::agent(format!(
193                    "Duplicate sub-agent name: {}",
194                    agent.name()
195                )));
196            }
197        }
198
199        Ok(CustomAgent {
200            name: self.name,
201            description: self.description,
202            sub_agents: self.sub_agents,
203            before_callbacks: Arc::new(self.before_callbacks),
204            after_callbacks: Arc::new(self.after_callbacks),
205            handler,
206        })
207    }
208}