autoagents_core/agent/
base.rs

1use super::{
2    error::AgentBuildError, output::AgentOutputT, AgentExecutor, IntoRunnable, RunnableAgent,
3};
4use crate::{
5    error::Error, memory::MemoryProvider, protocol::AgentID, runtime::Runtime, tool::ToolT,
6};
7use async_trait::async_trait;
8use autoagents_llm::{chat::StructuredOutputFormat, LLMProvider};
9use serde_json::Value;
10use std::{fmt::Debug, sync::Arc};
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14/// Core trait that defines agent metadata and behavior
15/// This trait is implemented via the #[agent] macro
16#[async_trait]
17pub trait AgentDeriveT: Send + Sync + 'static + AgentExecutor + Debug {
18    /// The output type this agent produces
19    type Output: AgentOutputT;
20
21    /// Get the agent's description
22    fn description(&self) -> &'static str;
23
24    fn output_schema(&self) -> Option<Value>;
25
26    /// Get the agent's name
27    fn name(&self) -> &'static str;
28
29    /// Get the tools available to this agent
30    fn tools(&self) -> Vec<Box<dyn ToolT>>;
31}
32
33pub struct AgentConfig {
34    /// The agent's name
35    pub name: String,
36    /// The agent's description
37    pub description: String,
38    /// The Agent ID
39    pub id: AgentID,
40    /// The output schema for the agent
41    pub output_schema: Option<StructuredOutputFormat>,
42}
43
44/// Base agent type that wraps an AgentDeriveT implementation with additional runtime components
45#[derive(Clone)]
46pub struct BaseAgent<T: AgentDeriveT> {
47    /// The inner agent implementation (from macro)
48    pub inner: Arc<T>,
49    /// LLM provider for this agent
50    pub llm: Arc<dyn LLMProvider>,
51    // Agent ID
52    pub id: AgentID,
53    /// Optional memory provider
54    pub memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
55}
56
57impl<T: AgentDeriveT> Debug for BaseAgent<T> {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.write_str(self.inner().name())
60    }
61}
62
63impl<T: AgentDeriveT> BaseAgent<T> {
64    /// Create a new BaseAgent wrapping an AgentDeriveT implementation
65    pub fn new(
66        inner: T,
67        llm: Arc<dyn LLMProvider>,
68        memory: Option<Box<dyn MemoryProvider>>,
69    ) -> Self {
70        // Convert tools to Arc for efficient sharing
71        Self {
72            inner: Arc::new(inner),
73            id: Uuid::new_v4(),
74            llm,
75            memory: memory.map(|m| Arc::new(RwLock::new(m))),
76        }
77    }
78
79    pub fn inner(&self) -> Arc<T> {
80        self.inner.clone()
81    }
82
83    /// Get the agent's name
84    pub fn name(&self) -> &'static str {
85        self.inner.name()
86    }
87
88    /// Get the agent's description
89    pub fn description(&self) -> &'static str {
90        self.inner.description()
91    }
92
93    /// Get the tools as Arc-wrapped references
94    pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
95        self.inner.tools()
96    }
97
98    pub fn agent_config(&self) -> AgentConfig {
99        let output_schema = self.inner().output_schema();
100        let structured_schema = output_schema.map(|schema| serde_json::from_value(schema).unwrap());
101        AgentConfig {
102            name: self.name().into(),
103            description: self.description().into(),
104            id: self.id,
105            output_schema: structured_schema,
106        }
107    }
108
109    /// Get the LLM provider
110    pub fn llm(&self) -> Arc<dyn LLMProvider> {
111        self.llm.clone()
112    }
113
114    /// Get the memory provider if available
115    pub fn memory(&self) -> Option<Arc<RwLock<Box<dyn MemoryProvider>>>> {
116        self.memory.clone()
117    }
118}
119
120/// Builder for creating BaseAgent instances from AgentDeriveT implementations
121pub struct AgentBuilder<T: AgentDeriveT + AgentExecutor> {
122    inner: T,
123    llm: Option<Arc<dyn LLMProvider>>,
124    memory: Option<Box<dyn MemoryProvider>>,
125    runtime: Option<Arc<dyn Runtime>>,
126    subscribed_topics: Vec<String>,
127}
128
129impl<T: AgentDeriveT + AgentExecutor> AgentBuilder<T> {
130    /// Create a new builder with an AgentDeriveT implementation
131    pub fn new(inner: T) -> Self {
132        Self {
133            inner,
134            llm: None,
135            memory: None,
136            runtime: None,
137            subscribed_topics: vec![],
138        }
139    }
140
141    /// Set the LLM provider
142    pub fn with_llm(mut self, llm: Arc<dyn LLMProvider>) -> Self {
143        self.llm = Some(llm);
144        self
145    }
146
147    /// Set the memory provider
148    pub fn with_memory(mut self, memory: Box<dyn MemoryProvider>) -> Self {
149        self.memory = Some(memory);
150        self
151    }
152
153    pub fn subscribe_topic<S: Into<String>>(mut self, topic: S) -> Self {
154        self.subscribed_topics.push(topic.into());
155        self
156    }
157
158    /// Build the BaseAgent
159    pub async fn build(self) -> Result<Arc<dyn RunnableAgent>, Error> {
160        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
161            "LLM provider is required".to_string(),
162        ))?;
163        let runnable = BaseAgent::new(self.inner, llm, self.memory).into_runnable();
164        if let Some(runtime) = self.runtime {
165            runtime.register_agent(runnable.clone()).await?;
166            for topic in self.subscribed_topics {
167                runtime.subscribe(runnable.id(), topic).await?;
168            }
169        } else {
170            return Err(AgentBuildError::BuildFailure("Runtime should be defined".into()).into());
171        }
172        Ok(runnable)
173    }
174
175    pub fn runtime(mut self, runtime: Arc<dyn Runtime>) -> Self {
176        self.runtime = Some(runtime);
177        self
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::agent::{AgentDeriveT, AgentState, ExecutorConfig};
185    use crate::memory::MemoryProvider;
186    use crate::protocol::Event;
187    use crate::runtime::Task;
188    use async_trait::async_trait;
189    use autoagents_llm::{chat::StructuredOutputFormat, LLMProvider};
190    use autoagents_test_utils::agent::{MockAgentImpl, TestAgentOutput, TestError};
191    use autoagents_test_utils::llm::MockLLMProvider;
192    use std::sync::Arc;
193    use tokio::sync::mpsc;
194
195    impl AgentOutputT for TestAgentOutput {
196        fn output_schema() -> &'static str {
197            r#"{"type":"object","properties":{"result":{"type":"string"}},"required":["result"]}"#
198        }
199
200        fn structured_output_format() -> serde_json::Value {
201            serde_json::json!({
202                "type": "object",
203                "properties": {
204                    "result": {"type": "string"}
205                },
206                "required": ["result"]
207            })
208        }
209    }
210
211    #[async_trait]
212    impl AgentDeriveT for MockAgentImpl {
213        type Output = TestAgentOutput;
214
215        fn name(&self) -> &'static str {
216            Box::leak(self.name.clone().into_boxed_str())
217        }
218
219        fn description(&self) -> &'static str {
220            Box::leak(self.description.clone().into_boxed_str())
221        }
222
223        fn output_schema(&self) -> Option<Value> {
224            Some(TestAgentOutput::structured_output_format())
225        }
226
227        fn tools(&self) -> Vec<Box<dyn ToolT>> {
228            vec![]
229        }
230    }
231
232    #[async_trait]
233    impl AgentExecutor for MockAgentImpl {
234        type Output = TestAgentOutput;
235        type Error = TestError;
236
237        fn config(&self) -> ExecutorConfig {
238            ExecutorConfig::default()
239        }
240
241        async fn execute(
242            &self,
243            _llm: Arc<dyn LLMProvider>,
244            _memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
245            _tools: Vec<Box<dyn ToolT>>,
246            _agent_config: &AgentConfig,
247            task: Task,
248            _state: Arc<RwLock<AgentState>>,
249            _tx_event: mpsc::Sender<Event>,
250        ) -> Result<Self::Output, Self::Error> {
251            if self.should_fail {
252                return Err(TestError::TestError("Mock execution failed".to_string()));
253            }
254
255            Ok(TestAgentOutput {
256                result: format!("Processed: {}", task.prompt),
257            })
258        }
259    }
260
261    #[test]
262    fn test_agent_config_creation() {
263        let config = AgentConfig {
264            name: "test_agent".to_string(),
265            id: Uuid::new_v4(),
266            description: "A test agent".to_string(),
267            output_schema: None,
268        };
269
270        assert_eq!(config.name, "test_agent");
271        assert_eq!(config.description, "A test agent");
272        assert!(config.output_schema.is_none());
273    }
274
275    #[test]
276    fn test_agent_config_with_schema() {
277        let schema = StructuredOutputFormat {
278            name: "TestSchema".to_string(),
279            description: Some("Test schema".to_string()),
280            schema: Some(serde_json::json!({"type": "object"})),
281            strict: Some(true),
282        };
283
284        let config = AgentConfig {
285            name: "test_agent".to_string(),
286            id: Uuid::new_v4(),
287            description: "A test agent".to_string(),
288            output_schema: Some(schema.clone()),
289        };
290
291        assert_eq!(config.name, "test_agent");
292        assert_eq!(config.description, "A test agent");
293        assert!(config.output_schema.is_some());
294        assert_eq!(config.output_schema.unwrap().name, "TestSchema");
295    }
296
297    #[test]
298    fn test_base_agent_creation() {
299        let mock_agent = MockAgentImpl::new("test", "test description");
300        let llm = Arc::new(MockLLMProvider);
301        let base_agent = BaseAgent::new(mock_agent, llm, None);
302
303        assert_eq!(base_agent.name(), "test");
304        assert_eq!(base_agent.description(), "test description");
305        assert!(base_agent.memory().is_none());
306    }
307
308    #[test]
309    fn test_base_agent_with_memory() {
310        let mock_agent = MockAgentImpl::new("test", "test description");
311        let llm = Arc::new(MockLLMProvider);
312        let memory = Box::new(crate::memory::SlidingWindowMemory::new(5));
313        let base_agent = BaseAgent::new(mock_agent, llm, Some(memory));
314
315        assert_eq!(base_agent.name(), "test");
316        assert_eq!(base_agent.description(), "test description");
317        assert!(base_agent.memory().is_some());
318    }
319
320    #[test]
321    fn test_base_agent_inner() {
322        let mock_agent = MockAgentImpl::new("test", "test description");
323        let llm = Arc::new(MockLLMProvider);
324        let base_agent = BaseAgent::new(mock_agent, llm, None);
325
326        let inner = base_agent.inner();
327        assert_eq!(inner.name(), "test");
328        assert_eq!(inner.description(), "test description");
329    }
330
331    #[test]
332    fn test_base_agent_tools() {
333        let mock_agent = MockAgentImpl::new("test", "test description");
334        let llm = Arc::new(MockLLMProvider);
335        let base_agent = BaseAgent::new(mock_agent, llm, None);
336
337        let tools = base_agent.tools();
338        assert!(tools.is_empty());
339    }
340
341    #[test]
342    fn test_base_agent_llm() {
343        let mock_agent = MockAgentImpl::new("test", "test description");
344        let llm = Arc::new(MockLLMProvider);
345        let base_agent = BaseAgent::new(mock_agent, llm.clone(), None);
346
347        let agent_llm = base_agent.llm();
348        // The llm() method returns Arc<dyn LLMProvider>, so we just verify it exists
349        assert!(Arc::strong_count(&agent_llm) > 0);
350    }
351}