1use super::{
2 error::AgentBuildError, output::AgentOutputT, AgentExecutor, IntoRunnable, RunnableAgent,
3};
4use crate::{
5 error::Error, memory::MemoryProvider, protocol::AgentID, runtime::Runtime, tool::ToolT,
6};
7use async_trait::async_trait;
8use autoagents_llm::{chat::StructuredOutputFormat, LLMProvider};
9use serde_json::Value;
10use std::{fmt::Debug, sync::Arc};
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14#[async_trait]
17pub trait AgentDeriveT: Send + Sync + 'static + AgentExecutor + Debug {
18 type Output: AgentOutputT;
20
21 fn description(&self) -> &'static str;
23
24 fn output_schema(&self) -> Option<Value>;
25
26 fn name(&self) -> &'static str;
28
29 fn tools(&self) -> Vec<Box<dyn ToolT>>;
31}
32
33pub struct AgentConfig {
34 pub name: String,
36 pub description: String,
38 pub id: AgentID,
40 pub output_schema: Option<StructuredOutputFormat>,
42}
43
44#[derive(Clone)]
46pub struct BaseAgent<T: AgentDeriveT> {
47 pub inner: Arc<T>,
49 pub llm: Arc<dyn LLMProvider>,
51 pub id: AgentID,
53 pub memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
55}
56
57impl<T: AgentDeriveT> Debug for BaseAgent<T> {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.write_str(self.inner().name())
60 }
61}
62
63impl<T: AgentDeriveT> BaseAgent<T> {
64 pub fn new(
66 inner: T,
67 llm: Arc<dyn LLMProvider>,
68 memory: Option<Box<dyn MemoryProvider>>,
69 ) -> Self {
70 Self {
72 inner: Arc::new(inner),
73 id: Uuid::new_v4(),
74 llm,
75 memory: memory.map(|m| Arc::new(RwLock::new(m))),
76 }
77 }
78
79 pub fn inner(&self) -> Arc<T> {
80 self.inner.clone()
81 }
82
83 pub fn name(&self) -> &'static str {
85 self.inner.name()
86 }
87
88 pub fn description(&self) -> &'static str {
90 self.inner.description()
91 }
92
93 pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
95 self.inner.tools()
96 }
97
98 pub fn agent_config(&self) -> AgentConfig {
99 let output_schema = self.inner().output_schema();
100 let structured_schema = output_schema.map(|schema| serde_json::from_value(schema).unwrap());
101 AgentConfig {
102 name: self.name().into(),
103 description: self.description().into(),
104 id: self.id,
105 output_schema: structured_schema,
106 }
107 }
108
109 pub fn llm(&self) -> Arc<dyn LLMProvider> {
111 self.llm.clone()
112 }
113
114 pub fn memory(&self) -> Option<Arc<RwLock<Box<dyn MemoryProvider>>>> {
116 self.memory.clone()
117 }
118}
119
120pub struct AgentBuilder<T: AgentDeriveT + AgentExecutor> {
122 inner: T,
123 llm: Option<Arc<dyn LLMProvider>>,
124 memory: Option<Box<dyn MemoryProvider>>,
125 runtime: Option<Arc<dyn Runtime>>,
126 subscribed_topics: Vec<String>,
127}
128
129impl<T: AgentDeriveT + AgentExecutor> AgentBuilder<T> {
130 pub fn new(inner: T) -> Self {
132 Self {
133 inner,
134 llm: None,
135 memory: None,
136 runtime: None,
137 subscribed_topics: vec![],
138 }
139 }
140
141 pub fn with_llm(mut self, llm: Arc<dyn LLMProvider>) -> Self {
143 self.llm = Some(llm);
144 self
145 }
146
147 pub fn with_memory(mut self, memory: Box<dyn MemoryProvider>) -> Self {
149 self.memory = Some(memory);
150 self
151 }
152
153 pub fn subscribe_topic<S: Into<String>>(mut self, topic: S) -> Self {
154 self.subscribed_topics.push(topic.into());
155 self
156 }
157
158 pub async fn build(self) -> Result<Arc<dyn RunnableAgent>, Error> {
160 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
161 "LLM provider is required".to_string(),
162 ))?;
163 let runnable = BaseAgent::new(self.inner, llm, self.memory).into_runnable();
164 if let Some(runtime) = self.runtime {
165 runtime.register_agent(runnable.clone()).await?;
166 for topic in self.subscribed_topics {
167 runtime.subscribe(runnable.id(), topic).await?;
168 }
169 } else {
170 return Err(AgentBuildError::BuildFailure("Runtime should be defined".into()).into());
171 }
172 Ok(runnable)
173 }
174
175 pub fn runtime(mut self, runtime: Arc<dyn Runtime>) -> Self {
176 self.runtime = Some(runtime);
177 self
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::agent::{AgentDeriveT, AgentState, ExecutorConfig};
185 use crate::memory::MemoryProvider;
186 use crate::protocol::Event;
187 use crate::runtime::Task;
188 use async_trait::async_trait;
189 use autoagents_llm::{chat::StructuredOutputFormat, LLMProvider};
190 use autoagents_test_utils::agent::{MockAgentImpl, TestAgentOutput, TestError};
191 use autoagents_test_utils::llm::MockLLMProvider;
192 use std::sync::Arc;
193 use tokio::sync::mpsc;
194
195 impl AgentOutputT for TestAgentOutput {
196 fn output_schema() -> &'static str {
197 r#"{"type":"object","properties":{"result":{"type":"string"}},"required":["result"]}"#
198 }
199
200 fn structured_output_format() -> serde_json::Value {
201 serde_json::json!({
202 "type": "object",
203 "properties": {
204 "result": {"type": "string"}
205 },
206 "required": ["result"]
207 })
208 }
209 }
210
211 #[async_trait]
212 impl AgentDeriveT for MockAgentImpl {
213 type Output = TestAgentOutput;
214
215 fn name(&self) -> &'static str {
216 Box::leak(self.name.clone().into_boxed_str())
217 }
218
219 fn description(&self) -> &'static str {
220 Box::leak(self.description.clone().into_boxed_str())
221 }
222
223 fn output_schema(&self) -> Option<Value> {
224 Some(TestAgentOutput::structured_output_format())
225 }
226
227 fn tools(&self) -> Vec<Box<dyn ToolT>> {
228 vec![]
229 }
230 }
231
232 #[async_trait]
233 impl AgentExecutor for MockAgentImpl {
234 type Output = TestAgentOutput;
235 type Error = TestError;
236
237 fn config(&self) -> ExecutorConfig {
238 ExecutorConfig::default()
239 }
240
241 async fn execute(
242 &self,
243 _llm: Arc<dyn LLMProvider>,
244 _memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
245 _tools: Vec<Box<dyn ToolT>>,
246 _agent_config: &AgentConfig,
247 task: Task,
248 _state: Arc<RwLock<AgentState>>,
249 _tx_event: mpsc::Sender<Event>,
250 ) -> Result<Self::Output, Self::Error> {
251 if self.should_fail {
252 return Err(TestError::TestError("Mock execution failed".to_string()));
253 }
254
255 Ok(TestAgentOutput {
256 result: format!("Processed: {}", task.prompt),
257 })
258 }
259 }
260
261 #[test]
262 fn test_agent_config_creation() {
263 let config = AgentConfig {
264 name: "test_agent".to_string(),
265 id: Uuid::new_v4(),
266 description: "A test agent".to_string(),
267 output_schema: None,
268 };
269
270 assert_eq!(config.name, "test_agent");
271 assert_eq!(config.description, "A test agent");
272 assert!(config.output_schema.is_none());
273 }
274
275 #[test]
276 fn test_agent_config_with_schema() {
277 let schema = StructuredOutputFormat {
278 name: "TestSchema".to_string(),
279 description: Some("Test schema".to_string()),
280 schema: Some(serde_json::json!({"type": "object"})),
281 strict: Some(true),
282 };
283
284 let config = AgentConfig {
285 name: "test_agent".to_string(),
286 id: Uuid::new_v4(),
287 description: "A test agent".to_string(),
288 output_schema: Some(schema.clone()),
289 };
290
291 assert_eq!(config.name, "test_agent");
292 assert_eq!(config.description, "A test agent");
293 assert!(config.output_schema.is_some());
294 assert_eq!(config.output_schema.unwrap().name, "TestSchema");
295 }
296
297 #[test]
298 fn test_base_agent_creation() {
299 let mock_agent = MockAgentImpl::new("test", "test description");
300 let llm = Arc::new(MockLLMProvider);
301 let base_agent = BaseAgent::new(mock_agent, llm, None);
302
303 assert_eq!(base_agent.name(), "test");
304 assert_eq!(base_agent.description(), "test description");
305 assert!(base_agent.memory().is_none());
306 }
307
308 #[test]
309 fn test_base_agent_with_memory() {
310 let mock_agent = MockAgentImpl::new("test", "test description");
311 let llm = Arc::new(MockLLMProvider);
312 let memory = Box::new(crate::memory::SlidingWindowMemory::new(5));
313 let base_agent = BaseAgent::new(mock_agent, llm, Some(memory));
314
315 assert_eq!(base_agent.name(), "test");
316 assert_eq!(base_agent.description(), "test description");
317 assert!(base_agent.memory().is_some());
318 }
319
320 #[test]
321 fn test_base_agent_inner() {
322 let mock_agent = MockAgentImpl::new("test", "test description");
323 let llm = Arc::new(MockLLMProvider);
324 let base_agent = BaseAgent::new(mock_agent, llm, None);
325
326 let inner = base_agent.inner();
327 assert_eq!(inner.name(), "test");
328 assert_eq!(inner.description(), "test description");
329 }
330
331 #[test]
332 fn test_base_agent_tools() {
333 let mock_agent = MockAgentImpl::new("test", "test description");
334 let llm = Arc::new(MockLLMProvider);
335 let base_agent = BaseAgent::new(mock_agent, llm, None);
336
337 let tools = base_agent.tools();
338 assert!(tools.is_empty());
339 }
340
341 #[test]
342 fn test_base_agent_llm() {
343 let mock_agent = MockAgentImpl::new("test", "test description");
344 let llm = Arc::new(MockLLMProvider);
345 let base_agent = BaseAgent::new(mock_agent, llm.clone(), None);
346
347 let agent_llm = base_agent.llm();
348 assert!(Arc::strong_count(&agent_llm) > 0);
350 }
351}