adk_agent/
custom_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, EventStream, InvocationContext, Result,
3};
4use async_trait::async_trait;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9type RunHandler = Box<
10 dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = Result<EventStream>> + Send>>
11 + Send
12 + Sync,
13>;
14
15pub struct CustomAgent {
16 name: String,
17 description: String,
18 sub_agents: Vec<Arc<dyn Agent>>,
19 #[allow(dead_code)] before_callbacks: Vec<BeforeAgentCallback>,
21 #[allow(dead_code)] after_callbacks: Vec<AfterAgentCallback>,
23 handler: RunHandler,
24}
25
26impl CustomAgent {
27 pub fn builder(name: impl Into<String>) -> CustomAgentBuilder {
28 CustomAgentBuilder::new(name)
29 }
30}
31
32#[async_trait]
33impl Agent for CustomAgent {
34 fn name(&self) -> &str {
35 &self.name
36 }
37
38 fn description(&self) -> &str {
39 &self.description
40 }
41
42 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
43 &self.sub_agents
44 }
45
46 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
47 (self.handler)(ctx).await
48 }
49}
50
51pub struct CustomAgentBuilder {
52 name: String,
53 description: String,
54 sub_agents: Vec<Arc<dyn Agent>>,
55 before_callbacks: Vec<BeforeAgentCallback>,
56 after_callbacks: Vec<AfterAgentCallback>,
57 handler: Option<RunHandler>,
58}
59
60impl CustomAgentBuilder {
61 pub fn new(name: impl Into<String>) -> Self {
62 Self {
63 name: name.into(),
64 description: String::new(),
65 sub_agents: Vec::new(),
66 before_callbacks: Vec::new(),
67 after_callbacks: Vec::new(),
68 handler: None,
69 }
70 }
71
72 pub fn description(mut self, description: impl Into<String>) -> Self {
73 self.description = description.into();
74 self
75 }
76
77 pub fn sub_agent(mut self, agent: Arc<dyn Agent>) -> Self {
78 self.sub_agents.push(agent);
79 self
80 }
81
82 pub fn sub_agents(mut self, agents: Vec<Arc<dyn Agent>>) -> Self {
83 self.sub_agents = agents;
84 self
85 }
86
87 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
88 self.before_callbacks.push(callback);
89 self
90 }
91
92 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
93 self.after_callbacks.push(callback);
94 self
95 }
96
97 pub fn handler<F, Fut>(mut self, handler: F) -> Self
98 where
99 F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
100 Fut: Future<Output = Result<EventStream>> + Send + 'static,
101 {
102 self.handler = Some(Box::new(move |ctx| Box::pin(handler(ctx))));
103 self
104 }
105
106 pub fn build(self) -> Result<CustomAgent> {
107 let handler = self.handler.ok_or_else(|| {
108 adk_core::AdkError::Agent("CustomAgent requires a handler".to_string())
109 })?;
110
111 let mut seen_names = std::collections::HashSet::new();
113 for agent in &self.sub_agents {
114 if !seen_names.insert(agent.name()) {
115 return Err(adk_core::AdkError::Agent(format!(
116 "Duplicate sub-agent name: {}",
117 agent.name()
118 )));
119 }
120 }
121
122 Ok(CustomAgent {
123 name: self.name,
124 description: self.description,
125 sub_agents: self.sub_agents,
126 before_callbacks: self.before_callbacks,
127 after_callbacks: self.after_callbacks,
128 handler,
129 })
130 }
131}