Skip to main content

aster/execution/
manager.rs

1use crate::agents::extension::PlatformExtensionContext;
2use crate::agents::Agent;
3use crate::config::paths::Paths;
4use crate::config::Config;
5use crate::scheduler::Scheduler;
6use crate::scheduler_trait::SchedulerTrait;
7use anyhow::Result;
8use lru::LruCache;
9use std::num::NonZeroUsize;
10use std::sync::Arc;
11use tokio::sync::{OnceCell, RwLock};
12use tracing::{debug, info};
13
14const DEFAULT_MAX_SESSION: usize = 100;
15
16static AGENT_MANAGER: OnceCell<Arc<AgentManager>> = OnceCell::const_new();
17
18pub struct AgentManager {
19    sessions: Arc<RwLock<LruCache<String, Arc<Agent>>>>,
20    scheduler: Arc<dyn SchedulerTrait>,
21    default_provider: Arc<RwLock<Option<Arc<dyn crate::providers::base::Provider>>>>,
22}
23
24impl AgentManager {
25    #[cfg(test)]
26    pub fn reset_for_test() {
27        unsafe {
28            // Cast away the const to get mutable access
29            // This is safe in test context where we control execution with #[serial]
30            let cell_ptr = &AGENT_MANAGER as *const OnceCell<Arc<AgentManager>>
31                as *mut OnceCell<Arc<AgentManager>>;
32            let _ = (*cell_ptr).take();
33        }
34    }
35
36    async fn new(max_sessions: Option<usize>) -> Result<Self> {
37        let schedule_file_path = Paths::data_dir().join("schedule.json");
38
39        let scheduler = Scheduler::new(schedule_file_path).await?;
40
41        let capacity = NonZeroUsize::new(max_sessions.unwrap_or(DEFAULT_MAX_SESSION))
42            .unwrap_or_else(|| NonZeroUsize::new(100).unwrap());
43
44        let manager = Self {
45            sessions: Arc::new(RwLock::new(LruCache::new(capacity))),
46            scheduler,
47            default_provider: Arc::new(RwLock::new(None)),
48        };
49
50        Ok(manager)
51    }
52
53    pub async fn instance() -> Result<Arc<Self>> {
54        AGENT_MANAGER
55            .get_or_try_init(|| async {
56                let max_sessions = Config::global()
57                    .get_aster_max_active_agents()
58                    .unwrap_or(DEFAULT_MAX_SESSION);
59                let manager = Self::new(Some(max_sessions)).await?;
60                Ok(Arc::new(manager))
61            })
62            .await
63            .cloned()
64    }
65
66    pub fn scheduler(&self) -> Arc<dyn SchedulerTrait> {
67        Arc::clone(&self.scheduler)
68    }
69
70    pub async fn set_default_provider(&self, provider: Arc<dyn crate::providers::base::Provider>) {
71        debug!("Setting default provider on AgentManager");
72        *self.default_provider.write().await = Some(provider);
73    }
74
75    pub async fn get_or_create_agent(&self, session_id: String) -> Result<Arc<Agent>> {
76        {
77            let mut sessions = self.sessions.write().await;
78            if let Some(existing) = sessions.get(&session_id) {
79                return Ok(Arc::clone(existing));
80            }
81        }
82
83        let agent = Arc::new(Agent::new());
84        agent.set_scheduler(Arc::clone(&self.scheduler)).await;
85        agent
86            .extension_manager
87            .set_context(PlatformExtensionContext {
88                session_id: Some(session_id.clone()),
89                extension_manager: Some(Arc::downgrade(&agent.extension_manager)),
90            })
91            .await;
92        if let Some(provider) = &*self.default_provider.read().await {
93            agent
94                .update_provider(Arc::clone(provider), &session_id)
95                .await?;
96        }
97
98        let mut sessions = self.sessions.write().await;
99        if let Some(existing) = sessions.get(&session_id) {
100            Ok(Arc::clone(existing))
101        } else {
102            sessions.put(session_id, agent.clone());
103            Ok(agent)
104        }
105    }
106
107    pub async fn remove_session(&self, session_id: &str) -> Result<()> {
108        let mut sessions = self.sessions.write().await;
109        sessions
110            .pop(session_id)
111            .ok_or_else(|| anyhow::anyhow!("Session {} not found", session_id))?;
112        info!("Removed session {}", session_id);
113        Ok(())
114    }
115
116    pub async fn has_session(&self, session_id: &str) -> bool {
117        self.sessions.read().await.contains(session_id)
118    }
119
120    pub async fn session_count(&self) -> usize {
121        self.sessions.read().await.len()
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use serial_test::serial;
128    use std::sync::Arc;
129
130    use crate::execution::{manager::AgentManager, SessionExecutionMode};
131
132    #[test]
133    fn test_execution_mode_constructors() {
134        assert_eq!(
135            SessionExecutionMode::chat(),
136            SessionExecutionMode::Interactive
137        );
138        assert_eq!(
139            SessionExecutionMode::scheduled(),
140            SessionExecutionMode::Background
141        );
142
143        let parent = "parent-123".to_string();
144        assert_eq!(
145            SessionExecutionMode::task(parent.clone()),
146            SessionExecutionMode::SubTask {
147                parent_session: parent
148            }
149        );
150    }
151
152    #[tokio::test]
153    #[serial]
154    async fn test_session_isolation() {
155        AgentManager::reset_for_test();
156        let manager = AgentManager::instance().await.unwrap();
157
158        let session1 = uuid::Uuid::new_v4().to_string();
159        let session2 = uuid::Uuid::new_v4().to_string();
160
161        let agent1 = manager.get_or_create_agent(session1.clone()).await.unwrap();
162
163        let agent2 = manager.get_or_create_agent(session2.clone()).await.unwrap();
164
165        // Different sessions should have different agents
166        assert!(!Arc::ptr_eq(&agent1, &agent2));
167
168        // Getting the same session should return the same agent
169        let agent1_again = manager.get_or_create_agent(session1).await.unwrap();
170
171        assert!(Arc::ptr_eq(&agent1, &agent1_again));
172
173        AgentManager::reset_for_test();
174    }
175
176    #[tokio::test]
177    #[serial]
178    async fn test_session_limit() {
179        AgentManager::reset_for_test();
180        let manager = AgentManager::instance().await.unwrap();
181
182        let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect();
183
184        for session in &sessions {
185            manager.get_or_create_agent(session.clone()).await.unwrap();
186        }
187
188        // Create a new session after cleanup
189        let new_session = "new-session".to_string();
190        let _new_agent = manager.get_or_create_agent(new_session).await.unwrap();
191
192        assert_eq!(manager.session_count().await, 100);
193    }
194
195    #[tokio::test]
196    #[serial]
197    async fn test_remove_session() {
198        AgentManager::reset_for_test();
199        let manager = AgentManager::instance().await.unwrap();
200        let session = String::from("remove-test");
201
202        manager.get_or_create_agent(session.clone()).await.unwrap();
203        assert!(manager.has_session(&session).await);
204
205        manager.remove_session(&session).await.unwrap();
206        assert!(!manager.has_session(&session).await);
207
208        assert!(manager.remove_session(&session).await.is_err());
209    }
210
211    #[tokio::test]
212    #[serial]
213    async fn test_concurrent_access() {
214        AgentManager::reset_for_test();
215        let manager = AgentManager::instance().await.unwrap();
216        let session = String::from("concurrent-test");
217
218        let mut handles = vec![];
219        for _ in 0..10 {
220            let mgr = Arc::clone(&manager);
221            let sess = session.clone();
222            handles.push(tokio::spawn(async move {
223                mgr.get_or_create_agent(sess).await.unwrap()
224            }));
225        }
226
227        let agents: Vec<_> = futures::future::join_all(handles)
228            .await
229            .into_iter()
230            .map(|r| r.unwrap())
231            .collect();
232
233        for agent in &agents[1..] {
234            assert!(Arc::ptr_eq(&agents[0], agent));
235        }
236
237        assert_eq!(manager.session_count().await, 1);
238    }
239
240    #[tokio::test]
241    #[serial]
242    async fn test_concurrent_session_creation_race_condition() {
243        // Test that concurrent attempts to create the same new session ID
244        // result in only one agent being created (tests double-check pattern)
245        AgentManager::reset_for_test();
246        let manager = AgentManager::instance().await.unwrap();
247        let session_id = String::from("race-condition-test");
248
249        // Spawn multiple tasks trying to create the same NEW session simultaneously
250        let mut handles = vec![];
251        for _ in 0..20 {
252            let sess = session_id.clone();
253            let mgr_clone = Arc::clone(&manager);
254            handles.push(tokio::spawn(async move {
255                mgr_clone.get_or_create_agent(sess).await.unwrap()
256            }));
257        }
258
259        // Collect all agents
260        let agents: Vec<_> = futures::future::join_all(handles)
261            .await
262            .into_iter()
263            .map(|r| r.unwrap())
264            .collect();
265
266        for agent in &agents[1..] {
267            assert!(
268                Arc::ptr_eq(&agents[0], agent),
269                "All concurrent requests should get the same agent"
270            );
271        }
272        assert_eq!(manager.session_count().await, 1);
273    }
274
275    #[tokio::test]
276    #[serial]
277    async fn test_set_default_provider() {
278        use crate::providers::testprovider::TestProvider;
279        use std::sync::Arc;
280
281        AgentManager::reset_for_test();
282        let manager = AgentManager::instance().await.unwrap();
283
284        // Create a test provider for replaying (doesn't need inner provider)
285        let temp_file = format!(
286            "{}/test_provider_{}.json",
287            std::env::temp_dir().display(),
288            std::process::id()
289        );
290
291        // Create an empty test provider (will fail on actual use but that's ok for this test)
292        let test_provider = TestProvider::new_replaying(&temp_file)
293            .unwrap_or_else(|_| TestProvider::new_replaying("/tmp/dummy.json").unwrap());
294
295        manager.set_default_provider(Arc::new(test_provider)).await;
296
297        let session = String::from("provider-test");
298        let _agent = manager.get_or_create_agent(session.clone()).await.unwrap();
299
300        assert!(manager.has_session(&session).await);
301    }
302
303    #[tokio::test]
304    #[serial]
305    async fn test_eviction_updates_last_used() {
306        AgentManager::reset_for_test();
307        // Test that accessing a session updates its last_used timestamp
308        // and affects eviction order
309        let manager = AgentManager::instance().await.unwrap();
310
311        let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect();
312
313        for session in &sessions {
314            manager.get_or_create_agent(session.clone()).await.unwrap();
315            // Small delay to ensure different timestamps
316            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
317        }
318
319        // Access the first session again to update its last_used
320        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
321        manager
322            .get_or_create_agent(sessions[0].clone())
323            .await
324            .unwrap();
325
326        // Now create a 101st session - should evict session2 (least recently used)
327        let session101 = String::from("session-101");
328        manager
329            .get_or_create_agent(session101.clone())
330            .await
331            .unwrap();
332
333        assert!(manager.has_session(&sessions[0]).await);
334        assert!(!manager.has_session(&sessions[1]).await);
335        assert!(manager.has_session(&session101).await);
336    }
337
338    #[tokio::test]
339    #[serial]
340    async fn test_remove_nonexistent_session_error() {
341        // Test that removing a non-existent session returns an error
342        AgentManager::reset_for_test();
343        let manager = AgentManager::instance().await.unwrap();
344        let session = String::from("never-created");
345
346        let result = manager.remove_session(&session).await;
347        assert!(result.is_err());
348        assert!(result.unwrap_err().to_string().contains("not found"));
349    }
350}