autoagents_core/agent/
context.rs

1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::{ActorMessage, Topic};
3use crate::agent::memory::MemoryProvider;
4use crate::agent::state::AgentState;
5use crate::agent::AgentConfig;
6use crate::protocol::Event;
7use crate::tool::ToolT;
8use autoagents_llm::chat::ChatMessage;
9use autoagents_llm::LLMProvider;
10use std::any::Any;
11use std::sync::Arc;
12#[cfg(not(target_arch = "wasm32"))]
13use tokio::sync::{mpsc, Mutex};
14
15#[cfg(target_arch = "wasm32")]
16use futures::channel::mpsc;
17#[cfg(target_arch = "wasm32")]
18use futures::lock::Mutex;
19
20pub struct Context {
21    llm: Arc<dyn LLMProvider>,
22    messages: Vec<ChatMessage>,
23    memory: Option<Arc<Mutex<Box<dyn MemoryProvider>>>>,
24    tools: Vec<Box<dyn ToolT>>,
25    config: AgentConfig,
26    state: Arc<Mutex<AgentState>>,
27    tx: Option<mpsc::Sender<Event>>,
28    stream: bool,
29}
30
31#[derive(Clone, Debug, thiserror::Error)]
32pub enum ContextError {
33    #[error("Tx value is None, Tx is only set for Actor agents")]
34    EmptyTx,
35    /// Error when sending events
36    #[error("Failed to send event: {0}")]
37    EventSendError(String),
38}
39
40impl Context {
41    pub fn new(llm: Arc<dyn LLMProvider>, tx: Option<mpsc::Sender<Event>>) -> Self {
42        Self {
43            llm,
44            messages: vec![],
45            memory: None,
46            tools: vec![],
47            config: AgentConfig::default(),
48            state: Arc::new(Mutex::new(AgentState::new())),
49            stream: false,
50            tx,
51        }
52    }
53
54    #[cfg(not(target_arch = "wasm32"))]
55    pub async fn publish<M: ActorMessage>(
56        &self,
57        topic: Topic<M>,
58        message: M,
59    ) -> Result<(), ContextError> {
60        self.tx
61            .as_ref()
62            .ok_or(ContextError::EmptyTx)?
63            .send(Event::PublishMessage {
64                topic_name: topic.name().to_string(),
65                message: Arc::new(message) as Arc<dyn Any + Send + Sync>,
66                topic_type: topic.type_id(),
67            })
68            .await
69            .map_err(|e| ContextError::EventSendError(e.to_string()))
70    }
71
72    pub fn with_memory(mut self, memory: Option<Arc<Mutex<Box<dyn MemoryProvider>>>>) -> Self {
73        self.memory = memory;
74        self
75    }
76
77    pub fn with_tools(mut self, tools: Vec<Box<dyn ToolT>>) -> Self {
78        self.tools = tools;
79        self
80    }
81
82    pub fn with_config(mut self, config: AgentConfig) -> Self {
83        self.config = config;
84        self
85    }
86
87    pub fn with_messages(mut self, messages: Vec<ChatMessage>) -> Self {
88        self.messages = messages;
89        self
90    }
91
92    pub fn with_stream(mut self, stream: bool) -> Self {
93        self.stream = stream;
94        self
95    }
96
97    // Getters
98    pub fn llm(&self) -> &Arc<dyn LLMProvider> {
99        &self.llm
100    }
101
102    pub fn messages(&self) -> &[ChatMessage] {
103        &self.messages
104    }
105
106    pub fn memory(&self) -> Option<Arc<Mutex<Box<dyn MemoryProvider>>>> {
107        self.memory.clone()
108    }
109
110    pub fn tools(&self) -> &[Box<dyn ToolT>] {
111        &self.tools
112    }
113
114    pub fn config(&self) -> &AgentConfig {
115        &self.config
116    }
117
118    pub fn state(&self) -> Arc<Mutex<AgentState>> {
119        self.state.clone()
120    }
121
122    pub fn tx(&self) -> Result<mpsc::Sender<Event>, ContextError> {
123        Ok(self.tx.as_ref().ok_or(ContextError::EmptyTx)?.clone())
124    }
125
126    pub fn stream(&self) -> bool {
127        self.stream
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::agent::memory::SlidingWindowMemory;
135    use autoagents_llm::chat::{ChatMessage, ChatMessageBuilder, ChatRole};
136    use autoagents_test_utils::llm::MockLLMProvider;
137    use std::sync::Arc;
138
139    #[test]
140    fn test_context_creation() {
141        let llm = Arc::new(MockLLMProvider);
142        let context = Context::new(llm, None);
143
144        assert!(context.messages.is_empty());
145        assert!(context.memory.is_none());
146        assert!(context.tools.is_empty());
147        assert!(!context.stream);
148    }
149
150    #[test]
151    fn test_context_with_llm_provider() {
152        let llm = Arc::new(MockLLMProvider);
153        let context = Context::new(llm.clone(), None);
154
155        // Verify the LLM provider is set correctly
156        let context_llm = context.llm();
157        assert!(Arc::strong_count(context_llm) > 0);
158    }
159
160    #[test]
161    fn test_context_with_memory() {
162        let llm = Arc::new(MockLLMProvider);
163        let memory = Box::new(SlidingWindowMemory::new(5));
164        let context = Context::new(llm, None).with_memory(Some(Arc::new(Mutex::new(memory))));
165
166        assert!(context.memory().is_some());
167    }
168
169    #[test]
170    fn test_context_with_messages() {
171        let llm = Arc::new(MockLLMProvider);
172        let message = ChatMessage::user().content("Hello".to_string()).build();
173        let context = Context::new(llm, None).with_messages(vec![message]);
174
175        assert_eq!(context.messages().len(), 1);
176        assert_eq!(context.messages()[0].role, ChatRole::User);
177        assert_eq!(context.messages()[0].content, "Hello");
178    }
179
180    #[test]
181    fn test_context_streaming_flag() {
182        let llm = Arc::new(MockLLMProvider);
183        let context = Context::new(llm, None).with_stream(true);
184        assert!(context.stream());
185    }
186
187    #[test]
188    fn test_context_fluent_interface() {
189        let llm = Arc::new(MockLLMProvider);
190        let memory = Box::new(SlidingWindowMemory::new(3));
191        let message = ChatMessageBuilder::new(ChatRole::System)
192            .content("System prompt".to_string())
193            .build();
194
195        let context = Context::new(llm, None)
196            .with_memory(Some(Arc::new(Mutex::new(memory))))
197            .with_messages(vec![message])
198            .with_stream(true);
199
200        assert!(context.memory().is_some());
201        assert_eq!(context.messages().len(), 1);
202        assert!(context.stream());
203    }
204}