aster/execution/
manager.rs1use crate::agents::extension::PlatformExtensionContext;
2use crate::agents::Agent;
3use crate::config::paths::Paths;
4use crate::config::Config;
5use crate::scheduler::Scheduler;
6use crate::scheduler_trait::SchedulerTrait;
7use anyhow::Result;
8use lru::LruCache;
9use std::num::NonZeroUsize;
10use std::sync::Arc;
11use tokio::sync::{OnceCell, RwLock};
12use tracing::{debug, info};
13
14const DEFAULT_MAX_SESSION: usize = 100;
15
16static AGENT_MANAGER: OnceCell<Arc<AgentManager>> = OnceCell::const_new();
17
18pub struct AgentManager {
19 sessions: Arc<RwLock<LruCache<String, Arc<Agent>>>>,
20 scheduler: Arc<dyn SchedulerTrait>,
21 default_provider: Arc<RwLock<Option<Arc<dyn crate::providers::base::Provider>>>>,
22}
23
24impl AgentManager {
25 #[cfg(test)]
26 pub fn reset_for_test() {
27 unsafe {
28 let cell_ptr = &AGENT_MANAGER as *const OnceCell<Arc<AgentManager>>
31 as *mut OnceCell<Arc<AgentManager>>;
32 let _ = (*cell_ptr).take();
33 }
34 }
35
36 async fn new(max_sessions: Option<usize>) -> Result<Self> {
37 let schedule_file_path = Paths::data_dir().join("schedule.json");
38
39 let scheduler = Scheduler::new(schedule_file_path).await?;
40
41 let capacity = NonZeroUsize::new(max_sessions.unwrap_or(DEFAULT_MAX_SESSION))
42 .unwrap_or_else(|| NonZeroUsize::new(100).unwrap());
43
44 let manager = Self {
45 sessions: Arc::new(RwLock::new(LruCache::new(capacity))),
46 scheduler,
47 default_provider: Arc::new(RwLock::new(None)),
48 };
49
50 Ok(manager)
51 }
52
53 pub async fn instance() -> Result<Arc<Self>> {
54 AGENT_MANAGER
55 .get_or_try_init(|| async {
56 let max_sessions = Config::global()
57 .get_aster_max_active_agents()
58 .unwrap_or(DEFAULT_MAX_SESSION);
59 let manager = Self::new(Some(max_sessions)).await?;
60 Ok(Arc::new(manager))
61 })
62 .await
63 .cloned()
64 }
65
66 pub fn scheduler(&self) -> Arc<dyn SchedulerTrait> {
67 Arc::clone(&self.scheduler)
68 }
69
70 pub async fn set_default_provider(&self, provider: Arc<dyn crate::providers::base::Provider>) {
71 debug!("Setting default provider on AgentManager");
72 *self.default_provider.write().await = Some(provider);
73 }
74
75 pub async fn get_or_create_agent(&self, session_id: String) -> Result<Arc<Agent>> {
76 {
77 let mut sessions = self.sessions.write().await;
78 if let Some(existing) = sessions.get(&session_id) {
79 return Ok(Arc::clone(existing));
80 }
81 }
82
83 let agent = Arc::new(Agent::new());
84 agent.set_scheduler(Arc::clone(&self.scheduler)).await;
85 agent
86 .extension_manager
87 .set_context(PlatformExtensionContext {
88 session_id: Some(session_id.clone()),
89 extension_manager: Some(Arc::downgrade(&agent.extension_manager)),
90 })
91 .await;
92 if let Some(provider) = &*self.default_provider.read().await {
93 agent
94 .update_provider(Arc::clone(provider), &session_id)
95 .await?;
96 }
97
98 let mut sessions = self.sessions.write().await;
99 if let Some(existing) = sessions.get(&session_id) {
100 Ok(Arc::clone(existing))
101 } else {
102 sessions.put(session_id, agent.clone());
103 Ok(agent)
104 }
105 }
106
107 pub async fn remove_session(&self, session_id: &str) -> Result<()> {
108 let mut sessions = self.sessions.write().await;
109 sessions
110 .pop(session_id)
111 .ok_or_else(|| anyhow::anyhow!("Session {} not found", session_id))?;
112 info!("Removed session {}", session_id);
113 Ok(())
114 }
115
116 pub async fn has_session(&self, session_id: &str) -> bool {
117 self.sessions.read().await.contains(session_id)
118 }
119
120 pub async fn session_count(&self) -> usize {
121 self.sessions.read().await.len()
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use serial_test::serial;
128 use std::sync::Arc;
129
130 use crate::execution::{manager::AgentManager, SessionExecutionMode};
131
132 #[test]
133 fn test_execution_mode_constructors() {
134 assert_eq!(
135 SessionExecutionMode::chat(),
136 SessionExecutionMode::Interactive
137 );
138 assert_eq!(
139 SessionExecutionMode::scheduled(),
140 SessionExecutionMode::Background
141 );
142
143 let parent = "parent-123".to_string();
144 assert_eq!(
145 SessionExecutionMode::task(parent.clone()),
146 SessionExecutionMode::SubTask {
147 parent_session: parent
148 }
149 );
150 }
151
152 #[tokio::test]
153 #[serial]
154 async fn test_session_isolation() {
155 AgentManager::reset_for_test();
156 let manager = AgentManager::instance().await.unwrap();
157
158 let session1 = uuid::Uuid::new_v4().to_string();
159 let session2 = uuid::Uuid::new_v4().to_string();
160
161 let agent1 = manager.get_or_create_agent(session1.clone()).await.unwrap();
162
163 let agent2 = manager.get_or_create_agent(session2.clone()).await.unwrap();
164
165 assert!(!Arc::ptr_eq(&agent1, &agent2));
167
168 let agent1_again = manager.get_or_create_agent(session1).await.unwrap();
170
171 assert!(Arc::ptr_eq(&agent1, &agent1_again));
172
173 AgentManager::reset_for_test();
174 }
175
176 #[tokio::test]
177 #[serial]
178 async fn test_session_limit() {
179 AgentManager::reset_for_test();
180 let manager = AgentManager::instance().await.unwrap();
181
182 let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect();
183
184 for session in &sessions {
185 manager.get_or_create_agent(session.clone()).await.unwrap();
186 }
187
188 let new_session = "new-session".to_string();
190 let _new_agent = manager.get_or_create_agent(new_session).await.unwrap();
191
192 assert_eq!(manager.session_count().await, 100);
193 }
194
195 #[tokio::test]
196 #[serial]
197 async fn test_remove_session() {
198 AgentManager::reset_for_test();
199 let manager = AgentManager::instance().await.unwrap();
200 let session = String::from("remove-test");
201
202 manager.get_or_create_agent(session.clone()).await.unwrap();
203 assert!(manager.has_session(&session).await);
204
205 manager.remove_session(&session).await.unwrap();
206 assert!(!manager.has_session(&session).await);
207
208 assert!(manager.remove_session(&session).await.is_err());
209 }
210
211 #[tokio::test]
212 #[serial]
213 async fn test_concurrent_access() {
214 AgentManager::reset_for_test();
215 let manager = AgentManager::instance().await.unwrap();
216 let session = String::from("concurrent-test");
217
218 let mut handles = vec![];
219 for _ in 0..10 {
220 let mgr = Arc::clone(&manager);
221 let sess = session.clone();
222 handles.push(tokio::spawn(async move {
223 mgr.get_or_create_agent(sess).await.unwrap()
224 }));
225 }
226
227 let agents: Vec<_> = futures::future::join_all(handles)
228 .await
229 .into_iter()
230 .map(|r| r.unwrap())
231 .collect();
232
233 for agent in &agents[1..] {
234 assert!(Arc::ptr_eq(&agents[0], agent));
235 }
236
237 assert_eq!(manager.session_count().await, 1);
238 }
239
240 #[tokio::test]
241 #[serial]
242 async fn test_concurrent_session_creation_race_condition() {
243 AgentManager::reset_for_test();
246 let manager = AgentManager::instance().await.unwrap();
247 let session_id = String::from("race-condition-test");
248
249 let mut handles = vec![];
251 for _ in 0..20 {
252 let sess = session_id.clone();
253 let mgr_clone = Arc::clone(&manager);
254 handles.push(tokio::spawn(async move {
255 mgr_clone.get_or_create_agent(sess).await.unwrap()
256 }));
257 }
258
259 let agents: Vec<_> = futures::future::join_all(handles)
261 .await
262 .into_iter()
263 .map(|r| r.unwrap())
264 .collect();
265
266 for agent in &agents[1..] {
267 assert!(
268 Arc::ptr_eq(&agents[0], agent),
269 "All concurrent requests should get the same agent"
270 );
271 }
272 assert_eq!(manager.session_count().await, 1);
273 }
274
275 #[tokio::test]
276 #[serial]
277 async fn test_set_default_provider() {
278 use crate::providers::testprovider::TestProvider;
279 use std::sync::Arc;
280
281 AgentManager::reset_for_test();
282 let manager = AgentManager::instance().await.unwrap();
283
284 let temp_file = format!(
286 "{}/test_provider_{}.json",
287 std::env::temp_dir().display(),
288 std::process::id()
289 );
290
291 let test_provider = TestProvider::new_replaying(&temp_file)
293 .unwrap_or_else(|_| TestProvider::new_replaying("/tmp/dummy.json").unwrap());
294
295 manager.set_default_provider(Arc::new(test_provider)).await;
296
297 let session = String::from("provider-test");
298 let _agent = manager.get_or_create_agent(session.clone()).await.unwrap();
299
300 assert!(manager.has_session(&session).await);
301 }
302
303 #[tokio::test]
304 #[serial]
305 async fn test_eviction_updates_last_used() {
306 AgentManager::reset_for_test();
307 let manager = AgentManager::instance().await.unwrap();
310
311 let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect();
312
313 for session in &sessions {
314 manager.get_or_create_agent(session.clone()).await.unwrap();
315 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
317 }
318
319 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
321 manager
322 .get_or_create_agent(sessions[0].clone())
323 .await
324 .unwrap();
325
326 let session101 = String::from("session-101");
328 manager
329 .get_or_create_agent(session101.clone())
330 .await
331 .unwrap();
332
333 assert!(manager.has_session(&sessions[0]).await);
334 assert!(!manager.has_session(&sessions[1]).await);
335 assert!(manager.has_session(&session101).await);
336 }
337
338 #[tokio::test]
339 #[serial]
340 async fn test_remove_nonexistent_session_error() {
341 AgentManager::reset_for_test();
343 let manager = AgentManager::instance().await.unwrap();
344 let session = String::from("never-created");
345
346 let result = manager.remove_session(&session).await;
347 assert!(result.is_err());
348 assert!(result.unwrap_err().to_string().contains("not found"));
349 }
350}