autoagents_core/
environment.rs

1use crate::error::Error;
2use crate::protocol::{Event, RuntimeID};
3use crate::runtime::manager::RuntimeManager;
4use crate::runtime::{Runtime, RuntimeError};
5use crate::utils::BoxEventStream;
6use std::path::PathBuf;
7use std::sync::Arc;
8use tokio::task::JoinHandle;
9
10#[derive(Debug, thiserror::Error)]
11pub enum EnvironmentError {
12    #[error("Runtime not found: {0}")]
13    RuntimeNotFound(RuntimeID),
14
15    #[error("Runtime error: {0}")]
16    RuntimeError(#[from] RuntimeError),
17
18    #[error("Error when consuming receiver")]
19    EventError,
20}
21
22#[derive(Clone)]
23pub struct EnvironmentConfig {
24    pub working_dir: PathBuf,
25}
26
27impl Default for EnvironmentConfig {
28    fn default() -> Self {
29        Self {
30            working_dir: std::env::current_dir().unwrap_or_default(),
31        }
32    }
33}
34
35pub struct Environment {
36    config: EnvironmentConfig,
37    runtime_manager: Arc<RuntimeManager>,
38    default_runtime: Option<RuntimeID>,
39    handle: Option<JoinHandle<Result<(), RuntimeError>>>,
40}
41
42impl Environment {
43    pub fn new(config: Option<EnvironmentConfig>) -> Self {
44        let config = config.unwrap_or_default();
45        let runtime_manager = Arc::new(RuntimeManager::new());
46
47        Self {
48            config,
49            runtime_manager,
50            default_runtime: None,
51            handle: None,
52        }
53    }
54
55    pub async fn register_runtime(&mut self, runtime: Arc<dyn Runtime>) -> Result<(), Error> {
56        self.runtime_manager
57            .register_runtime(runtime.clone())
58            .await?;
59        if self.default_runtime.is_none() {
60            self.default_runtime = Some(runtime.id());
61        }
62        Ok(())
63    }
64
65    pub fn config(&self) -> &EnvironmentConfig {
66        &self.config
67    }
68
69    pub async fn get_runtime(&self, runtime_id: &RuntimeID) -> Option<Arc<dyn Runtime>> {
70        self.runtime_manager.get_runtime(runtime_id).await
71    }
72
73    pub async fn get_runtime_or_default(
74        &self,
75        runtime_id: Option<RuntimeID>,
76    ) -> Result<Arc<dyn Runtime>, Error> {
77        let rid = runtime_id.unwrap_or(self.default_runtime.unwrap());
78        self.get_runtime(&rid)
79            .await
80            .ok_or_else(|| EnvironmentError::RuntimeNotFound(rid).into())
81    }
82
83    pub fn run(&mut self) -> JoinHandle<Result<(), RuntimeError>> {
84        let manager = self.runtime_manager.clone();
85        // Spawn background task to run the runtimes. This will wait indefinitely
86        let handle = tokio::spawn(async move { manager.run().await });
87        handle
88    }
89
90    pub async fn run_background(&mut self) -> Result<(), RuntimeError> {
91        let manager = self.runtime_manager.clone();
92        // Spawn background task to run the runtimes.
93        manager.run_background().await
94    }
95
96    pub async fn take_event_receiver(
97        &mut self,
98        runtime_id: Option<RuntimeID>,
99    ) -> Result<BoxEventStream<Event>, EnvironmentError> {
100        if let Ok(runtime) = self.get_runtime_or_default(runtime_id).await {
101            runtime
102                .take_event_receiver()
103                .await
104                .ok_or_else(|| EnvironmentError::EventError)
105        } else {
106            Err(EnvironmentError::RuntimeNotFound(runtime_id.unwrap()))
107        }
108    }
109
110    pub async fn shutdown(&mut self) {
111        let _ = self.runtime_manager.stop().await;
112
113        if let Some(handle) = self.handle.take() {
114            let _ = handle.await;
115        }
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use crate::runtime::SingleThreadedRuntime;
122
123    use super::*;
124    use uuid::Uuid;
125
126    #[test]
127    fn test_environment_config_default() {
128        let config = EnvironmentConfig::default();
129        assert_eq!(
130            config.working_dir,
131            std::env::current_dir().unwrap_or_default()
132        );
133    }
134
135    #[test]
136    fn test_environment_config_custom() {
137        let config = EnvironmentConfig {
138            working_dir: std::path::PathBuf::from("/tmp"),
139        };
140        assert_eq!(config.working_dir, std::path::PathBuf::from("/tmp"));
141    }
142
143    #[tokio::test]
144    async fn test_environment_get_runtime() {
145        let mut env = Environment::new(None);
146        let runtime = SingleThreadedRuntime::new(None);
147        let runtime_id = runtime.id;
148        env.register_runtime(runtime).await.unwrap();
149
150        // Test getting default runtime
151        let runtime = env.get_runtime(&runtime_id).await;
152
153        assert!(runtime.is_some());
154
155        // Test getting non-existent runtime
156        let non_existent_id = Uuid::new_v4();
157        let runtime = env.get_runtime(&non_existent_id).await;
158        assert!(runtime.is_none());
159    }
160
161    #[tokio::test]
162    async fn test_environment_take_event_receiver() {
163        let mut env = Environment::new(None);
164        let runtime = SingleThreadedRuntime::new(None);
165        let _ = runtime.id;
166        env.register_runtime(runtime).await.unwrap();
167        let receiver = env.take_event_receiver(None).await;
168        assert!(receiver.is_ok());
169
170        // Second call should return None
171        let receiver2 = env.take_event_receiver(None).await;
172        assert!(receiver2.is_err());
173    }
174
175    #[tokio::test]
176    async fn test_environment_shutdown() {
177        let mut env = Environment::new(None);
178        env.shutdown().await;
179        // Should not panic
180    }
181
182    #[tokio::test]
183    async fn test_environment_error_runtime_not_found() {
184        let mut env = Environment::new(None);
185        let runtime = SingleThreadedRuntime::new(None);
186        let _ = runtime.id;
187        env.register_runtime(runtime).await.unwrap();
188        let non_existent_id = Uuid::new_v4();
189
190        let result = env.get_runtime_or_default(Some(non_existent_id)).await;
191        assert!(result.is_err());
192
193        assert!(result.is_err());
194        // Just test that it's an error, not the specific variant
195        assert!(result.is_err());
196    }
197
198    #[test]
199    fn test_environment_error_display() {
200        let runtime_id = Uuid::new_v4();
201        let error = EnvironmentError::RuntimeNotFound(runtime_id);
202        assert!(error.to_string().contains("Runtime not found"));
203        assert!(error.to_string().contains(&runtime_id.to_string()));
204    }
205}