autoagents_core/agent/
base.rs1use crate::agent::config::AgentConfig;
2use crate::agent::memory::MemoryProvider;
3use crate::agent::{AgentExecutor, Context, output::AgentOutputT};
4use crate::tool::{ToolT, to_llm_tool};
5use async_trait::async_trait;
6use autoagents_llm::LLMProvider;
7use autoagents_llm::chat::Tool;
8use autoagents_protocol::{ActorID, Event};
9
10use serde_json::Value;
11use std::marker::PhantomData;
12use std::{fmt::Debug, sync::Arc};
13
14#[cfg(target_arch = "wasm32")]
15pub use futures::lock::Mutex;
16#[cfg(not(target_arch = "wasm32"))]
17pub use tokio::sync::Mutex;
18
19#[cfg(target_arch = "wasm32")]
20use futures::channel::mpsc::Sender;
21
22#[cfg(not(target_arch = "wasm32"))]
23use tokio::sync::mpsc::Sender;
24
25use crate::agent::error::RunnableAgentError;
26use crate::agent::hooks::AgentHooks;
27use uuid::Uuid;
28
29#[async_trait]
32pub trait AgentDeriveT: Send + Sync + 'static + Debug {
33 type Output: AgentOutputT;
35
36 fn description(&self) -> &str;
38
39 fn output_schema(&self) -> Option<Value>;
41
42 fn name(&self) -> &str;
44
45 fn tools(&self) -> Vec<Box<dyn ToolT>>;
47}
48
49pub trait AgentType: 'static + Send + Sync {
50 fn type_name() -> &'static str;
51}
52
53#[derive(Clone)]
55pub struct BaseAgent<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync, A: AgentType> {
56 pub(crate) inner: Arc<T>,
58 pub(crate) llm: Arc<dyn LLMProvider>,
60 pub id: ActorID,
62 pub(crate) memory: Option<Arc<Mutex<Box<dyn MemoryProvider>>>>,
64 pub(crate) serialized_tools: Option<Arc<Vec<Tool>>>,
66 pub(crate) tx: Option<Sender<Event>>,
68 pub(crate) stream: bool,
70 pub(crate) marker: PhantomData<A>,
71}
72
73impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> Debug for BaseAgent<T, A> {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.write_str(format!("A: {} - T: {}", self.inner().name(), A::type_name()).as_str())
76 }
77}
78
79impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> BaseAgent<T, A> {
80 pub async fn new(
82 inner: T,
83 llm: Arc<dyn LLMProvider>,
84 memory: Option<Box<dyn MemoryProvider>>,
85 tx: Sender<Event>,
86 stream: bool,
87 ) -> Result<Self, RunnableAgentError> {
88 let tool_defs = inner.tools();
89 let serialized_tools = if tool_defs.is_empty() {
90 None
91 } else {
92 Some(Arc::new(
93 tool_defs.iter().map(to_llm_tool).collect::<Vec<_>>(),
94 ))
95 };
96 let agent = Self {
97 inner: Arc::new(inner),
98 id: Uuid::new_v4(),
99 llm,
100 tx: Some(tx),
101 memory: memory.map(|m| Arc::new(Mutex::new(m))),
102 serialized_tools,
103 stream,
104 marker: PhantomData,
105 };
106
107 agent.inner().on_agent_create().await;
109
110 Ok(agent)
111 }
112
113 pub fn inner(&self) -> Arc<T> {
114 self.inner.clone()
115 }
116
117 pub fn name(&self) -> &str {
119 self.inner.name()
120 }
121
122 pub fn description(&self) -> &str {
124 self.inner.description()
125 }
126
127 pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
129 self.inner.tools()
130 }
131
132 pub fn serialized_tools(&self) -> Option<Arc<Vec<Tool>>> {
133 self.serialized_tools.clone()
134 }
135
136 pub fn stream(&self) -> bool {
137 self.stream
138 }
139
140 pub(crate) fn create_context(&self) -> Arc<Context> {
141 let tools = self.tools();
142 let cached_tools = self
143 .serialized_tools()
144 .filter(|cached| tools_match_cached(&tools, cached));
145 Arc::new(
146 Context::new(self.llm(), self.tx.clone())
147 .with_memory(self.memory())
148 .with_serialized_tools(cached_tools)
149 .with_tools(tools)
150 .with_config(self.agent_config())
151 .with_stream(self.stream()),
152 )
153 }
154
155 pub fn agent_config(&self) -> AgentConfig {
156 let output_schema = self.inner().output_schema();
157 let structured_schema =
158 output_schema.and_then(|schema| serde_json::from_value(schema).ok());
159 AgentConfig {
160 name: self.name().into(),
161 description: self.description().into(),
162 id: self.id,
163 output_schema: structured_schema,
164 }
165 }
166
167 pub fn llm(&self) -> Arc<dyn LLMProvider> {
169 self.llm.clone()
170 }
171
172 pub fn memory(&self) -> Option<Arc<Mutex<Box<dyn MemoryProvider>>>> {
174 self.memory.clone()
175 }
176}
177
178fn tools_match_cached(tools: &[Box<dyn ToolT>], cached: &[Tool]) -> bool {
179 if tools.len() != cached.len() {
180 return false;
181 }
182
183 tools.iter().zip(cached.iter()).all(|(tool, cached_tool)| {
184 cached_tool.tool_type == "function"
185 && cached_tool.function.name == tool.name()
186 && cached_tool.function.description == tool.description()
187 && cached_tool.function.parameters == tool.args_schema()
188 })
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::agent::memory::SlidingWindowMemory;
195 use crate::agent::{AgentConfig, DirectAgent};
196 use crate::tests::{MockAgentImpl, MockLLMProvider};
197 use autoagents_llm::chat::StructuredOutputFormat;
198 use std::sync::Arc;
199 use tokio::sync::mpsc::{Receiver, channel};
200 use uuid::Uuid;
201
202 #[test]
203 fn test_agent_config_with_schema() {
204 let schema = StructuredOutputFormat {
205 name: "TestSchema".to_string(),
206 description: Some("Test schema".to_string()),
207 schema: Some(serde_json::json!({"type": "object"})),
208 strict: Some(true),
209 };
210
211 let config = AgentConfig {
212 name: "test_agent".to_string(),
213 id: Uuid::new_v4(),
214 description: "A test agent".to_string(),
215 output_schema: Some(schema.clone()),
216 };
217
218 assert_eq!(config.name, "test_agent");
219 assert_eq!(config.description, "A test agent");
220 assert!(config.output_schema.is_some());
221 assert_eq!(config.output_schema.unwrap().name, "TestSchema");
222 }
223
224 #[tokio::test]
225 async fn test_base_agent_creation_with_memory_and_stream() {
226 let mock_agent = MockAgentImpl::new("test", "test description");
227 let llm = Arc::new(MockLLMProvider);
228 let memory = Box::new(SlidingWindowMemory::new(5));
229 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
230 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, Some(memory), tx, true)
231 .await
232 .unwrap();
233
234 assert_eq!(base_agent.name(), "test");
235 assert_eq!(base_agent.description(), "test description");
236 assert!(base_agent.memory().is_some());
237 assert!(base_agent.stream);
238 }
239
240 #[tokio::test]
241 async fn test_base_agent_create_context_populates_config() {
242 let mock_agent = MockAgentImpl::new("ctx_agent", "context agent");
243 let llm = Arc::new(MockLLMProvider);
244 let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
245 let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
246 .await
247 .unwrap();
248
249 let context = base_agent.create_context();
250 let config = context.config();
251 assert_eq!(config.name, "ctx_agent");
252 assert_eq!(config.description, "context agent");
253 }
254}