adk_agent/
custom_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
3 InvocationContext, Result,
4};
5use async_stream::stream;
6use async_trait::async_trait;
7use futures::StreamExt;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12type RunHandler = Box<
13 dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = Result<EventStream>> + Send>>
14 + Send
15 + Sync,
16>;
17
18pub struct CustomAgent {
24 name: String,
25 description: String,
26 sub_agents: Vec<Arc<dyn Agent>>,
27 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
28 after_callbacks: Arc<Vec<AfterAgentCallback>>,
29 handler: RunHandler,
30}
31
32impl CustomAgent {
33 pub fn builder(name: impl Into<String>) -> CustomAgentBuilder {
35 CustomAgentBuilder::new(name)
36 }
37}
38
39#[async_trait]
40impl Agent for CustomAgent {
41 fn name(&self) -> &str {
42 &self.name
43 }
44
45 fn description(&self) -> &str {
46 &self.description
47 }
48
49 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
50 &self.sub_agents
51 }
52
53 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
54 let handler = &self.handler;
55 let before_callbacks = self.before_callbacks.clone();
56 let after_callbacks = self.after_callbacks.clone();
57 let agent_name = self.name.clone();
58
59 for callback in before_callbacks.as_ref() {
61 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
62 Ok(Some(content)) => {
63 let invocation_id = ctx.invocation_id().to_string();
64 let s = stream! {
65 let mut early_event = Event::new(&invocation_id);
66 early_event.author = agent_name.clone();
67 early_event.llm_response.content = Some(content);
68 yield Ok(early_event);
69
70 for after_cb in after_callbacks.as_ref() {
71 match after_cb(ctx.clone() as Arc<dyn CallbackContext>).await {
72 Ok(Some(after_content)) => {
73 let mut after_event = Event::new(&invocation_id);
74 after_event.author = agent_name.clone();
75 after_event.llm_response.content = Some(after_content);
76 yield Ok(after_event);
77 return;
78 }
79 Ok(None) => continue,
80 Err(e) => { yield Err(e); return; }
81 }
82 }
83 };
84 return Ok(Box::pin(s));
85 }
86 Ok(None) => continue,
87 Err(e) => return Err(e),
88 }
89 }
90
91 let mut inner_stream = (handler)(ctx.clone()).await?;
93
94 let s = stream! {
95 while let Some(result) = inner_stream.next().await {
96 yield result;
97 }
98
99 for callback in after_callbacks.as_ref() {
101 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
102 Ok(Some(content)) => {
103 let mut after_event = Event::new(ctx.invocation_id());
104 after_event.author = agent_name.clone();
105 after_event.llm_response.content = Some(content);
106 yield Ok(after_event);
107 break;
108 }
109 Ok(None) => continue,
110 Err(e) => { yield Err(e); return; }
111 }
112 }
113 };
114
115 Ok(Box::pin(s))
116 }
117}
118
119pub struct CustomAgentBuilder {
121 name: String,
122 description: String,
123 sub_agents: Vec<Arc<dyn Agent>>,
124 before_callbacks: Vec<BeforeAgentCallback>,
125 after_callbacks: Vec<AfterAgentCallback>,
126 handler: Option<RunHandler>,
127}
128
129impl CustomAgentBuilder {
130 pub fn new(name: impl Into<String>) -> Self {
132 Self {
133 name: name.into(),
134 description: String::new(),
135 sub_agents: Vec::new(),
136 before_callbacks: Vec::new(),
137 after_callbacks: Vec::new(),
138 handler: None,
139 }
140 }
141
142 pub fn description(mut self, description: impl Into<String>) -> Self {
144 self.description = description.into();
145 self
146 }
147
148 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
150 self.sub_agents.push(agent);
151 self
152 }
153
154 pub fn sub_agents(mut self, agents: Vec<Arc<dyn Agent>>) -> Self {
156 self.sub_agents = agents;
157 self
158 }
159
160 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
162 self.before_callbacks.push(callback);
163 self
164 }
165
166 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
168 self.after_callbacks.push(callback);
169 self
170 }
171
172 pub fn handler<F, Fut>(mut self, handler: F) -> Self
174 where
175 F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
176 Fut: Future<Output = Result<EventStream>> + Send + 'static,
177 {
178 self.handler = Some(Box::new(move |ctx| Box::pin(handler(ctx))));
179 self
180 }
181
182 pub fn build(self) -> Result<CustomAgent> {
184 let handler = self
185 .handler
186 .ok_or_else(|| adk_core::AdkError::agent("CustomAgent requires a handler"))?;
187
188 let mut seen_names = std::collections::HashSet::new();
190 for agent in &self.sub_agents {
191 if !seen_names.insert(agent.name()) {
192 return Err(adk_core::AdkError::agent(format!(
193 "Duplicate sub-agent name: {}",
194 agent.name()
195 )));
196 }
197 }
198
199 Ok(CustomAgent {
200 name: self.name,
201 description: self.description,
202 sub_agents: self.sub_agents,
203 before_callbacks: Arc::new(self.before_callbacks),
204 after_callbacks: Arc::new(self.after_callbacks),
205 handler,
206 })
207 }
208}