autoagents_core/
environment.rs1use crate::error::Error;
2use crate::runtime::manager::RuntimeManager;
3use crate::runtime::{Runtime, RuntimeError};
4use crate::utils::BoxEventStream;
5use autoagents_protocol::{Event, RuntimeID};
6use std::path::PathBuf;
7use std::sync::Arc;
8use tokio::task::JoinHandle;
9
10#[derive(Debug, thiserror::Error)]
12pub enum EnvironmentError {
13 #[error("Runtime not found: {0}")]
14 RuntimeNotFound(RuntimeID),
15
16 #[error("Runtime error: {0}")]
17 RuntimeError(#[from] RuntimeError),
18
19 #[error("Error when consuming receiver")]
20 EventError,
21}
22
23#[derive(Clone)]
25pub struct EnvironmentConfig {
26 pub working_dir: PathBuf,
27}
28
29impl Default for EnvironmentConfig {
30 fn default() -> Self {
31 Self {
32 working_dir: std::env::current_dir().unwrap_or_default(),
33 }
34 }
35}
36
37pub struct Environment {
41 config: EnvironmentConfig,
42 runtime_manager: Arc<RuntimeManager>,
43 default_runtime: Option<RuntimeID>,
44 handle: Option<JoinHandle<Result<(), RuntimeError>>>,
45}
46
47impl Environment {
48 pub fn new(config: Option<EnvironmentConfig>) -> Self {
50 let config = config.unwrap_or_default();
51 let runtime_manager = Arc::new(RuntimeManager::new());
52
53 Self {
54 config,
55 runtime_manager,
56 default_runtime: None,
57 handle: None,
58 }
59 }
60
61 pub async fn register_runtime(&mut self, runtime: Arc<dyn Runtime>) -> Result<(), Error> {
64 self.runtime_manager
65 .register_runtime(runtime.clone())
66 .await?;
67 if self.default_runtime.is_none() {
68 self.default_runtime = Some(runtime.id());
69 }
70 Ok(())
71 }
72
73 pub fn config(&self) -> &EnvironmentConfig {
75 &self.config
76 }
77
78 pub async fn get_runtime(&self, runtime_id: &RuntimeID) -> Option<Arc<dyn Runtime>> {
80 self.runtime_manager.get_runtime(runtime_id).await
81 }
82
83 pub async fn get_runtime_or_default(
85 &self,
86 runtime_id: Option<RuntimeID>,
87 ) -> Result<Arc<dyn Runtime>, Error> {
88 let rid = runtime_id.unwrap_or_else(|| self.default_runtime.unwrap());
89 self.get_runtime(&rid)
90 .await
91 .ok_or_else(|| EnvironmentError::RuntimeNotFound(rid).into())
92 }
93
94 pub fn run(&mut self) -> JoinHandle<Result<(), RuntimeError>> {
97 let manager = self.runtime_manager.clone();
98 tokio::spawn(async move { manager.run().await })
100 }
101
102 pub async fn run_background(&mut self) -> Result<(), RuntimeError> {
105 let manager = self.runtime_manager.clone();
106 manager.run_background().await
108 }
109
110 pub async fn take_event_receiver(
113 &mut self,
114 runtime_id: Option<RuntimeID>,
115 ) -> Result<BoxEventStream<Event>, EnvironmentError> {
116 if let Ok(runtime) = self.get_runtime_or_default(runtime_id).await {
117 runtime
118 .take_event_receiver()
119 .await
120 .ok_or_else(|| EnvironmentError::EventError)
121 } else {
122 Err(EnvironmentError::RuntimeNotFound(runtime_id.unwrap()))
123 }
124 }
125
126 pub async fn subscribe_events(
128 &self,
129 runtime_id: Option<RuntimeID>,
130 ) -> Result<BoxEventStream<Event>, EnvironmentError> {
131 let runtime = self
132 .get_runtime_or_default(runtime_id)
133 .await
134 .map_err(|err| match err {
135 Error::EnvironmentError(env_err) => env_err,
136 _ => EnvironmentError::EventError,
137 })?;
138 Ok(runtime.subscribe_events().await)
139 }
140
141 pub async fn shutdown(&mut self) {
143 let _ = self.runtime_manager.stop().await;
144
145 if let Some(handle) = self.handle.take() {
146 let _ = handle.await;
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::runtime::SingleThreadedRuntime;
155 use tempfile::tempdir;
156 use uuid::Uuid;
157
158 #[test]
159 fn test_environment_config_default() {
160 let config = EnvironmentConfig::default();
161 assert_eq!(
162 config.working_dir,
163 std::env::current_dir().unwrap_or_default()
164 );
165 }
166
167 #[test]
168 fn test_environment_config_custom() {
169 let dir = tempdir().expect("Unable to create temp dir");
170 let config = EnvironmentConfig {
171 working_dir: dir.path().to_path_buf(),
172 };
173 assert_eq!(config.working_dir, dir.path().to_path_buf());
174 }
175
176 #[tokio::test]
177 async fn test_environment_get_runtime() {
178 let mut env = Environment::new(None);
179 let runtime = SingleThreadedRuntime::new(None);
180 let runtime_id = runtime.id;
181 env.register_runtime(runtime).await.unwrap();
182
183 let runtime = env.get_runtime(&runtime_id).await;
185
186 assert!(runtime.is_some());
187
188 let non_existent_id = Uuid::new_v4();
190 let runtime = env.get_runtime(&non_existent_id).await;
191 assert!(runtime.is_none());
192 }
193
194 #[tokio::test]
195 async fn test_environment_take_event_receiver() {
196 let mut env = Environment::new(None);
197 let runtime = SingleThreadedRuntime::new(None);
198 let _ = runtime.id;
199 env.register_runtime(runtime).await.unwrap();
200 let receiver = env.take_event_receiver(None).await;
201 assert!(receiver.is_ok());
202
203 let receiver2 = env.take_event_receiver(None).await;
205 assert!(receiver2.is_err());
206 }
207
208 #[tokio::test]
209 async fn test_environment_shutdown() {
210 let mut env = Environment::new(None);
211 env.shutdown().await;
212 }
214
215 #[tokio::test]
216 async fn test_environment_error_runtime_not_found() {
217 let mut env = Environment::new(None);
218 let runtime = SingleThreadedRuntime::new(None);
219 let _ = runtime.id;
220 env.register_runtime(runtime).await.unwrap();
221 let non_existent_id = Uuid::new_v4();
222
223 let result = env.get_runtime_or_default(Some(non_existent_id)).await;
224 assert!(result.is_err());
225
226 assert!(result.is_err());
227 assert!(result.is_err());
229 }
230
231 #[test]
232 fn test_environment_error_display() {
233 let runtime_id = Uuid::new_v4();
234 let error = EnvironmentError::RuntimeNotFound(runtime_id);
235 assert!(error.to_string().contains("Runtime not found"));
236 assert!(error.to_string().contains(&runtime_id.to_string()));
237 }
238}