Skip to main content

autoagents_core/agent/
base.rs

1use crate::agent::config::AgentConfig;
2use crate::agent::memory::MemoryProvider;
3use crate::agent::{AgentExecutor, Context, output::AgentOutputT};
4use crate::tool::{ToolT, to_llm_tool};
5use async_trait::async_trait;
6use autoagents_llm::LLMProvider;
7use autoagents_llm::chat::Tool;
8use autoagents_protocol::{ActorID, Event};
9
10use serde_json::Value;
11use std::marker::PhantomData;
12use std::{fmt::Debug, sync::Arc};
13
14#[cfg(target_arch = "wasm32")]
15pub use futures::lock::Mutex;
16#[cfg(not(target_arch = "wasm32"))]
17pub use tokio::sync::Mutex;
18
19#[cfg(target_arch = "wasm32")]
20use futures::channel::mpsc::Sender;
21
22#[cfg(not(target_arch = "wasm32"))]
23use tokio::sync::mpsc::Sender;
24
25use crate::agent::error::RunnableAgentError;
26use crate::agent::hooks::AgentHooks;
27use uuid::Uuid;
28
29/// Core trait that defines agent metadata and behavior
30/// This trait is implemented via the #[agent] macro
31#[async_trait]
32pub trait AgentDeriveT: Send + Sync + 'static + Debug {
33    /// The output type this agent produces
34    type Output: AgentOutputT;
35
36    /// Get the agent's description
37    fn description(&self) -> &str;
38
39    // If you provide None then its taken as String output
40    fn output_schema(&self) -> Option<Value>;
41
42    /// Get the agent's name
43    fn name(&self) -> &str;
44
45    /// Get the tools available to this agent
46    fn tools(&self) -> Vec<Box<dyn ToolT>>;
47}
48
49pub trait AgentType: 'static + Send + Sync {
50    fn type_name() -> &'static str;
51}
52
53/// Base agent type that wraps an AgentDeriveT implementation with additional runtime components
54#[derive(Clone)]
55pub struct BaseAgent<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync, A: AgentType> {
56    /// The inner agent implementation (from macro)
57    pub(crate) inner: Arc<T>,
58    /// LLM provider for this agent
59    pub(crate) llm: Arc<dyn LLMProvider>,
60    /// Agent ID
61    pub id: ActorID,
62    /// Optional memory provider
63    pub(crate) memory: Option<Arc<Mutex<Box<dyn MemoryProvider>>>>,
64    /// Cached serialized tool definitions
65    pub(crate) serialized_tools: Option<Arc<Vec<Tool>>>,
66    /// Tx sender
67    pub(crate) tx: Option<Sender<Event>>,
68    //Stream
69    pub(crate) stream: bool,
70    pub(crate) marker: PhantomData<A>,
71}
72
73impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> Debug for BaseAgent<T, A> {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.write_str(format!("A: {} - T: {}", self.inner().name(), A::type_name()).as_str())
76    }
77}
78
79impl<T: AgentDeriveT + AgentExecutor + AgentHooks, A: AgentType> BaseAgent<T, A> {
80    /// Create a new BaseAgent wrapping an AgentDeriveT implementation
81    pub async fn new(
82        inner: T,
83        llm: Arc<dyn LLMProvider>,
84        memory: Option<Box<dyn MemoryProvider>>,
85        tx: Sender<Event>,
86        stream: bool,
87    ) -> Result<Self, RunnableAgentError> {
88        let tool_defs = inner.tools();
89        let serialized_tools = if tool_defs.is_empty() {
90            None
91        } else {
92            Some(Arc::new(
93                tool_defs.iter().map(to_llm_tool).collect::<Vec<_>>(),
94            ))
95        };
96        let agent = Self {
97            inner: Arc::new(inner),
98            id: Uuid::new_v4(),
99            llm,
100            tx: Some(tx),
101            memory: memory.map(|m| Arc::new(Mutex::new(m))),
102            serialized_tools,
103            stream,
104            marker: PhantomData,
105        };
106
107        //Run Hook
108        agent.inner().on_agent_create().await;
109
110        Ok(agent)
111    }
112
113    pub fn inner(&self) -> Arc<T> {
114        self.inner.clone()
115    }
116
117    /// Get the agent's name
118    pub fn name(&self) -> &str {
119        self.inner.name()
120    }
121
122    /// Get the agent's description
123    pub fn description(&self) -> &str {
124        self.inner.description()
125    }
126
127    /// Get the tools as Arc-wrapped references
128    pub fn tools(&self) -> Vec<Box<dyn ToolT>> {
129        self.inner.tools()
130    }
131
132    pub fn serialized_tools(&self) -> Option<Arc<Vec<Tool>>> {
133        self.serialized_tools.clone()
134    }
135
136    pub fn stream(&self) -> bool {
137        self.stream
138    }
139
140    pub(crate) fn create_context(&self) -> Arc<Context> {
141        let tools = self.tools();
142        let cached_tools = self
143            .serialized_tools()
144            .filter(|cached| tools_match_cached(&tools, cached));
145        Arc::new(
146            Context::new(self.llm(), self.tx.clone())
147                .with_memory(self.memory())
148                .with_serialized_tools(cached_tools)
149                .with_tools(tools)
150                .with_config(self.agent_config())
151                .with_stream(self.stream()),
152        )
153    }
154
155    pub fn agent_config(&self) -> AgentConfig {
156        let output_schema = self.inner().output_schema();
157        let structured_schema =
158            output_schema.and_then(|schema| serde_json::from_value(schema).ok());
159        AgentConfig {
160            name: self.name().into(),
161            description: self.description().into(),
162            id: self.id,
163            output_schema: structured_schema,
164        }
165    }
166
167    /// Get the LLM provider
168    pub fn llm(&self) -> Arc<dyn LLMProvider> {
169        self.llm.clone()
170    }
171
172    /// Get the memory provider if available
173    pub fn memory(&self) -> Option<Arc<Mutex<Box<dyn MemoryProvider>>>> {
174        self.memory.clone()
175    }
176}
177
178fn tools_match_cached(tools: &[Box<dyn ToolT>], cached: &[Tool]) -> bool {
179    if tools.len() != cached.len() {
180        return false;
181    }
182
183    tools.iter().zip(cached.iter()).all(|(tool, cached_tool)| {
184        cached_tool.tool_type == "function"
185            && cached_tool.function.name == tool.name()
186            && cached_tool.function.description == tool.description()
187            && cached_tool.function.parameters == tool.args_schema()
188    })
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::agent::memory::SlidingWindowMemory;
195    use crate::agent::{AgentConfig, DirectAgent};
196    use crate::tests::{MockAgentImpl, MockLLMProvider};
197    use autoagents_llm::chat::StructuredOutputFormat;
198    use std::sync::Arc;
199    use tokio::sync::mpsc::{Receiver, channel};
200    use uuid::Uuid;
201
202    #[test]
203    fn test_agent_config_with_schema() {
204        let schema = StructuredOutputFormat {
205            name: "TestSchema".to_string(),
206            description: Some("Test schema".to_string()),
207            schema: Some(serde_json::json!({"type": "object"})),
208            strict: Some(true),
209        };
210
211        let config = AgentConfig {
212            name: "test_agent".to_string(),
213            id: Uuid::new_v4(),
214            description: "A test agent".to_string(),
215            output_schema: Some(schema.clone()),
216        };
217
218        assert_eq!(config.name, "test_agent");
219        assert_eq!(config.description, "A test agent");
220        assert!(config.output_schema.is_some());
221        assert_eq!(config.output_schema.unwrap().name, "TestSchema");
222    }
223
224    #[tokio::test]
225    async fn test_base_agent_creation_with_memory_and_stream() {
226        let mock_agent = MockAgentImpl::new("test", "test description");
227        let llm = Arc::new(MockLLMProvider);
228        let memory = Box::new(SlidingWindowMemory::new(5));
229        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
230        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, Some(memory), tx, true)
231            .await
232            .unwrap();
233
234        assert_eq!(base_agent.name(), "test");
235        assert_eq!(base_agent.description(), "test description");
236        assert!(base_agent.memory().is_some());
237        assert!(base_agent.stream);
238    }
239
240    #[tokio::test]
241    async fn test_base_agent_create_context_populates_config() {
242        let mock_agent = MockAgentImpl::new("ctx_agent", "context agent");
243        let llm = Arc::new(MockLLMProvider);
244        let (tx, _): (Sender<Event>, Receiver<Event>) = channel(32);
245        let base_agent = BaseAgent::<_, DirectAgent>::new(mock_agent, llm, None, tx, false)
246            .await
247            .unwrap();
248
249        let context = base_agent.create_context();
250        let config = context.config();
251        assert_eq!(config.name, "ctx_agent");
252        assert_eq!(config.description, "context agent");
253    }
254}