autoagents_core/
environment.rs

1use crate::agent::base::BaseAgent;
2use crate::agent::executor::AgentExecutor;
3use crate::agent::result::AgentRunResult;
4use crate::error::Error;
5use crate::protocol::{AgentID, Event, SessionId, SubmissionId};
6use crate::session::{Session, SessionManager, Task};
7use serde_json::Value;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::sync::{mpsc, Mutex};
11use tokio::task::JoinHandle;
12use tokio_stream::wrappers::ReceiverStream;
13use uuid::Uuid;
14
15#[derive(Debug, thiserror::Error)]
16pub enum EnvironmentError {
17    #[error("Session not found: {0}")]
18    SessionNotFound(SessionId),
19}
20
21#[derive(Clone)]
22pub struct EnvironmentConfig {
23    pub working_dir: PathBuf,
24    pub channel_buffer: usize,
25}
26
27impl Default for EnvironmentConfig {
28    fn default() -> Self {
29        Self {
30            working_dir: std::env::current_dir().unwrap_or_default(),
31            channel_buffer: 100,
32        }
33    }
34}
35
36pub struct Environment {
37    config: EnvironmentConfig,
38    session_manager: SessionManager,
39    default_session: SessionId,
40    event_handler: Option<JoinHandle<()>>,
41}
42
43impl Environment {
44    pub async fn new(config: Option<EnvironmentConfig>) -> Self {
45        let config = config.unwrap_or_default();
46        let session_manager = SessionManager::new();
47        let (default_session_id, _) = session_manager.create_session(config.channel_buffer).await;
48
49        Self {
50            config,
51            session_manager,
52            default_session: default_session_id,
53            event_handler: None,
54        }
55    }
56
57    pub fn config(&self) -> &EnvironmentConfig {
58        &self.config
59    }
60
61    pub async fn get_session(&self, session_id: Option<SessionId>) -> Option<Arc<Mutex<Session>>> {
62        let sid = session_id.unwrap_or(self.default_session);
63        self.session_manager.get_session(&sid).await
64    }
65
66    pub async fn get_session_or_default(
67        &self,
68        session_id: Option<SessionId>,
69    ) -> Result<Arc<Mutex<Session>>, Error> {
70        let sid = session_id.unwrap_or(self.default_session);
71        self.get_session(Some(sid))
72            .await
73            .ok_or_else(|| EnvironmentError::SessionNotFound(sid).into())
74    }
75
76    pub async fn get_session_mut(
77        &self,
78        session_id: Option<SessionId>,
79    ) -> Option<Arc<Mutex<Session>>> {
80        let sid = session_id.unwrap_or(self.default_session);
81        self.session_manager.get_session_mut(&sid).await
82    }
83
84    pub async fn register_agent<E>(
85        &self,
86        agent: BaseAgent<E>,
87        session_id: Option<SessionId>,
88    ) -> Result<AgentID, Error>
89    where
90        E: AgentExecutor,
91        E::Output: Into<Value> + Send,
92        E::Error: std::error::Error + Send + Sync + 'static,
93    {
94        let agent_id = Uuid::new_v4();
95        self.register_agent_with_id(agent_id, agent, session_id)
96            .await?;
97        Ok(agent_id)
98    }
99
100    pub async fn register_agent_with_id<E>(
101        &self,
102        agent_id: AgentID,
103        agent: BaseAgent<E>,
104        session_id: Option<SessionId>,
105    ) -> Result<(), Error>
106    where
107        E: AgentExecutor,
108        E::Output: Into<Value> + Send,
109        E::Error: std::error::Error + Send + Sync + 'static,
110    {
111        let session_arc = self.get_session_or_default(session_id).await?;
112        let session = session_arc.lock().await;
113        session.register_agent_with_id(agent_id, agent).await;
114        Ok(())
115    }
116
117    pub async fn add_task<T: Into<String>>(
118        &self,
119        agent_id: Uuid,
120        task: T,
121    ) -> Result<SubmissionId, Error> {
122        let task = Task::new(task, Some(agent_id));
123        let sub_id = task.submission_id;
124        let session_arc = self.get_session_or_default(None).await?;
125        let session = session_arc.lock().await;
126        session.add_task(task).await;
127        Ok(sub_id)
128    }
129
130    pub async fn run_task(
131        &self,
132        agent_id: AgentID,
133        sub_id: SubmissionId,
134        session_id: Option<SessionId>,
135    ) -> Result<AgentRunResult, Error> {
136        let session_arc = self.get_session_or_default(session_id).await?;
137        let session = session_arc.lock().await;
138        let task = session.get_task(sub_id).await;
139        session.run_task(task, agent_id).await
140    }
141
142    pub async fn run(
143        &self,
144        agent_id: AgentID,
145        session_id: Option<SessionId>,
146    ) -> Result<AgentRunResult, Error> {
147        let session_arc = self.get_session_or_default(session_id).await?;
148        let session = session_arc.lock().await;
149        session.run(agent_id).await
150    }
151
152    pub async fn run_all(
153        &self,
154        agent_id: AgentID,
155        session_id: Option<SessionId>,
156    ) -> Result<Vec<AgentRunResult>, Error> {
157        let session_arc = self.get_session_or_default(session_id).await?;
158        let session = session_arc.lock().await;
159        session.run_all(agent_id).await
160    }
161
162    pub async fn event_sender(
163        &self,
164        session_id: Option<SessionId>,
165    ) -> Result<mpsc::Sender<Event>, Error> {
166        let session_arc = self.get_session_or_default(session_id).await?;
167        let session = session_arc.lock().await;
168        Ok(session.event_sender().clone())
169    }
170
171    pub async fn take_event_receiver(
172        &mut self,
173        session_id: Option<SessionId>,
174    ) -> Option<ReceiverStream<Event>> {
175        if let Some(session_arc_mutex) = self.get_session_mut(session_id).await {
176            let mut session = session_arc_mutex.lock().await;
177            session.take_event_receiver()
178        } else {
179            None
180        }
181    }
182
183    pub async fn shutdown(&mut self) {
184        if let Some(handler) = self.event_handler.take() {
185            handler.abort();
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_environment_config_default() {
196        let config = EnvironmentConfig::default();
197        assert_eq!(config.channel_buffer, 100);
198    }
199}