autoagents_core/agent/
base.rs

1use crate::agent::config::AgentConfig;
2use crate::agent::memory::MemoryProvider;
3use crate::agent::{output::AgentOutputT, AgentExecutor, Context};
4use crate::protocol::Event;
5use crate::{protocol::ActorID, tool::ToolT};
6use async_trait::async_trait;
7use autoagents_llm::LLMProvider;
8
9use serde_json::Value;
10use std::marker::PhantomData;
11use std::{fmt::Debug, sync::Arc};
12
13#[cfg(target_arch = "wasm32")]
14pub use futures::lock::Mutex;
15#[cfg(not(target_arch = "wasm32"))]
16pub use tokio::sync::Mutex;
17
18#[cfg(target_arch = "wasm32")]
19use futures::channel::mpsc::Sender;
20
21#[cfg(not(target_arch = "wasm32"))]
22use tokio::sync::mpsc::Sender;
23
24use crate::agent::error::RunnableAgentError;
25use crate::agent::hooks::AgentHooks;
26use uuid::Uuid;
27
28/// Core trait that defines agent metadata and behavior
29/// This trait is implemented via the #[agent] macro
30#[async_trait]
31pub trait AgentDeriveT: Send + Sync + 'static + Debug {
32    /// The output type this agent produces
33    type Output: AgentOutputT;
34
35    /// Get the agent's description
36    fn description(&self) -> &'static str;
37
38    fn output_schema(&self) -> Option<Value>;
39
40    /// Get the agent's name
41    fn name(&self) -> &'static str;
42
43    /// Get the tools available to this agent
44    fn tools(&self) -> Vec<Box<dyn ToolT>>;
45}
46
47pub trait AgentType: 'static + Send + Sync {
48    fn type_name() -> &'static str;
49}
50
51/// Base agent type that wraps an AgentDeriveT implementation with additional runtime components
52#[derive(Clone)]
53pub struct BaseAgent<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> {
54    /// The inner agent implementation (from macro)
55    pub(crate) inner: Arc<T>,
56    /// LLM provider for this agent
57    pub(crate) llm: Arc<dyn LLMProvider>,
58    /// Agent ID
59    pub id: ActorID,
60    /// Optional memory provider
61    pub(crate) memory: Option<Arc<Mutex<Box<dyn MemoryProvider>>>>,
62    /// Tx sender
63    pub(crate) tx: Option<Sender<Event>>,
64    //Stream
65    pub(crate) stream: bool,
66    pub(crate) marker: PhantomData<A>,
67}
68
69impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> Debug for BaseAgent<T, A> {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.write_str(format!("A: {} - T: {}", self.inner().name(), A::type_name()).as_str())
72    }
73}
74
75impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> BaseAgent<T, A> {
76    /// Create a new BaseAgent wrapping an AgentDeriveT implementation
77    pub async fn new(
78        inner: T,
79        llm: Arc<dyn LLMProvider>,
80        memory: Option<Box<dyn MemoryProvider>>,
81        tx: Sender<Event>,
82        stream: bool,
83    ) -> Result<Self, RunnableAgentError> {
84        let agent = Self {
85            inner: Arc::new(inner),
86            id: Uuid::new_v4(),
87            llm,
88            tx: Some(tx),
89            memory: memory.map(|m| Arc::new(Mutex::new(m))),
90            stream,
91            marker: PhantomData,
92        };
93
94        //Run Hook
95        agent.inner().on_agent_create().await;
96
97        Ok(agent)
98    }
99
100    pub fn inner(&self) -> Arc<T> {
101        self.inner.clone()
102    }
103
104    /// Get the agent's name
105    pub fn name(&self) -> &'static str {
106        self.inner.name()
107    }
108
109    /// Get the agent's description
110    pub fn description(&self) -> &'static str {
111        self.inner.description()
112    }
113
114    /// Get the tools as Arc-wrapped references
115    pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
116        self.inner.tools()
117    }
118
119    pub fn stream(&self) -> bool {
120        self.stream
121    }
122
123    pub(crate) fn create_context(&self) -> Arc<Context> {
124        Arc::new(
125            Context::new(self.llm(), self.tx.clone())
126                .with_memory(self.memory())
127                .with_tools(self.tools())
128                .with_config(self.agent_config())
129                .with_stream(self.stream()),
130        )
131    }
132
133    pub fn agent_config(&self) -> AgentConfig {
134        let output_schema = self.inner().output_schema();
135        let structured_schema = output_schema.map(|schema| serde_json::from_value(schema).unwrap());
136        AgentConfig {
137            name: self.name().into(),
138            description: self.description().into(),
139            id: self.id,
140            output_schema: structured_schema,
141        }
142    }
143
144    /// Get the LLM provider
145    pub fn llm(&self) -> Arc<dyn LLMProvider> {
146        self.llm.clone()
147    }
148
149    /// Get the memory provider if available
150    pub fn memory(&self) -> Option<Arc<Mutex<Box<dyn MemoryProvider>>>> {
151        self.memory.clone()
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::agent::{AgentConfig, DirectAgent};
159    use crate::tests::agent::MockAgentImpl;
160    use autoagents_llm::chat::StructuredOutputFormat;
161    use autoagents_test_utils::llm::MockLLMProvider;
162    use std::sync::Arc;
163    use tokio::sync::mpsc::{channel, Receiver};
164    use uuid::Uuid;
165
166    #[test]
167    fn test_agent_config_creation() {
168        let config = AgentConfig {
169            name: "test_agent".to_string(),
170            id: Uuid::new_v4(),
171            description: "A test agent".to_string(),
172            output_schema: None,
173        };
174
175        assert_eq!(config.name, "test_agent");
176        assert_eq!(config.description, "A test agent");
177        assert!(config.output_schema.is_none());
178    }
179
180    #[test]
181    fn test_agent_config_with_schema() {
182        let schema = StructuredOutputFormat {
183            name: "TestSchema".to_string(),
184            description: Some("Test schema".to_string()),
185            schema: Some(serde_json::json!({"type": "object"})),
186            strict: Some(true),
187        };
188
189        let config = AgentConfig {
190            name: "test_agent".to_string(),
191            id: Uuid::new_v4(),
192            description: "A test agent".to_string(),
193            output_schema: Some(schema.clone()),
194        };
195
196        assert_eq!(config.name, "test_agent");
197        assert_eq!(config.description, "A test agent");
198        assert!(config.output_schema.is_some());
199        assert_eq!(config.output_schema.unwrap().name, "TestSchema");
200    }
201
202    #[tokio::test]
203    async fn test_base_agent_creation() {
204        let mock_agent = MockAgentImpl::new("test", "test description");
205        let llm = Arc::new(MockLLMProvider);
206        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
207        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
208            .await
209            .unwrap();
210
211        assert_eq!(base_agent.name(), "test");
212        assert_eq!(base_agent.description(), "test description");
213        assert!(base_agent.memory().is_none());
214    }
215
216    #[tokio::test]
217    async fn test_base_agent_with_memory() {
218        let mock_agent = MockAgentImpl::new("test", "test description");
219        let llm = Arc::new(MockLLMProvider);
220        let memory = Box::new(crate::agent::memory::SlidingWindowMemory::new(5));
221        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
222        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, Some(memory), tx, false)
223            .await
224            .unwrap();
225
226        assert_eq!(base_agent.name(), "test");
227        assert_eq!(base_agent.description(), "test description");
228        assert!(base_agent.memory().is_some());
229    }
230
231    #[tokio::test]
232    async fn test_base_agent_inner() {
233        let mock_agent = MockAgentImpl::new("test", "test description");
234        let llm = Arc::new(MockLLMProvider);
235        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
236        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
237            .await
238            .unwrap();
239
240        let inner = base_agent.inner();
241        assert_eq!(inner.name(), "test");
242        assert_eq!(inner.description(), "test description");
243    }
244
245    #[tokio::test]
246    async fn test_base_agent_tools() {
247        let mock_agent = MockAgentImpl::new("test", "test description");
248        let llm = Arc::new(MockLLMProvider);
249        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
250        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
251            .await
252            .unwrap();
253
254        let tools = base_agent.tools();
255        assert!(tools.is_empty());
256    }
257
258    #[tokio::test]
259    async fn test_base_agent_llm() {
260        let mock_agent = MockAgentImpl::new("test", "test description");
261        let llm = Arc::new(MockLLMProvider);
262        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
263        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm.clone(), None, tx, false)
264            .await
265            .unwrap();
266
267        let agent_llm = base_agent.llm();
268        // The llm() method returns Arc<dyn LLMProvider>, so we just verify it exists
269        assert!(Arc::strong_count(&agent_llm) > 0);
270    }
271
272    #[tokio::test]
273    async fn test_base_agent_with_streaming() {
274        let mock_agent = MockAgentImpl::new("streaming_agent", "test streaming agent");
275        let llm = Arc::new(MockLLMProvider);
276        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
277        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, true)
278            .await
279            .unwrap();
280
281        assert_eq!(base_agent.name(), "streaming_agent");
282        assert_eq!(base_agent.description(), "test streaming agent");
283        assert!(base_agent.memory().is_none());
284        assert!(base_agent.stream);
285    }
286}