use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use tokio_util::sync::CancellationToken;
use super::LLMSession;
use super::config::LLMSessionConfig;
use crate::client::error::LlmError;
use crate::controller::types::FromLLMPayload;
pub struct LLMSessionManager {
sessions: RwLock<HashMap<i64, Arc<LLMSession>>>,
}
impl LLMSessionManager {
pub fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
}
}
pub async fn create_session(
&self,
config: LLMSessionConfig,
from_llm: mpsc::Sender<FromLLMPayload>,
channel_size: usize,
) -> Result<i64, LlmError> {
let cancel_token = CancellationToken::new();
let session = Arc::new(LLMSession::new(
config,
from_llm,
cancel_token,
channel_size,
)?);
let session_id = session.id();
{
let mut sessions = self.sessions.write().await;
sessions.insert(session_id, Arc::clone(&session));
}
let session_clone = Arc::clone(&session);
tokio::spawn(async move {
session_clone.start().await;
});
tracing::info!(session_id, "Session created and started");
Ok(session_id)
}
pub async fn get_session_by_id(&self, session_id: i64) -> Option<Arc<LLMSession>> {
let sessions = self.sessions.read().await;
sessions.get(&session_id).cloned()
}
pub async fn remove_session(&self, session_id: i64) -> bool {
let session = {
let mut sessions = self.sessions.write().await;
sessions.remove(&session_id)
};
if let Some(session) = session {
session.shutdown();
tracing::info!(session_id, "Session removed");
true
} else {
false
}
}
pub async fn shutdown(&self) {
let sessions: Vec<Arc<LLMSession>> = {
let mut sessions = self.sessions.write().await;
sessions.drain().map(|(_, s)| s).collect()
};
for session in sessions {
session.shutdown();
}
tracing::info!("Session manager shutdown complete");
}
pub async fn session_count(&self) -> usize {
let sessions = self.sessions.read().await;
sessions.len()
}
}
impl Default for LLMSessionManager {
fn default() -> Self {
Self::new()
}
}