Skip to main content

aster_server/
state.rs

1use aster::execution::manager::AgentManager;
2use aster::scheduler_trait::SchedulerTrait;
3use axum::http::StatusCode;
4use std::collections::{HashMap, HashSet};
5use std::path::PathBuf;
6use std::sync::atomic::AtomicUsize;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10use crate::tunnel::TunnelManager;
11
12#[derive(Clone)]
13pub struct AppState {
14    pub(crate) agent_manager: Arc<AgentManager>,
15    pub recipe_file_hash_map: Arc<Mutex<HashMap<String, PathBuf>>>,
16    pub session_counter: Arc<AtomicUsize>,
17    /// Tracks sessions that have already emitted recipe telemetry to prevent double counting.
18    recipe_session_tracker: Arc<Mutex<HashSet<String>>>,
19    pub tunnel_manager: Arc<TunnelManager>,
20}
21
22impl AppState {
23    pub async fn new() -> anyhow::Result<Arc<AppState>> {
24        let agent_manager = AgentManager::instance().await?;
25        let tunnel_manager = Arc::new(TunnelManager::new());
26
27        Ok(Arc::new(Self {
28            agent_manager,
29            recipe_file_hash_map: Arc::new(Mutex::new(HashMap::new())),
30            session_counter: Arc::new(AtomicUsize::new(0)),
31            recipe_session_tracker: Arc::new(Mutex::new(HashSet::new())),
32            tunnel_manager,
33        }))
34    }
35
36    pub fn scheduler(&self) -> Arc<dyn SchedulerTrait> {
37        self.agent_manager.scheduler()
38    }
39
40    pub async fn set_recipe_file_hash_map(&self, hash_map: HashMap<String, PathBuf>) {
41        let mut map = self.recipe_file_hash_map.lock().await;
42        *map = hash_map;
43    }
44
45    pub async fn mark_recipe_run_if_absent(&self, session_id: &str) -> bool {
46        let mut sessions = self.recipe_session_tracker.lock().await;
47        if sessions.contains(session_id) {
48            false
49        } else {
50            sessions.insert(session_id.to_string());
51            true
52        }
53    }
54
55    pub async fn get_agent(&self, session_id: String) -> anyhow::Result<Arc<aster::agents::Agent>> {
56        self.agent_manager.get_or_create_agent(session_id).await
57    }
58
59    pub async fn get_agent_for_route(
60        &self,
61        session_id: String,
62    ) -> Result<Arc<aster::agents::Agent>, StatusCode> {
63        self.get_agent(session_id).await.map_err(|e| {
64            tracing::error!("Failed to get agent: {}", e);
65            StatusCode::INTERNAL_SERVER_ERROR
66        })
67    }
68}