autoagents_core/
environment.rs1use crate::agent::base::BaseAgent;
2use crate::agent::executor::AgentExecutor;
3use crate::agent::result::AgentRunResult;
4use crate::error::Error;
5use crate::protocol::{AgentID, Event, SessionId, SubmissionId};
6use crate::session::{Session, SessionManager, Task};
7use serde_json::Value;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::sync::{mpsc, Mutex};
11use tokio::task::JoinHandle;
12use tokio_stream::wrappers::ReceiverStream;
13use uuid::Uuid;
14
15#[derive(Debug, thiserror::Error)]
16pub enum EnvironmentError {
17 #[error("Session not found: {0}")]
18 SessionNotFound(SessionId),
19}
20
21#[derive(Clone)]
22pub struct EnvironmentConfig {
23 pub working_dir: PathBuf,
24 pub channel_buffer: usize,
25}
26
27impl Default for EnvironmentConfig {
28 fn default() -> Self {
29 Self {
30 working_dir: std::env::current_dir().unwrap_or_default(),
31 channel_buffer: 100,
32 }
33 }
34}
35
36pub struct Environment {
37 config: EnvironmentConfig,
38 session_manager: SessionManager,
39 default_session: SessionId,
40 event_handler: Option<JoinHandle<()>>,
41}
42
43impl Environment {
44 pub async fn new(config: Option<EnvironmentConfig>) -> Self {
45 let config = config.unwrap_or_default();
46 let session_manager = SessionManager::new();
47 let (default_session_id, _) = session_manager.create_session(config.channel_buffer).await;
48
49 Self {
50 config,
51 session_manager,
52 default_session: default_session_id,
53 event_handler: None,
54 }
55 }
56
57 pub fn config(&self) -> &EnvironmentConfig {
58 &self.config
59 }
60
61 pub async fn get_session(&self, session_id: Option<SessionId>) -> Option<Arc<Mutex<Session>>> {
62 let sid = session_id.unwrap_or(self.default_session);
63 self.session_manager.get_session(&sid).await
64 }
65
66 pub async fn get_session_or_default(
67 &self,
68 session_id: Option<SessionId>,
69 ) -> Result<Arc<Mutex<Session>>, Error> {
70 let sid = session_id.unwrap_or(self.default_session);
71 self.get_session(Some(sid))
72 .await
73 .ok_or_else(|| EnvironmentError::SessionNotFound(sid).into())
74 }
75
76 pub async fn get_session_mut(
77 &self,
78 session_id: Option<SessionId>,
79 ) -> Option<Arc<Mutex<Session>>> {
80 let sid = session_id.unwrap_or(self.default_session);
81 self.session_manager.get_session_mut(&sid).await
82 }
83
84 pub async fn register_agent<E>(
85 &self,
86 agent: BaseAgent<E>,
87 session_id: Option<SessionId>,
88 ) -> Result<AgentID, Error>
89 where
90 E: AgentExecutor,
91 E::Output: Into<Value> + Send,
92 E::Error: std::error::Error + Send + Sync + 'static,
93 {
94 let agent_id = Uuid::new_v4();
95 self.register_agent_with_id(agent_id, agent, session_id)
96 .await?;
97 Ok(agent_id)
98 }
99
100 pub async fn register_agent_with_id<E>(
101 &self,
102 agent_id: AgentID,
103 agent: BaseAgent<E>,
104 session_id: Option<SessionId>,
105 ) -> Result<(), Error>
106 where
107 E: AgentExecutor,
108 E::Output: Into<Value> + Send,
109 E::Error: std::error::Error + Send + Sync + 'static,
110 {
111 let session_arc = self.get_session_or_default(session_id).await?;
112 let session = session_arc.lock().await;
113 session.register_agent_with_id(agent_id, agent).await;
114 Ok(())
115 }
116
117 pub async fn add_task<T: Into<String>>(
118 &self,
119 agent_id: Uuid,
120 task: T,
121 ) -> Result<SubmissionId, Error> {
122 let task = Task::new(task, Some(agent_id));
123 let sub_id = task.submission_id;
124 let session_arc = self.get_session_or_default(None).await?;
125 let session = session_arc.lock().await;
126 session.add_task(task).await;
127 Ok(sub_id)
128 }
129
130 pub async fn run_task(
131 &self,
132 agent_id: AgentID,
133 sub_id: SubmissionId,
134 session_id: Option<SessionId>,
135 ) -> Result<AgentRunResult, Error> {
136 let session_arc = self.get_session_or_default(session_id).await?;
137 let session = session_arc.lock().await;
138 let task = session.get_task(sub_id).await;
139 session.run_task(task, agent_id).await
140 }
141
142 pub async fn run(
143 &self,
144 agent_id: AgentID,
145 session_id: Option<SessionId>,
146 ) -> Result<AgentRunResult, Error> {
147 let session_arc = self.get_session_or_default(session_id).await?;
148 let session = session_arc.lock().await;
149 session.run(agent_id).await
150 }
151
152 pub async fn run_all(
153 &self,
154 agent_id: AgentID,
155 session_id: Option<SessionId>,
156 ) -> Result<Vec<AgentRunResult>, Error> {
157 let session_arc = self.get_session_or_default(session_id).await?;
158 let session = session_arc.lock().await;
159 session.run_all(agent_id).await
160 }
161
162 pub async fn event_sender(
163 &self,
164 session_id: Option<SessionId>,
165 ) -> Result<mpsc::Sender<Event>, Error> {
166 let session_arc = self.get_session_or_default(session_id).await?;
167 let session = session_arc.lock().await;
168 Ok(session.event_sender().clone())
169 }
170
171 pub async fn take_event_receiver(
172 &mut self,
173 session_id: Option<SessionId>,
174 ) -> Option<ReceiverStream<Event>> {
175 if let Some(session_arc_mutex) = self.get_session_mut(session_id).await {
176 let mut session = session_arc_mutex.lock().await;
177 session.take_event_receiver()
178 } else {
179 None
180 }
181 }
182
183 pub async fn shutdown(&mut self) {
184 if let Some(handler) = self.event_handler.take() {
185 handler.abort();
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_environment_config_default() {
196 let config = EnvironmentConfig::default();
197 assert_eq!(config.channel_buffer, 100);
198 }
199}