autoagents_core/agent/
base.rs

1use super::{
2    error::AgentBuildError, output::AgentOutputT, AgentExecutor, IntoRunnable, RunnableAgent,
3};
4use crate::{error::Error, memory::MemoryProvider, protocol::AgentID, runtime::Runtime};
5use async_trait::async_trait;
6use autoagents_llm::{chat::StructuredOutputFormat, LLMProvider, ToolT};
7use serde_json::Value;
8use std::{fmt::Debug, sync::Arc};
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12/// Core trait that defines agent metadata and behavior
13/// This trait is implemented via the #[agent] macro
14#[async_trait]
15pub trait AgentDeriveT: Send + Sync + 'static + AgentExecutor + Debug {
16    /// The output type this agent produces
17    type Output: AgentOutputT;
18
19    /// Get the agent's description
20    fn description(&self) -> &'static str;
21
22    fn output_schema(&self) -> Option<Value>;
23
24    /// Get the agent's name
25    fn name(&self) -> &'static str;
26
27    /// Get the tools available to this agent
28    fn tools(&self) -> Vec<Box<dyn ToolT>>;
29}
30
31pub struct AgentConfig {
32    /// The agent's name
33    pub name: String,
34    /// The agent's description
35    pub description: String,
36    /// The Agent ID
37    pub id: AgentID,
38    /// The output schema for the agent
39    pub output_schema: Option<StructuredOutputFormat>,
40}
41
42/// Base agent type that wraps an AgentDeriveT implementation with additional runtime components
43#[derive(Clone)]
44pub struct BaseAgent<T: AgentDeriveT> {
45    /// The inner agent implementation (from macro)
46    pub inner: Arc<T>,
47    /// LLM provider for this agent
48    pub llm: Arc<dyn LLMProvider>,
49    // Agent ID
50    pub id: AgentID,
51    /// Optional memory provider
52    pub memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
53}
54
55impl<T: AgentDeriveT> Debug for BaseAgent<T> {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.write_str(self.inner().name())
58    }
59}
60
61impl<T: AgentDeriveT> BaseAgent<T> {
62    /// Create a new BaseAgent wrapping an AgentDeriveT implementation
63    pub fn new(
64        inner: T,
65        llm: Arc<dyn LLMProvider>,
66        memory: Option<Box<dyn MemoryProvider>>,
67    ) -> Self {
68        // Convert tools to Arc for efficient sharing
69        Self {
70            inner: Arc::new(inner),
71            id: Uuid::new_v4(),
72            llm,
73            memory: memory.map(|m| Arc::new(RwLock::new(m))),
74        }
75    }
76
77    pub fn inner(&self) -> Arc<T> {
78        self.inner.clone()
79    }
80
81    /// Get the agent's name
82    pub fn name(&self) -> &'static str {
83        self.inner.name()
84    }
85
86    /// Get the agent's description
87    pub fn description(&self) -> &'static str {
88        self.inner.description()
89    }
90
91    /// Get the tools as Arc-wrapped references
92    pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
93        self.inner.tools()
94    }
95
96    pub fn agent_config(&self) -> AgentConfig {
97        let output_schema = self.inner().output_schema();
98        let structured_schema = output_schema.map(|schema| serde_json::from_value(schema).unwrap());
99        AgentConfig {
100            name: self.name().into(),
101            description: self.description().into(),
102            id: self.id,
103            output_schema: structured_schema,
104        }
105    }
106
107    /// Get the LLM provider
108    pub fn llm(&self) -> Arc<dyn LLMProvider> {
109        self.llm.clone()
110    }
111
112    /// Get the memory provider if available
113    pub fn memory(&self) -> Option<Arc<RwLock<Box<dyn MemoryProvider>>>> {
114        self.memory.clone()
115    }
116}
117
118/// Builder for creating BaseAgent instances from AgentDeriveT implementations
119pub struct AgentBuilder<T: AgentDeriveT + AgentExecutor> {
120    inner: T,
121    llm: Option<Arc<dyn LLMProvider>>,
122    memory: Option<Box<dyn MemoryProvider>>,
123    runtime: Option<Arc<dyn Runtime>>,
124    subscribed_topics: Vec<String>,
125}
126
127impl<T: AgentDeriveT + AgentExecutor> AgentBuilder<T> {
128    /// Create a new builder with an AgentDeriveT implementation
129    pub fn new(inner: T) -> Self {
130        Self {
131            inner,
132            llm: None,
133            memory: None,
134            runtime: None,
135            subscribed_topics: vec![],
136        }
137    }
138
139    /// Set the LLM provider
140    pub fn with_llm(mut self, llm: Arc<dyn LLMProvider>) -> Self {
141        self.llm = Some(llm);
142        self
143    }
144
145    /// Set the memory provider
146    pub fn with_memory(mut self, memory: Box<dyn MemoryProvider>) -> Self {
147        self.memory = Some(memory);
148        self
149    }
150
151    pub fn subscribe_topic<S: Into<String>>(mut self, topic: S) -> Self {
152        self.subscribed_topics.push(topic.into());
153        self
154    }
155
156    /// Build the BaseAgent
157    pub async fn build(self) -> Result<Arc<dyn RunnableAgent>, Error> {
158        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
159            "LLM provider is required".to_string(),
160        ))?;
161        let runnable = BaseAgent::new(self.inner, llm, self.memory).into_runnable();
162        if let Some(runtime) = self.runtime {
163            runtime.register_agent(runnable.clone()).await?;
164            for topic in self.subscribed_topics {
165                runtime.subscribe(runnable.id(), topic).await?;
166            }
167        } else {
168            return Err(AgentBuildError::BuildFailure("Runtime should be defined".into()).into());
169        }
170        Ok(runnable)
171    }
172
173    pub fn runtime(mut self, runtime: Arc<dyn Runtime>) -> Self {
174        self.runtime = Some(runtime);
175        self
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::agent::{AgentDeriveT, AgentState, ExecutorConfig};
183    use crate::memory::MemoryProvider;
184    use crate::protocol::Event;
185    use crate::runtime::Task;
186    use async_trait::async_trait;
187    use autoagents_llm::{chat::StructuredOutputFormat, LLMProvider, ToolT};
188    use autoagents_test_utils::agent::{MockAgentImpl, TestAgentOutput, TestError};
189    use autoagents_test_utils::llm::MockLLMProvider;
190    use std::sync::Arc;
191    use tokio::sync::mpsc;
192
193    impl AgentOutputT for TestAgentOutput {
194        fn output_schema() -> &'static str {
195            r#"{"type":"object","properties":{"result":{"type":"string"}},"required":["result"]}"#
196        }
197
198        fn structured_output_format() -> serde_json::Value {
199            serde_json::json!({
200                "type": "object",
201                "properties": {
202                    "result": {"type": "string"}
203                },
204                "required": ["result"]
205            })
206        }
207    }
208
209    #[async_trait]
210    impl AgentDeriveT for MockAgentImpl {
211        type Output = TestAgentOutput;
212
213        fn name(&self) -> &'static str {
214            Box::leak(self.name.clone().into_boxed_str())
215        }
216
217        fn description(&self) -> &'static str {
218            Box::leak(self.description.clone().into_boxed_str())
219        }
220
221        fn output_schema(&self) -> Option<Value> {
222            Some(TestAgentOutput::structured_output_format())
223        }
224
225        fn tools(&self) -> Vec<Box<dyn ToolT>> {
226            vec![]
227        }
228    }
229
230    #[async_trait]
231    impl AgentExecutor for MockAgentImpl {
232        type Output = TestAgentOutput;
233        type Error = TestError;
234
235        fn config(&self) -> ExecutorConfig {
236            ExecutorConfig::default()
237        }
238
239        async fn execute(
240            &self,
241            _llm: Arc<dyn LLMProvider>,
242            _memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
243            _tools: Vec<Box<dyn ToolT>>,
244            _agent_config: &AgentConfig,
245            task: Task,
246            _state: Arc<RwLock<AgentState>>,
247            _tx_event: mpsc::Sender<Event>,
248        ) -> Result<Self::Output, Self::Error> {
249            if self.should_fail {
250                return Err(TestError::TestError("Mock execution failed".to_string()));
251            }
252
253            Ok(TestAgentOutput {
254                result: format!("Processed: {}", task.prompt),
255            })
256        }
257    }
258
259    #[test]
260    fn test_agent_config_creation() {
261        let config = AgentConfig {
262            name: "test_agent".to_string(),
263            id: Uuid::new_v4(),
264            description: "A test agent".to_string(),
265            output_schema: None,
266        };
267
268        assert_eq!(config.name, "test_agent");
269        assert_eq!(config.description, "A test agent");
270        assert!(config.output_schema.is_none());
271    }
272
273    #[test]
274    fn test_agent_config_with_schema() {
275        let schema = StructuredOutputFormat {
276            name: "TestSchema".to_string(),
277            description: Some("Test schema".to_string()),
278            schema: Some(serde_json::json!({"type": "object"})),
279            strict: Some(true),
280        };
281
282        let config = AgentConfig {
283            name: "test_agent".to_string(),
284            id: Uuid::new_v4(),
285            description: "A test agent".to_string(),
286            output_schema: Some(schema.clone()),
287        };
288
289        assert_eq!(config.name, "test_agent");
290        assert_eq!(config.description, "A test agent");
291        assert!(config.output_schema.is_some());
292        assert_eq!(config.output_schema.unwrap().name, "TestSchema");
293    }
294
295    #[test]
296    fn test_base_agent_creation() {
297        let mock_agent = MockAgentImpl::new("test", "test description");
298        let llm = Arc::new(MockLLMProvider);
299        let base_agent = BaseAgent::new(mock_agent, llm, None);
300
301        assert_eq!(base_agent.name(), "test");
302        assert_eq!(base_agent.description(), "test description");
303        assert!(base_agent.memory().is_none());
304    }
305
306    #[test]
307    fn test_base_agent_with_memory() {
308        let mock_agent = MockAgentImpl::new("test", "test description");
309        let llm = Arc::new(MockLLMProvider);
310        let memory = Box::new(crate::memory::SlidingWindowMemory::new(5));
311        let base_agent = BaseAgent::new(mock_agent, llm, Some(memory));
312
313        assert_eq!(base_agent.name(), "test");
314        assert_eq!(base_agent.description(), "test description");
315        assert!(base_agent.memory().is_some());
316    }
317
318    #[test]
319    fn test_base_agent_inner() {
320        let mock_agent = MockAgentImpl::new("test", "test description");
321        let llm = Arc::new(MockLLMProvider);
322        let base_agent = BaseAgent::new(mock_agent, llm, None);
323
324        let inner = base_agent.inner();
325        assert_eq!(inner.name(), "test");
326        assert_eq!(inner.description(), "test description");
327    }
328
329    #[test]
330    fn test_base_agent_tools() {
331        let mock_agent = MockAgentImpl::new("test", "test description");
332        let llm = Arc::new(MockLLMProvider);
333        let base_agent = BaseAgent::new(mock_agent, llm, None);
334
335        let tools = base_agent.tools();
336        assert!(tools.is_empty());
337    }
338
339    #[test]
340    fn test_base_agent_llm() {
341        let mock_agent = MockAgentImpl::new("test", "test description");
342        let llm = Arc::new(MockLLMProvider);
343        let base_agent = BaseAgent::new(mock_agent, llm.clone(), None);
344
345        let agent_llm = base_agent.llm();
346        // The llm() method returns Arc<dyn LLMProvider>, so we just verify it exists
347        assert!(Arc::strong_count(&agent_llm) > 0);
348    }
349}