Skip to main content

agent_air_runtime/controller/session/
manager.rs

1// This implements a session manager that handles multiple sessions
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::{RwLock, mpsc};
7use tokio_util::sync::CancellationToken;
8
9use super::LLMSession;
10use super::config::LLMSessionConfig;
11use crate::client::error::LlmError;
12use crate::controller::types::FromLLMPayload;
13
14/// Manages multiple LLM sessions
15pub struct LLMSessionManager {
16    /// Map of session ID to session instance
17    sessions: RwLock<HashMap<i64, Arc<LLMSession>>>,
18}
19
20impl LLMSessionManager {
21    /// Creates a new session manager
22    pub fn new() -> Self {
23        Self {
24            sessions: RwLock::new(HashMap::new()),
25        }
26    }
27
28    /// Creates a new LLM session and starts it.
29    ///
30    /// # Arguments
31    /// * `config` - Session configuration (includes model, API key, etc.)
32    /// * `from_llm` - Channel sender for responses from the LLM
33    /// * `channel_size` - Buffer size for the session's input channel
34    ///
35    /// # Returns
36    /// The session ID of the newly created session
37    ///
38    /// # Errors
39    /// Returns an error if the session fails to initialize (e.g., TLS setup failure)
40    pub async fn create_session(
41        &self,
42        config: LLMSessionConfig,
43        from_llm: mpsc::Sender<FromLLMPayload>,
44        channel_size: usize,
45    ) -> Result<i64, LlmError> {
46        let cancel_token = CancellationToken::new();
47        let session = Arc::new(LLMSession::new(
48            config,
49            from_llm,
50            cancel_token,
51            channel_size,
52        )?);
53        let session_id = session.id();
54
55        // Store the session
56        {
57            let mut sessions = self.sessions.write().await;
58            sessions.insert(session_id, Arc::clone(&session));
59        }
60
61        // Spawn the session's processing loop
62        let session_clone = Arc::clone(&session);
63        tokio::spawn(async move {
64            session_clone.start().await;
65        });
66
67        tracing::info!(session_id, "Session created and started");
68        Ok(session_id)
69    }
70
71    /// Retrieves a session by its ID.
72    ///
73    /// # Returns
74    /// The session if found, None otherwise
75    pub async fn get_session_by_id(&self, session_id: i64) -> Option<Arc<LLMSession>> {
76        let sessions = self.sessions.read().await;
77        sessions.get(&session_id).cloned()
78    }
79
80    /// Removes and shuts down a specific session.
81    ///
82    /// # Arguments
83    /// * `session_id` - The ID of the session to remove
84    ///
85    /// # Returns
86    /// true if the session was found and removed, false otherwise
87    pub async fn remove_session(&self, session_id: i64) -> bool {
88        let session = {
89            let mut sessions = self.sessions.write().await;
90            sessions.remove(&session_id)
91        };
92
93        if let Some(session) = session {
94            session.shutdown();
95            tracing::info!(session_id, "Session removed");
96            true
97        } else {
98            false
99        }
100    }
101
102    /// Shuts down all sessions managed by this manager.
103    /// This is idempotent and safe to call multiple times.
104    pub async fn shutdown(&self) {
105        let sessions: Vec<Arc<LLMSession>> = {
106            let mut sessions = self.sessions.write().await;
107            sessions.drain().map(|(_, s)| s).collect()
108        };
109
110        for session in sessions {
111            session.shutdown();
112        }
113
114        tracing::info!("Session manager shutdown complete");
115    }
116
117    /// Returns the number of active sessions
118    pub async fn session_count(&self) -> usize {
119        let sessions = self.sessions.read().await;
120        sessions.len()
121    }
122}
123
124impl Default for LLMSessionManager {
125    fn default() -> Self {
126        Self::new()
127    }
128}