autoagents_core/
environment.rs

1use crate::error::Error;
2use crate::protocol::{Event, RuntimeID};
3use crate::runtime::manager::RuntimeManager;
4use crate::runtime::{Runtime, RuntimeError};
5use std::path::PathBuf;
6use std::sync::Arc;
7use tokio::task::JoinHandle;
8use tokio_stream::wrappers::ReceiverStream;
9
10#[derive(Debug, thiserror::Error)]
11pub enum EnvironmentError {
12    #[error("Runtime not found: {0}")]
13    RuntimeNotFound(RuntimeID),
14
15    #[error("Runtime error: {0}")]
16    RuntimeError(#[from] RuntimeError),
17}
18
19#[derive(Clone)]
20pub struct EnvironmentConfig {
21    pub working_dir: PathBuf,
22    pub channel_buffer: usize,
23}
24
25impl Default for EnvironmentConfig {
26    fn default() -> Self {
27        Self {
28            working_dir: std::env::current_dir().unwrap_or_default(),
29            channel_buffer: 100,
30        }
31    }
32}
33
34pub struct Environment {
35    config: EnvironmentConfig,
36    runtime_manager: Arc<RuntimeManager>,
37    default_runtime: Option<RuntimeID>,
38    handle: Option<tokio::task::JoinHandle<Result<(), RuntimeError>>>,
39}
40
41impl Environment {
42    pub fn new(config: Option<EnvironmentConfig>) -> Self {
43        let config = config.unwrap_or_default();
44        let runtime_manager = Arc::new(RuntimeManager::new());
45
46        Self {
47            config,
48            runtime_manager,
49            default_runtime: None,
50            handle: None,
51        }
52    }
53
54    pub async fn register_runtime(&mut self, runtime: Arc<dyn Runtime>) -> Result<(), Error> {
55        self.runtime_manager
56            .register_runtime(runtime.clone())
57            .await?;
58        if self.default_runtime.is_none() {
59            self.default_runtime = Some(runtime.id());
60        }
61        Ok(())
62    }
63
64    pub fn config(&self) -> &EnvironmentConfig {
65        &self.config
66    }
67
68    pub async fn get_runtime(&self, runtime_id: Option<RuntimeID>) -> Option<Arc<dyn Runtime>> {
69        let rid = runtime_id.unwrap_or(self.default_runtime?);
70        self.runtime_manager.get_runtime(&rid).await
71    }
72
73    pub async fn get_runtime_or_default(
74        &self,
75        runtime_id: Option<RuntimeID>,
76    ) -> Result<Arc<dyn Runtime>, Error> {
77        let rid = runtime_id.unwrap_or(self.default_runtime.unwrap());
78        self.get_runtime(Some(rid))
79            .await
80            .ok_or_else(|| EnvironmentError::RuntimeNotFound(rid).into())
81    }
82
83    pub fn run(&mut self) -> JoinHandle<Result<(), RuntimeError>> {
84        let manager = self.runtime_manager.clone();
85        // Spawn background task to run the runtimes.
86        let handle = tokio::spawn(async move { manager.run().await });
87        handle
88    }
89
90    pub async fn take_event_receiver(
91        &mut self,
92        runtime_id: Option<RuntimeID>,
93    ) -> Option<ReceiverStream<Event>> {
94        if let Some(runtime) = self.get_runtime(runtime_id).await {
95            runtime.take_event_receiver().await
96        } else {
97            None
98        }
99    }
100
101    pub async fn shutdown(&mut self) {
102        let _ = self.runtime_manager.stop().await;
103
104        if let Some(handle) = self.handle.take() {
105            let _ = handle.await;
106        }
107    }
108}
109
110// #[cfg(test)]
111// mod tests {
112//     use super::*;
113//     use crate::agent::AgentRunResult;
114//     use crate::agent::RunnableAgent;
115//     use crate::memory::MemoryProvider;
116//     use crate::protocol::Event;
117//     use async_trait::async_trait;
118//     use std::sync::Arc;
119//     use tokio::sync::mpsc;
120//     use uuid::Uuid;
121
122//     // Mock agent for testing
123//     #[derive(Debug)]
124//     struct MockAgent {
125//         id: Uuid,
126//         name: String,
127//         should_fail: bool,
128//     }
129
130//     impl MockAgent {
131//         fn new(name: &str) -> Self {
132//             Self {
133//                 id: Uuid::new_v4(),
134//                 name: name.to_string(),
135//                 should_fail: false,
136//             }
137//         }
138
139//         fn new_failing(name: &str) -> Self {
140//             Self {
141//                 id: Uuid::new_v4(),
142//                 name: name.to_string(),
143//                 should_fail: true,
144//             }
145//         }
146//     }
147
148//     #[async_trait]
149//     impl RunnableAgent for MockAgent {
150//         fn name(&self) -> &'static str {
151//             Box::leak(self.name.clone().into_boxed_str())
152//         }
153
154//         fn description(&self) -> &'static str {
155//             "Mock agent for testing"
156//         }
157
158//         fn id(&self) -> Uuid {
159//             self.id
160//         }
161
162//         async fn run(
163//             self: Arc<Self>,
164//             task: crate::runtime::Task,
165//             _tx_event: mpsc::Sender<Event>,
166//         ) -> Result<AgentRunResult, crate::error::Error> {
167//             if self.should_fail {
168//                 Err(crate::error::Error::SessionError(
169//                     crate::runtime::SessionError::EmptyTask,
170//                 ))
171//             } else {
172//                 Ok(AgentRunResult::success(serde_json::json!({
173//                     "response": format!("Processed: {}", task.prompt)
174//                 })))
175//             }
176//         }
177
178//         fn memory(&self) -> Option<Arc<tokio::sync::RwLock<Box<dyn MemoryProvider>>>> {
179//             None
180//         }
181//     }
182
183//     #[test]
184//     fn test_environment_config_default() {
185//         let config = EnvironmentConfig::default();
186//         assert_eq!(config.channel_buffer, 100);
187//         assert_eq!(
188//             config.working_dir,
189//             std::env::current_dir().unwrap_or_default()
190//         );
191//     }
192
193//     #[test]
194//     fn test_environment_config_custom() {
195//         let config = EnvironmentConfig {
196//             working_dir: std::path::PathBuf::from("/tmp"),
197//             channel_buffer: 50,
198//         };
199//         assert_eq!(config.channel_buffer, 50);
200//         assert_eq!(config.working_dir, std::path::PathBuf::from("/tmp"));
201//     }
202
203//     #[tokio::test]
204//     async fn test_environment_new_default() {
205//         let env = Environment::new(None).await;
206//         assert_eq!(env.config().channel_buffer, 100);
207//     }
208
209//     #[tokio::test]
210//     async fn test_environment_new_with_config() {
211//         let config = EnvironmentConfig {
212//             working_dir: std::path::PathBuf::from("/tmp"),
213//             channel_buffer: 50,
214//         };
215//         let env = Environment::new(Some(config)).await;
216//         assert_eq!(env.config().channel_buffer, 50);
217//     }
218
219//     #[tokio::test]
220//     async fn test_environment_get_session() {
221//         let env = Environment::new(None).await;
222
223//         // Test getting default session
224//         let session = env.get_session(None).await;
225//         assert!(session.is_some());
226
227//         // Test getting non-existent session
228//         let non_existent_id = Uuid::new_v4();
229//         let session = env.get_session(Some(non_existent_id)).await;
230//         assert!(session.is_none());
231//     }
232
233//     #[tokio::test]
234//     async fn test_environment_register_agent() {
235//         let env = Environment::new(None).await;
236//         let agent = Arc::new(MockAgent::new("test_agent"));
237
238//         let result = env.register_agent(agent, None).await;
239//         assert!(result.is_ok());
240//     }
241
242//     #[tokio::test]
243//     async fn test_environment_register_agent_with_id() {
244//         let env = Environment::new(None).await;
245//         let agent = Arc::new(MockAgent::new("test_agent"));
246//         let agent_id = Uuid::new_v4();
247
248//         // let result = env.register_agent_with_id(agent_id, agent, None).await;
249//         assert!(result.is_ok());
250//     }
251
252//     #[tokio::test]
253//     async fn test_environment_add_task() {
254//         let env = Environment::new(None).await;
255//         let agent = Arc::new(MockAgent::new("test_agent"));
256//         let agent_id = env.register_agent(agent, None).await.unwrap();
257
258//         let result = env.add_task(agent_id, "Test task").await;
259//         assert!(result.is_ok());
260//     }
261
262//     #[tokio::test]
263//     async fn test_environment_run_task() {
264//         let env = Environment::new(None).await;
265//         let agent = Arc::new(MockAgent::new("test_agent"));
266//         let agent_id = env.register_agent(agent, None).await.unwrap();
267
268//         let sub_id = env.add_task(agent_id, "Test task").await.unwrap();
269//         let result = env.run_task(agent_id, sub_id, None).await;
270//         assert!(result.is_ok());
271
272//         let result = result.unwrap();
273//         assert!(result.success);
274//         assert!(result.output.is_some());
275//     }
276
277//     #[tokio::test]
278//     async fn test_environment_run() {
279//         let env = Environment::new(None).await;
280//         let agent = Arc::new(MockAgent::new("test_agent"));
281//         let agent_id = env.register_agent(agent, None).await.unwrap();
282
283//         env.add_task(agent_id, "Test task").await.unwrap();
284//         // let result = env.run(agent_id, None).await;
285//         assert!(result.is_ok());
286//     }
287
288//     #[tokio::test]
289//     async fn test_environment_run_all() {
290//         let env = Environment::new(None).await;
291//         let agent = Arc::new(MockAgent::new("test_agent"));
292//         let agent_id = env.register_agent(agent, None).await.unwrap();
293
294//         // Add multiple tasks
295//         for i in 1..=3 {
296//             env.add_task(agent_id, format!("Task {}", i)).await.unwrap();
297//         }
298
299//         let results = env.run_all(agent_id, None).await;
300//         assert!(results.is_ok());
301//         assert_eq!(results.unwrap().len(), 3);
302//     }
303
304//     #[tokio::test]
305//     async fn test_environment_event_sender() {
306//         let env = Environment::new(None).await;
307//         let sender = env.event_sender(None).await;
308//         assert!(sender.is_ok());
309//     }
310
311//     #[tokio::test]
312//     async fn test_environment_take_event_receiver() {
313//         let mut env = Environment::new(None).await;
314//         let receiver = env.take_event_receiver(None).await;
315//         assert!(receiver.is_some());
316
317//         // Second call should return None
318//         let receiver2 = env.take_event_receiver(None).await;
319//         assert!(receiver2.is_none());
320//     }
321
322//     #[tokio::test]
323//     async fn test_environment_shutdown() {
324//         let mut env = Environment::new(None).await;
325//         env.shutdown().await;
326//         // Should not panic
327//     }
328
329//     #[tokio::test]
330//     async fn test_environment_with_failing_agent() {
331//         let env = Environment::new(None).await;
332//         let agent = Arc::new(MockAgent::new_failing("failing_agent"));
333//         let agent_id = env.register_agent(agent, None).await.unwrap();
334
335//         env.add_task(agent_id, "Test task").await.unwrap();
336//         // let result = env.run(agent_id, None).await;
337//         assert!(result.is_err());
338//     }
339
340//     #[tokio::test]
341//     async fn test_environment_error_session_not_found() {
342//         let env = Environment::new(None).await;
343//         let non_existent_id = Uuid::new_v4();
344
345//         let result = env.get_session_or_default(Some(non_existent_id)).await;
346//         assert!(result.is_err());
347
348//         assert!(result.is_err());
349//         // Just test that it's an error, not the specific variant
350//         assert!(result.is_err());
351//     }
352
353//     #[test]
354//     fn test_environment_error_display() {
355//         let session_id = Uuid::new_v4();
356//         let error = EnvironmentError::SessionNotFound(session_id);
357//         assert!(error.to_string().contains("Session not found"));
358//         assert!(error.to_string().contains(&session_id.to_string()));
359//     }
360// }