Skip to main content

agent_core_runtime/controller/session/
manager.rs

1// This implements a session manager that handles multiple sessions
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::{mpsc, RwLock};
7use tokio_util::sync::CancellationToken;
8
9use super::config::LLMSessionConfig;
10use super::LLMSession;
11use crate::client::error::LlmError;
12use crate::controller::types::FromLLMPayload;
13
14/// Manages multiple LLM sessions
15pub struct LLMSessionManager {
16    /// Map of session ID to session instance
17    sessions: RwLock<HashMap<i64, Arc<LLMSession>>>,
18}
19
20impl LLMSessionManager {
21    /// Creates a new session manager
22    pub fn new() -> Self {
23        Self {
24            sessions: RwLock::new(HashMap::new()),
25        }
26    }
27
28    /// Creates a new LLM session and starts it.
29    ///
30    /// # Arguments
31    /// * `config` - Session configuration (includes model, API key, etc.)
32    /// * `from_llm` - Channel sender for responses from the LLM
33    /// * `channel_size` - Buffer size for the session's input channel
34    ///
35    /// # Returns
36    /// The session ID of the newly created session
37    ///
38    /// # Errors
39    /// Returns an error if the session fails to initialize (e.g., TLS setup failure)
40    pub async fn create_session(
41        &self,
42        config: LLMSessionConfig,
43        from_llm: mpsc::Sender<FromLLMPayload>,
44        channel_size: usize,
45    ) -> Result<i64, LlmError> {
46        let cancel_token = CancellationToken::new();
47        let session = Arc::new(LLMSession::new(config, from_llm, cancel_token, channel_size)?);
48        let session_id = session.id();
49
50        // Store the session
51        {
52            let mut sessions = self.sessions.write().await;
53            sessions.insert(session_id, Arc::clone(&session));
54        }
55
56        // Spawn the session's processing loop
57        let session_clone = Arc::clone(&session);
58        tokio::spawn(async move {
59            session_clone.start().await;
60        });
61
62        tracing::info!(session_id, "Session created and started");
63        Ok(session_id)
64    }
65
66    /// Retrieves a session by its ID.
67    ///
68    /// # Returns
69    /// The session if found, None otherwise
70    pub async fn get_session_by_id(&self, session_id: i64) -> Option<Arc<LLMSession>> {
71        let sessions = self.sessions.read().await;
72        sessions.get(&session_id).cloned()
73    }
74
75    /// Removes and shuts down a specific session.
76    ///
77    /// # Arguments
78    /// * `session_id` - The ID of the session to remove
79    ///
80    /// # Returns
81    /// true if the session was found and removed, false otherwise
82    pub async fn remove_session(&self, session_id: i64) -> bool {
83        let session = {
84            let mut sessions = self.sessions.write().await;
85            sessions.remove(&session_id)
86        };
87
88        if let Some(session) = session {
89            session.shutdown();
90            tracing::info!(session_id, "Session removed");
91            true
92        } else {
93            false
94        }
95    }
96
97    /// Shuts down all sessions managed by this manager.
98    /// This is idempotent and safe to call multiple times.
99    pub async fn shutdown(&self) {
100        let sessions: Vec<Arc<LLMSession>> = {
101            let mut sessions = self.sessions.write().await;
102            sessions.drain().map(|(_, s)| s).collect()
103        };
104
105        for session in sessions {
106            session.shutdown();
107        }
108
109        tracing::info!("Session manager shutdown complete");
110    }
111
112    /// Returns the number of active sessions
113    pub async fn session_count(&self) -> usize {
114        let sessions = self.sessions.read().await;
115        sessions.len()
116    }
117}
118
119impl Default for LLMSessionManager {
120    fn default() -> Self {
121        Self::new()
122    }
123}