Skip to main content

autoagents_core/agent/
base.rs

1use crate::agent::config::AgentConfig;
2use crate::agent::memory::MemoryProvider;
3use crate::agent::{AgentExecutor, Context, output::AgentOutputT};
4use crate::protocol::Event;
5use crate::{protocol::ActorID, tool::ToolT};
6use async_trait::async_trait;
7use autoagents_llm::LLMProvider;
8
9use serde_json::Value;
10use std::marker::PhantomData;
11use std::{fmt::Debug, sync::Arc};
12
13#[cfg(target_arch = "wasm32")]
14pub use futures::lock::Mutex;
15#[cfg(not(target_arch = "wasm32"))]
16pub use tokio::sync::Mutex;
17
18#[cfg(target_arch = "wasm32")]
19use futures::channel::mpsc::Sender;
20
21#[cfg(not(target_arch = "wasm32"))]
22use tokio::sync::mpsc::Sender;
23
24use crate::agent::error::RunnableAgentError;
25use crate::agent::hooks::AgentHooks;
26use uuid::Uuid;
27
28/// Core trait that defines agent metadata and behavior
29/// This trait is implemented via the #[agent] macro
30#[async_trait]
31pub trait AgentDeriveT: Send + Sync + 'static + Debug {
32    /// The output type this agent produces
33    type Output: AgentOutputT;
34
35    /// Get the agent's description
36    fn description(&self) -> &'static str;
37
38    // If you provide None then its taken as String output
39    fn output_schema(&self) -> Option<Value>;
40
41    /// Get the agent's name
42    fn name(&self) -> &'static str;
43
44    /// Get the tools available to this agent
45    fn tools(&self) -> Vec<Box<dyn ToolT>>;
46}
47
48pub trait AgentType: 'static + Send + Sync {
49    fn type_name() -> &'static str;
50}
51
52/// Base agent type that wraps an AgentDeriveT implementation with additional runtime components
53#[derive(Clone)]
54pub struct BaseAgent<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync, A: AgentType> {
55    /// The inner agent implementation (from macro)
56    pub(crate) inner: Arc<T>,
57    /// LLM provider for this agent
58    pub(crate) llm: Arc<dyn LLMProvider>,
59    /// Agent ID
60    pub id: ActorID,
61    /// Optional memory provider
62    pub(crate) memory: Option<Arc<Mutex<Box<dyn MemoryProvider>>>>,
63    /// Tx sender
64    pub(crate) tx: Option<Sender<Event>>,
65    //Stream
66    pub(crate) stream: bool,
67    pub(crate) marker: PhantomData<A>,
68}
69
70impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> Debug for BaseAgent<T, A> {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        f.write_str(format!("A: {} - T: {}", self.inner().name(), A::type_name()).as_str())
73    }
74}
75
76impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> BaseAgent<T, A> {
77    /// Create a new BaseAgent wrapping an AgentDeriveT implementation
78    pub async fn new(
79        inner: T,
80        llm: Arc<dyn LLMProvider>,
81        memory: Option<Box<dyn MemoryProvider>>,
82        tx: Sender<Event>,
83        stream: bool,
84    ) -> Result<Self, RunnableAgentError> {
85        let agent = Self {
86            inner: Arc::new(inner),
87            id: Uuid::new_v4(),
88            llm,
89            tx: Some(tx),
90            memory: memory.map(|m| Arc::new(Mutex::new(m))),
91            stream,
92            marker: PhantomData,
93        };
94
95        //Run Hook
96        agent.inner().on_agent_create().await;
97
98        Ok(agent)
99    }
100
101    pub fn inner(&self) -> Arc<T> {
102        self.inner.clone()
103    }
104
105    /// Get the agent's name
106    pub fn name(&self) -> &'static str {
107        self.inner.name()
108    }
109
110    /// Get the agent's description
111    pub fn description(&self) -> &'static str {
112        self.inner.description()
113    }
114
115    /// Get the tools as Arc-wrapped references
116    pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
117        self.inner.tools()
118    }
119
120    pub fn stream(&self) -> bool {
121        self.stream
122    }
123
124    pub(crate) fn create_context(&self) -> Arc<Context> {
125        Arc::new(
126            Context::new(self.llm(), self.tx.clone())
127                .with_memory(self.memory())
128                .with_tools(self.tools())
129                .with_config(self.agent_config())
130                .with_stream(self.stream()),
131        )
132    }
133
134    pub fn agent_config(&self) -> AgentConfig {
135        let output_schema = self.inner().output_schema();
136        let structured_schema =
137            output_schema.and_then(|schema| serde_json::from_value(schema).ok());
138        AgentConfig {
139            name: self.name().into(),
140            description: self.description().into(),
141            id: self.id,
142            output_schema: structured_schema,
143        }
144    }
145
146    /// Get the LLM provider
147    pub fn llm(&self) -> Arc<dyn LLMProvider> {
148        self.llm.clone()
149    }
150
151    /// Get the memory provider if available
152    pub fn memory(&self) -> Option<Arc<Mutex<Box<dyn MemoryProvider>>>> {
153        self.memory.clone()
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::agent::{AgentConfig, DirectAgent};
161    use crate::tests::agent::MockAgentImpl;
162    use autoagents_llm::chat::StructuredOutputFormat;
163    use autoagents_test_utils::llm::MockLLMProvider;
164    use std::sync::Arc;
165    use tokio::sync::mpsc::{Receiver, channel};
166    use uuid::Uuid;
167
168    #[test]
169    fn test_agent_config_creation() {
170        let config = AgentConfig {
171            name: "test_agent".to_string(),
172            id: Uuid::new_v4(),
173            description: "A test agent".to_string(),
174            output_schema: None,
175        };
176
177        assert_eq!(config.name, "test_agent");
178        assert_eq!(config.description, "A test agent");
179        assert!(config.output_schema.is_none());
180    }
181
182    #[test]
183    fn test_agent_config_with_schema() {
184        let schema = StructuredOutputFormat {
185            name: "TestSchema".to_string(),
186            description: Some("Test schema".to_string()),
187            schema: Some(serde_json::json!({"type": "object"})),
188            strict: Some(true),
189        };
190
191        let config = AgentConfig {
192            name: "test_agent".to_string(),
193            id: Uuid::new_v4(),
194            description: "A test agent".to_string(),
195            output_schema: Some(schema.clone()),
196        };
197
198        assert_eq!(config.name, "test_agent");
199        assert_eq!(config.description, "A test agent");
200        assert!(config.output_schema.is_some());
201        assert_eq!(config.output_schema.unwrap().name, "TestSchema");
202    }
203
204    #[tokio::test]
205    async fn test_base_agent_creation() {
206        let mock_agent = MockAgentImpl::new("test", "test description");
207        let llm = Arc::new(MockLLMProvider);
208        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
209        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
210            .await
211            .unwrap();
212
213        assert_eq!(base_agent.name(), "test");
214        assert_eq!(base_agent.description(), "test description");
215        assert!(base_agent.memory().is_none());
216    }
217
218    #[tokio::test]
219    async fn test_base_agent_with_memory() {
220        let mock_agent = MockAgentImpl::new("test", "test description");
221        let llm = Arc::new(MockLLMProvider);
222        let memory = Box::new(crate::agent::memory::SlidingWindowMemory::new(5));
223        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
224        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, Some(memory), tx, false)
225            .await
226            .unwrap();
227
228        assert_eq!(base_agent.name(), "test");
229        assert_eq!(base_agent.description(), "test description");
230        assert!(base_agent.memory().is_some());
231    }
232
233    #[tokio::test]
234    async fn test_base_agent_inner() {
235        let mock_agent = MockAgentImpl::new("test", "test description");
236        let llm = Arc::new(MockLLMProvider);
237        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
238        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
239            .await
240            .unwrap();
241
242        let inner = base_agent.inner();
243        assert_eq!(inner.name(), "test");
244        assert_eq!(inner.description(), "test description");
245    }
246
247    #[tokio::test]
248    async fn test_base_agent_tools() {
249        let mock_agent = MockAgentImpl::new("test", "test description");
250        let llm = Arc::new(MockLLMProvider);
251        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
252        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
253            .await
254            .unwrap();
255
256        let tools = base_agent.tools();
257        assert!(tools.is_empty());
258    }
259
260    #[tokio::test]
261    async fn test_base_agent_llm() {
262        let mock_agent = MockAgentImpl::new("test", "test description");
263        let llm = Arc::new(MockLLMProvider);
264        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
265        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm.clone(), None, tx, false)
266            .await
267            .unwrap();
268
269        let agent_llm = base_agent.llm();
270        // The llm() method returns Arc<dyn LLMProvider>, so we just verify it exists
271        assert!(Arc::strong_count(&agent_llm) > 0);
272    }
273
274    #[tokio::test]
275    async fn test_base_agent_with_streaming() {
276        let mock_agent = MockAgentImpl::new("streaming_agent", "test streaming agent");
277        let llm = Arc::new(MockLLMProvider);
278        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
279        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, true)
280            .await
281            .unwrap();
282
283        assert_eq!(base_agent.name(), "streaming_agent");
284        assert_eq!(base_agent.description(), "test streaming agent");
285        assert!(base_agent.memory().is_none());
286        assert!(base_agent.stream);
287    }
288}