1use crate::agent::config::AgentConfig;
2use crate::agent::memory::MemoryProvider;
3use crate::agent::{AgentExecutor, Context, output::AgentOutputT};
4use crate::protocol::Event;
5use crate::{protocol::ActorID, tool::ToolT};
6use async_trait::async_trait;
7use autoagents_llm::LLMProvider;
8
9use serde_json::Value;
10use std::marker::PhantomData;
11use std::{fmt::Debug, sync::Arc};
12
13#[cfg(target_arch = "wasm32")]
14pub use futures::lock::Mutex;
15#[cfg(not(target_arch = "wasm32"))]
16pub use tokio::sync::Mutex;
17
18#[cfg(target_arch = "wasm32")]
19use futures::channel::mpsc::Sender;
20
21#[cfg(not(target_arch = "wasm32"))]
22use tokio::sync::mpsc::Sender;
23
24use crate::agent::error::RunnableAgentError;
25use crate::agent::hooks::AgentHooks;
26use uuid::Uuid;
27
28#[async_trait]
31pub trait AgentDeriveT: Send + Sync + 'static + Debug {
32 type Output: AgentOutputT;
34
35 fn description(&self) -> &'static str;
37
38 fn output_schema(&self) -> Option<Value>;
40
41 fn name(&self) -> &'static str;
43
44 fn tools(&self) -> Vec<Box<dyn ToolT>>;
46}
47
48pub trait AgentType: 'static + Send + Sync {
49 fn type_name() -> &'static str;
50}
51
52#[derive(Clone)]
54pub struct BaseAgent<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync, A: AgentType> {
55 pub(crate) inner: Arc<T>,
57 pub(crate) llm: Arc<dyn LLMProvider>,
59 pub id: ActorID,
61 pub(crate) memory: Option<Arc<Mutex<Box<dyn MemoryProvider>>>>,
63 pub(crate) tx: Option<Sender<Event>>,
65 pub(crate) stream: bool,
67 pub(crate) marker: PhantomData<A>,
68}
69
70impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> Debug for BaseAgent<T, A> {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.write_str(format!("A: {} - T: {}", self.inner().name(), A::type_name()).as_str())
73 }
74}
75
76impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> BaseAgent<T, A> {
77 pub async fn new(
79 inner: T,
80 llm: Arc<dyn LLMProvider>,
81 memory: Option<Box<dyn MemoryProvider>>,
82 tx: Sender<Event>,
83 stream: bool,
84 ) -> Result<Self, RunnableAgentError> {
85 let agent = Self {
86 inner: Arc::new(inner),
87 id: Uuid::new_v4(),
88 llm,
89 tx: Some(tx),
90 memory: memory.map(|m| Arc::new(Mutex::new(m))),
91 stream,
92 marker: PhantomData,
93 };
94
95 agent.inner().on_agent_create().await;
97
98 Ok(agent)
99 }
100
101 pub fn inner(&self) -> Arc<T> {
102 self.inner.clone()
103 }
104
105 pub fn name(&self) -> &'static str {
107 self.inner.name()
108 }
109
110 pub fn description(&self) -> &'static str {
112 self.inner.description()
113 }
114
115 pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
117 self.inner.tools()
118 }
119
120 pub fn stream(&self) -> bool {
121 self.stream
122 }
123
124 pub(crate) fn create_context(&self) -> Arc<Context> {
125 Arc::new(
126 Context::new(self.llm(), self.tx.clone())
127 .with_memory(self.memory())
128 .with_tools(self.tools())
129 .with_config(self.agent_config())
130 .with_stream(self.stream()),
131 )
132 }
133
134 pub fn agent_config(&self) -> AgentConfig {
135 let output_schema = self.inner().output_schema();
136 let structured_schema =
137 output_schema.and_then(|schema| serde_json::from_value(schema).ok());
138 AgentConfig {
139 name: self.name().into(),
140 description: self.description().into(),
141 id: self.id,
142 output_schema: structured_schema,
143 }
144 }
145
146 pub fn llm(&self) -> Arc<dyn LLMProvider> {
148 self.llm.clone()
149 }
150
151 pub fn memory(&self) -> Option<Arc<Mutex<Box<dyn MemoryProvider>>>> {
153 self.memory.clone()
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::agent::{AgentConfig, DirectAgent};
161 use crate::tests::agent::MockAgentImpl;
162 use autoagents_llm::chat::StructuredOutputFormat;
163 use autoagents_test_utils::llm::MockLLMProvider;
164 use std::sync::Arc;
165 use tokio::sync::mpsc::{Receiver, channel};
166 use uuid::Uuid;
167
168 #[test]
169 fn test_agent_config_creation() {
170 let config = AgentConfig {
171 name: "test_agent".to_string(),
172 id: Uuid::new_v4(),
173 description: "A test agent".to_string(),
174 output_schema: None,
175 };
176
177 assert_eq!(config.name, "test_agent");
178 assert_eq!(config.description, "A test agent");
179 assert!(config.output_schema.is_none());
180 }
181
182 #[test]
183 fn test_agent_config_with_schema() {
184 let schema = StructuredOutputFormat {
185 name: "TestSchema".to_string(),
186 description: Some("Test schema".to_string()),
187 schema: Some(serde_json::json!({"type": "object"})),
188 strict: Some(true),
189 };
190
191 let config = AgentConfig {
192 name: "test_agent".to_string(),
193 id: Uuid::new_v4(),
194 description: "A test agent".to_string(),
195 output_schema: Some(schema.clone()),
196 };
197
198 assert_eq!(config.name, "test_agent");
199 assert_eq!(config.description, "A test agent");
200 assert!(config.output_schema.is_some());
201 assert_eq!(config.output_schema.unwrap().name, "TestSchema");
202 }
203
204 #[tokio::test]
205 async fn test_base_agent_creation() {
206 let mock_agent = MockAgentImpl::new("test", "test description");
207 let llm = Arc::new(MockLLMProvider);
208 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
209 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
210 .await
211 .unwrap();
212
213 assert_eq!(base_agent.name(), "test");
214 assert_eq!(base_agent.description(), "test description");
215 assert!(base_agent.memory().is_none());
216 }
217
218 #[tokio::test]
219 async fn test_base_agent_with_memory() {
220 let mock_agent = MockAgentImpl::new("test", "test description");
221 let llm = Arc::new(MockLLMProvider);
222 let memory = Box::new(crate::agent::memory::SlidingWindowMemory::new(5));
223 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
224 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, Some(memory), tx, false)
225 .await
226 .unwrap();
227
228 assert_eq!(base_agent.name(), "test");
229 assert_eq!(base_agent.description(), "test description");
230 assert!(base_agent.memory().is_some());
231 }
232
233 #[tokio::test]
234 async fn test_base_agent_inner() {
235 let mock_agent = MockAgentImpl::new("test", "test description");
236 let llm = Arc::new(MockLLMProvider);
237 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
238 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
239 .await
240 .unwrap();
241
242 let inner = base_agent.inner();
243 assert_eq!(inner.name(), "test");
244 assert_eq!(inner.description(), "test description");
245 }
246
247 #[tokio::test]
248 async fn test_base_agent_tools() {
249 let mock_agent = MockAgentImpl::new("test", "test description");
250 let llm = Arc::new(MockLLMProvider);
251 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
252 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
253 .await
254 .unwrap();
255
256 let tools = base_agent.tools();
257 assert!(tools.is_empty());
258 }
259
260 #[tokio::test]
261 async fn test_base_agent_llm() {
262 let mock_agent = MockAgentImpl::new("test", "test description");
263 let llm = Arc::new(MockLLMProvider);
264 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
265 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm.clone(), None, tx, false)
266 .await
267 .unwrap();
268
269 let agent_llm = base_agent.llm();
270 assert!(Arc::strong_count(&agent_llm) > 0);
272 }
273
274 #[tokio::test]
275 async fn test_base_agent_with_streaming() {
276 let mock_agent = MockAgentImpl::new("streaming_agent", "test streaming agent");
277 let llm = Arc::new(MockLLMProvider);
278 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
279 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, true)
280 .await
281 .unwrap();
282
283 assert_eq!(base_agent.name(), "streaming_agent");
284 assert_eq!(base_agent.description(), "test streaming agent");
285 assert!(base_agent.memory().is_none());
286 assert!(base_agent.stream);
287 }
288}