1use crate::agent::config::AgentConfig;
2use crate::agent::memory::MemoryProvider;
3use crate::agent::{output::AgentOutputT, AgentExecutor, Context};
4use crate::protocol::Event;
5use crate::{protocol::ActorID, tool::ToolT};
6use async_trait::async_trait;
7use autoagents_llm::LLMProvider;
8
9use serde_json::Value;
10use std::marker::PhantomData;
11use std::{fmt::Debug, sync::Arc};
12
13#[cfg(target_arch = "wasm32")]
14pub use futures::lock::Mutex;
15#[cfg(not(target_arch = "wasm32"))]
16pub use tokio::sync::Mutex;
17
18#[cfg(target_arch = "wasm32")]
19use futures::channel::mpsc::Sender;
20
21#[cfg(not(target_arch = "wasm32"))]
22use tokio::sync::mpsc::Sender;
23
24use crate::agent::error::RunnableAgentError;
25use crate::agent::hooks::AgentHooks;
26use uuid::Uuid;
27
28#[async_trait]
31pub trait AgentDeriveT: Send + Sync + 'static + Debug {
32 type Output: AgentOutputT;
34
35 fn description(&self) -> &'static str;
37
38 fn output_schema(&self) -> Option<Value>;
39
40 fn name(&self) -> &'static str;
42
43 fn tools(&self) -> Vec<Box<dyn ToolT>>;
45}
46
47pub trait AgentType: 'static + Send + Sync {
48 fn type_name() -> &'static str;
49}
50
51#[derive(Clone)]
53pub struct BaseAgent<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> {
54 pub(crate) inner: Arc<T>,
56 pub(crate) llm: Arc<dyn LLMProvider>,
58 pub id: ActorID,
60 pub(crate) memory: Option<Arc<Mutex<Box<dyn MemoryProvider>>>>,
62 pub(crate) tx: Option<Sender<Event>>,
64 pub(crate) stream: bool,
66 pub(crate) marker: PhantomData<A>,
67}
68
69impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> Debug for BaseAgent<T, A> {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.write_str(format!("A: {} - T: {}", self.inner().name(), A::type_name()).as_str())
72 }
73}
74
75impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> BaseAgent<T, A> {
76 pub async fn new(
78 inner: T,
79 llm: Arc<dyn LLMProvider>,
80 memory: Option<Box<dyn MemoryProvider>>,
81 tx: Sender<Event>,
82 stream: bool,
83 ) -> Result<Self, RunnableAgentError> {
84 let agent = Self {
85 inner: Arc::new(inner),
86 id: Uuid::new_v4(),
87 llm,
88 tx: Some(tx),
89 memory: memory.map(|m| Arc::new(Mutex::new(m))),
90 stream,
91 marker: PhantomData,
92 };
93
94 agent.inner().on_agent_create().await;
96
97 Ok(agent)
98 }
99
100 pub fn inner(&self) -> Arc<T> {
101 self.inner.clone()
102 }
103
104 pub fn name(&self) -> &'static str {
106 self.inner.name()
107 }
108
109 pub fn description(&self) -> &'static str {
111 self.inner.description()
112 }
113
114 pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
116 self.inner.tools()
117 }
118
119 pub fn stream(&self) -> bool {
120 self.stream
121 }
122
123 pub(crate) fn create_context(&self) -> Arc<Context> {
124 Arc::new(
125 Context::new(self.llm(), self.tx.clone())
126 .with_memory(self.memory())
127 .with_tools(self.tools())
128 .with_config(self.agent_config())
129 .with_stream(self.stream()),
130 )
131 }
132
133 pub fn agent_config(&self) -> AgentConfig {
134 let output_schema = self.inner().output_schema();
135 let structured_schema = output_schema.map(|schema| serde_json::from_value(schema).unwrap());
136 AgentConfig {
137 name: self.name().into(),
138 description: self.description().into(),
139 id: self.id,
140 output_schema: structured_schema,
141 }
142 }
143
144 pub fn llm(&self) -> Arc<dyn LLMProvider> {
146 self.llm.clone()
147 }
148
149 pub fn memory(&self) -> Option<Arc<Mutex<Box<dyn MemoryProvider>>>> {
151 self.memory.clone()
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::agent::{AgentConfig, DirectAgent};
159 use crate::tests::agent::MockAgentImpl;
160 use autoagents_llm::chat::StructuredOutputFormat;
161 use autoagents_test_utils::llm::MockLLMProvider;
162 use std::sync::Arc;
163 use tokio::sync::mpsc::{channel, Receiver};
164 use uuid::Uuid;
165
166 #[test]
167 fn test_agent_config_creation() {
168 let config = AgentConfig {
169 name: "test_agent".to_string(),
170 id: Uuid::new_v4(),
171 description: "A test agent".to_string(),
172 output_schema: None,
173 };
174
175 assert_eq!(config.name, "test_agent");
176 assert_eq!(config.description, "A test agent");
177 assert!(config.output_schema.is_none());
178 }
179
180 #[test]
181 fn test_agent_config_with_schema() {
182 let schema = StructuredOutputFormat {
183 name: "TestSchema".to_string(),
184 description: Some("Test schema".to_string()),
185 schema: Some(serde_json::json!({"type": "object"})),
186 strict: Some(true),
187 };
188
189 let config = AgentConfig {
190 name: "test_agent".to_string(),
191 id: Uuid::new_v4(),
192 description: "A test agent".to_string(),
193 output_schema: Some(schema.clone()),
194 };
195
196 assert_eq!(config.name, "test_agent");
197 assert_eq!(config.description, "A test agent");
198 assert!(config.output_schema.is_some());
199 assert_eq!(config.output_schema.unwrap().name, "TestSchema");
200 }
201
202 #[tokio::test]
203 async fn test_base_agent_creation() {
204 let mock_agent = MockAgentImpl::new("test", "test description");
205 let llm = Arc::new(MockLLMProvider);
206 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
207 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
208 .await
209 .unwrap();
210
211 assert_eq!(base_agent.name(), "test");
212 assert_eq!(base_agent.description(), "test description");
213 assert!(base_agent.memory().is_none());
214 }
215
216 #[tokio::test]
217 async fn test_base_agent_with_memory() {
218 let mock_agent = MockAgentImpl::new("test", "test description");
219 let llm = Arc::new(MockLLMProvider);
220 let memory = Box::new(crate::agent::memory::SlidingWindowMemory::new(5));
221 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
222 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, Some(memory), tx, false)
223 .await
224 .unwrap();
225
226 assert_eq!(base_agent.name(), "test");
227 assert_eq!(base_agent.description(), "test description");
228 assert!(base_agent.memory().is_some());
229 }
230
231 #[tokio::test]
232 async fn test_base_agent_inner() {
233 let mock_agent = MockAgentImpl::new("test", "test description");
234 let llm = Arc::new(MockLLMProvider);
235 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
236 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
237 .await
238 .unwrap();
239
240 let inner = base_agent.inner();
241 assert_eq!(inner.name(), "test");
242 assert_eq!(inner.description(), "test description");
243 }
244
245 #[tokio::test]
246 async fn test_base_agent_tools() {
247 let mock_agent = MockAgentImpl::new("test", "test description");
248 let llm = Arc::new(MockLLMProvider);
249 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
250 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
251 .await
252 .unwrap();
253
254 let tools = base_agent.tools();
255 assert!(tools.is_empty());
256 }
257
258 #[tokio::test]
259 async fn test_base_agent_llm() {
260 let mock_agent = MockAgentImpl::new("test", "test description");
261 let llm = Arc::new(MockLLMProvider);
262 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
263 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm.clone(), None, tx, false)
264 .await
265 .unwrap();
266
267 let agent_llm = base_agent.llm();
268 assert!(Arc::strong_count(&agent_llm) > 0);
270 }
271
272 #[tokio::test]
273 async fn test_base_agent_with_streaming() {
274 let mock_agent = MockAgentImpl::new("streaming_agent", "test streaming agent");
275 let llm = Arc::new(MockLLMProvider);
276 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
277 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, true)
278 .await
279 .unwrap();
280
281 assert_eq!(base_agent.name(), "streaming_agent");
282 assert_eq!(base_agent.description(), "test streaming agent");
283 assert!(base_agent.memory().is_none());
284 assert!(base_agent.stream);
285 }
286}