1use super::{
2 error::AgentBuildError, output::AgentOutputT, AgentExecutor, IntoRunnable, RunnableAgent,
3};
4use crate::{error::Error, memory::MemoryProvider, protocol::AgentID, runtime::Runtime};
5use async_trait::async_trait;
6use autoagents_llm::{chat::StructuredOutputFormat, LLMProvider, ToolT};
7use serde_json::Value;
8use std::{fmt::Debug, sync::Arc};
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12#[async_trait]
15pub trait AgentDeriveT: Send + Sync + 'static + AgentExecutor + Debug {
16 type Output: AgentOutputT;
18
19 fn description(&self) -> &'static str;
21
22 fn output_schema(&self) -> Option<Value>;
23
24 fn name(&self) -> &'static str;
26
27 fn tools(&self) -> Vec<Box<dyn ToolT>>;
29}
30
31pub struct AgentConfig {
32 pub name: String,
34 pub description: String,
36 pub id: AgentID,
38 pub output_schema: Option<StructuredOutputFormat>,
40}
41
42#[derive(Clone)]
44pub struct BaseAgent<T: AgentDeriveT> {
45 pub inner: Arc<T>,
47 pub llm: Arc<dyn LLMProvider>,
49 pub id: AgentID,
51 pub memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
53}
54
55impl<T: AgentDeriveT> Debug for BaseAgent<T> {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.write_str(self.inner().name())
58 }
59}
60
61impl<T: AgentDeriveT> BaseAgent<T> {
62 pub fn new(
64 inner: T,
65 llm: Arc<dyn LLMProvider>,
66 memory: Option<Box<dyn MemoryProvider>>,
67 ) -> Self {
68 Self {
70 inner: Arc::new(inner),
71 id: Uuid::new_v4(),
72 llm,
73 memory: memory.map(|m| Arc::new(RwLock::new(m))),
74 }
75 }
76
77 pub fn inner(&self) -> Arc<T> {
78 self.inner.clone()
79 }
80
81 pub fn name(&self) -> &'static str {
83 self.inner.name()
84 }
85
86 pub fn description(&self) -> &'static str {
88 self.inner.description()
89 }
90
91 pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
93 self.inner.tools()
94 }
95
96 pub fn agent_config(&self) -> AgentConfig {
97 let output_schema = self.inner().output_schema();
98 let structured_schema = output_schema.map(|schema| serde_json::from_value(schema).unwrap());
99 AgentConfig {
100 name: self.name().into(),
101 description: self.description().into(),
102 id: self.id,
103 output_schema: structured_schema,
104 }
105 }
106
107 pub fn llm(&self) -> Arc<dyn LLMProvider> {
109 self.llm.clone()
110 }
111
112 pub fn memory(&self) -> Option<Arc<RwLock<Box<dyn MemoryProvider>>>> {
114 self.memory.clone()
115 }
116}
117
118pub struct AgentBuilder<T: AgentDeriveT + AgentExecutor> {
120 inner: T,
121 llm: Option<Arc<dyn LLMProvider>>,
122 memory: Option<Box<dyn MemoryProvider>>,
123 runtime: Option<Arc<dyn Runtime>>,
124 subscribed_topics: Vec<String>,
125}
126
127impl<T: AgentDeriveT + AgentExecutor> AgentBuilder<T> {
128 pub fn new(inner: T) -> Self {
130 Self {
131 inner,
132 llm: None,
133 memory: None,
134 runtime: None,
135 subscribed_topics: vec![],
136 }
137 }
138
139 pub fn with_llm(mut self, llm: Arc<dyn LLMProvider>) -> Self {
141 self.llm = Some(llm);
142 self
143 }
144
145 pub fn with_memory(mut self, memory: Box<dyn MemoryProvider>) -> Self {
147 self.memory = Some(memory);
148 self
149 }
150
151 pub fn subscribe_topic<S: Into<String>>(mut self, topic: S) -> Self {
152 self.subscribed_topics.push(topic.into());
153 self
154 }
155
156 pub async fn build(self) -> Result<Arc<dyn RunnableAgent>, Error> {
158 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
159 "LLM provider is required".to_string(),
160 ))?;
161 let runnable = BaseAgent::new(self.inner, llm, self.memory).into_runnable();
162 if let Some(runtime) = self.runtime {
163 runtime.register_agent(runnable.clone()).await?;
164 for topic in self.subscribed_topics {
165 runtime.subscribe(runnable.id(), topic).await?;
166 }
167 } else {
168 return Err(AgentBuildError::BuildFailure("Runtime should be defined".into()).into());
169 }
170 Ok(runnable)
171 }
172
173 pub fn runtime(mut self, runtime: Arc<dyn Runtime>) -> Self {
174 self.runtime = Some(runtime);
175 self
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::agent::{AgentDeriveT, AgentState, ExecutorConfig};
183 use crate::memory::MemoryProvider;
184 use crate::protocol::Event;
185 use crate::runtime::Task;
186 use async_trait::async_trait;
187 use autoagents_llm::{chat::StructuredOutputFormat, LLMProvider, ToolT};
188 use autoagents_test_utils::agent::{MockAgentImpl, TestAgentOutput, TestError};
189 use autoagents_test_utils::llm::MockLLMProvider;
190 use std::sync::Arc;
191 use tokio::sync::mpsc;
192
193 impl AgentOutputT for TestAgentOutput {
194 fn output_schema() -> &'static str {
195 r#"{"type":"object","properties":{"result":{"type":"string"}},"required":["result"]}"#
196 }
197
198 fn structured_output_format() -> serde_json::Value {
199 serde_json::json!({
200 "type": "object",
201 "properties": {
202 "result": {"type": "string"}
203 },
204 "required": ["result"]
205 })
206 }
207 }
208
209 #[async_trait]
210 impl AgentDeriveT for MockAgentImpl {
211 type Output = TestAgentOutput;
212
213 fn name(&self) -> &'static str {
214 Box::leak(self.name.clone().into_boxed_str())
215 }
216
217 fn description(&self) -> &'static str {
218 Box::leak(self.description.clone().into_boxed_str())
219 }
220
221 fn output_schema(&self) -> Option<Value> {
222 Some(TestAgentOutput::structured_output_format())
223 }
224
225 fn tools(&self) -> Vec<Box<dyn ToolT>> {
226 vec![]
227 }
228 }
229
230 #[async_trait]
231 impl AgentExecutor for MockAgentImpl {
232 type Output = TestAgentOutput;
233 type Error = TestError;
234
235 fn config(&self) -> ExecutorConfig {
236 ExecutorConfig::default()
237 }
238
239 async fn execute(
240 &self,
241 _llm: Arc<dyn LLMProvider>,
242 _memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
243 _tools: Vec<Box<dyn ToolT>>,
244 _agent_config: &AgentConfig,
245 task: Task,
246 _state: Arc<RwLock<AgentState>>,
247 _tx_event: mpsc::Sender<Event>,
248 ) -> Result<Self::Output, Self::Error> {
249 if self.should_fail {
250 return Err(TestError::TestError("Mock execution failed".to_string()));
251 }
252
253 Ok(TestAgentOutput {
254 result: format!("Processed: {}", task.prompt),
255 })
256 }
257 }
258
259 #[test]
260 fn test_agent_config_creation() {
261 let config = AgentConfig {
262 name: "test_agent".to_string(),
263 id: Uuid::new_v4(),
264 description: "A test agent".to_string(),
265 output_schema: None,
266 };
267
268 assert_eq!(config.name, "test_agent");
269 assert_eq!(config.description, "A test agent");
270 assert!(config.output_schema.is_none());
271 }
272
273 #[test]
274 fn test_agent_config_with_schema() {
275 let schema = StructuredOutputFormat {
276 name: "TestSchema".to_string(),
277 description: Some("Test schema".to_string()),
278 schema: Some(serde_json::json!({"type": "object"})),
279 strict: Some(true),
280 };
281
282 let config = AgentConfig {
283 name: "test_agent".to_string(),
284 id: Uuid::new_v4(),
285 description: "A test agent".to_string(),
286 output_schema: Some(schema.clone()),
287 };
288
289 assert_eq!(config.name, "test_agent");
290 assert_eq!(config.description, "A test agent");
291 assert!(config.output_schema.is_some());
292 assert_eq!(config.output_schema.unwrap().name, "TestSchema");
293 }
294
295 #[test]
296 fn test_base_agent_creation() {
297 let mock_agent = MockAgentImpl::new("test", "test description");
298 let llm = Arc::new(MockLLMProvider);
299 let base_agent = BaseAgent::new(mock_agent, llm, None);
300
301 assert_eq!(base_agent.name(), "test");
302 assert_eq!(base_agent.description(), "test description");
303 assert!(base_agent.memory().is_none());
304 }
305
306 #[test]
307 fn test_base_agent_with_memory() {
308 let mock_agent = MockAgentImpl::new("test", "test description");
309 let llm = Arc::new(MockLLMProvider);
310 let memory = Box::new(crate::memory::SlidingWindowMemory::new(5));
311 let base_agent = BaseAgent::new(mock_agent, llm, Some(memory));
312
313 assert_eq!(base_agent.name(), "test");
314 assert_eq!(base_agent.description(), "test description");
315 assert!(base_agent.memory().is_some());
316 }
317
318 #[test]
319 fn test_base_agent_inner() {
320 let mock_agent = MockAgentImpl::new("test", "test description");
321 let llm = Arc::new(MockLLMProvider);
322 let base_agent = BaseAgent::new(mock_agent, llm, None);
323
324 let inner = base_agent.inner();
325 assert_eq!(inner.name(), "test");
326 assert_eq!(inner.description(), "test description");
327 }
328
329 #[test]
330 fn test_base_agent_tools() {
331 let mock_agent = MockAgentImpl::new("test", "test description");
332 let llm = Arc::new(MockLLMProvider);
333 let base_agent = BaseAgent::new(mock_agent, llm, None);
334
335 let tools = base_agent.tools();
336 assert!(tools.is_empty());
337 }
338
339 #[test]
340 fn test_base_agent_llm() {
341 let mock_agent = MockAgentImpl::new("test", "test description");
342 let llm = Arc::new(MockLLMProvider);
343 let base_agent = BaseAgent::new(mock_agent, llm.clone(), None);
344
345 let agent_llm = base_agent.llm();
346 assert!(Arc::strong_count(&agent_llm) > 0);
348 }
349}