adk_agent/
custom_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
3 InvocationContext, Result,
4};
5use async_stream::stream;
6use async_trait::async_trait;
7use futures::StreamExt;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12type RunHandler = Box<
13 dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = Result<EventStream>> + Send>>
14 + Send
15 + Sync,
16>;
17
18pub struct CustomAgent {
19 name: String,
20 description: String,
21 sub_agents: Vec<Arc<dyn Agent>>,
22 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
23 after_callbacks: Arc<Vec<AfterAgentCallback>>,
24 handler: RunHandler,
25}
26
27impl CustomAgent {
28 pub fn builder(name: impl Into<String>) -> CustomAgentBuilder {
29 CustomAgentBuilder::new(name)
30 }
31}
32
33#[async_trait]
34impl Agent for CustomAgent {
35 fn name(&self) -> &str {
36 &self.name
37 }
38
39 fn description(&self) -> &str {
40 &self.description
41 }
42
43 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
44 &self.sub_agents
45 }
46
47 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
48 let handler = &self.handler;
49 let before_callbacks = self.before_callbacks.clone();
50 let after_callbacks = self.after_callbacks.clone();
51 let agent_name = self.name.clone();
52
53 for callback in before_callbacks.as_ref() {
55 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
56 Ok(Some(content)) => {
57 let invocation_id = ctx.invocation_id().to_string();
58 let s = stream! {
59 let mut early_event = Event::new(&invocation_id);
60 early_event.author = agent_name.clone();
61 early_event.llm_response.content = Some(content);
62 yield Ok(early_event);
63
64 for after_cb in after_callbacks.as_ref() {
65 match after_cb(ctx.clone() as Arc<dyn CallbackContext>).await {
66 Ok(Some(after_content)) => {
67 let mut after_event = Event::new(&invocation_id);
68 after_event.author = agent_name.clone();
69 after_event.llm_response.content = Some(after_content);
70 yield Ok(after_event);
71 return;
72 }
73 Ok(None) => continue,
74 Err(e) => { yield Err(e); return; }
75 }
76 }
77 };
78 return Ok(Box::pin(s));
79 }
80 Ok(None) => continue,
81 Err(e) => return Err(e),
82 }
83 }
84
85 let mut inner_stream = (handler)(ctx.clone()).await?;
87
88 let s = stream! {
89 while let Some(result) = inner_stream.next().await {
90 yield result;
91 }
92
93 for callback in after_callbacks.as_ref() {
95 match callback(ctx.clone() as Arc<dyn CallbackContext>).await {
96 Ok(Some(content)) => {
97 let mut after_event = Event::new(ctx.invocation_id());
98 after_event.author = agent_name.clone();
99 after_event.llm_response.content = Some(content);
100 yield Ok(after_event);
101 break;
102 }
103 Ok(None) => continue,
104 Err(e) => { yield Err(e); return; }
105 }
106 }
107 };
108
109 Ok(Box::pin(s))
110 }
111}
112
113pub struct CustomAgentBuilder {
114 name: String,
115 description: String,
116 sub_agents: Vec<Arc<dyn Agent>>,
117 before_callbacks: Vec<BeforeAgentCallback>,
118 after_callbacks: Vec<AfterAgentCallback>,
119 handler: Option<RunHandler>,
120}
121
122impl CustomAgentBuilder {
123 pub fn new(name: impl Into<String>) -> Self {
124 Self {
125 name: name.into(),
126 description: String::new(),
127 sub_agents: Vec::new(),
128 before_callbacks: Vec::new(),
129 after_callbacks: Vec::new(),
130 handler: None,
131 }
132 }
133
134 pub fn description(mut self, description: impl Into<String>) -> Self {
135 self.description = description.into();
136 self
137 }
138
139 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
140 self.sub_agents.push(agent);
141 self
142 }
143
144 pub fn sub_agents(mut self, agents: Vec<Arc<dyn Agent>>) -> Self {
145 self.sub_agents = agents;
146 self
147 }
148
149 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
150 self.before_callbacks.push(callback);
151 self
152 }
153
154 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
155 self.after_callbacks.push(callback);
156 self
157 }
158
159 pub fn handler<F, Fut>(mut self, handler: F) -> Self
160 where
161 F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
162 Fut: Future<Output = Result<EventStream>> + Send + 'static,
163 {
164 self.handler = Some(Box::new(move |ctx| Box::pin(handler(ctx))));
165 self
166 }
167
168 pub fn build(self) -> Result<CustomAgent> {
169 let handler = self.handler.ok_or_else(|| {
170 adk_core::AdkError::Agent("CustomAgent requires a handler".to_string())
171 })?;
172
173 let mut seen_names = std::collections::HashSet::new();
175 for agent in &self.sub_agents {
176 if !seen_names.insert(agent.name()) {
177 return Err(adk_core::AdkError::Agent(format!(
178 "Duplicate sub-agent name: {}",
179 agent.name()
180 )));
181 }
182 }
183
184 Ok(CustomAgent {
185 name: self.name,
186 description: self.description,
187 sub_agents: self.sub_agents,
188 before_callbacks: Arc::new(self.before_callbacks),
189 after_callbacks: Arc::new(self.after_callbacks),
190 handler,
191 })
192 }
193}