use crate::error::{Result, RuvLLMError};
use crate::kv_cache::{KvCacheConfig, TwoTierKvCache};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionConfig {
pub max_lifetime_secs: u64,
pub idle_timeout_secs: u64,
pub max_turns: u32,
pub kv_cache: KvCacheConfig,
pub persist: bool,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
max_lifetime_secs: 3600, idle_timeout_secs: 300, max_turns: 100,
kv_cache: KvCacheConfig::default(),
persist: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SessionStatus {
Active,
Paused,
Expired,
Terminated,
}
impl Default for SessionStatus {
fn default() -> Self {
Self::Active
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionMetadata {
pub custom: HashMap<String, serde_json::Value>,
pub user_agent: Option<String>,
pub client_ip: Option<String>,
pub locale: Option<String>,
}
impl Default for SessionMetadata {
fn default() -> Self {
Self {
custom: HashMap::new(),
user_agent: None,
client_ip: None,
locale: None,
}
}
}
#[derive(Debug)]
pub struct Session {
pub id: String,
pub user_id: Option<String>,
pub status: SessionStatus,
pub turn_count: u32,
pub created_at: chrono::DateTime<chrono::Utc>,
pub last_active: chrono::DateTime<chrono::Utc>,
pub active_adapter: Option<Uuid>,
pub metadata: SessionMetadata,
pub context_embedding: Option<Vec<f32>>,
kv_cache: Arc<TwoTierKvCache>,
}
impl Session {
pub fn new(config: &SessionConfig, user_id: Option<&str>) -> Self {
let now = chrono::Utc::now();
Self {
id: Uuid::new_v4().to_string(),
user_id: user_id.map(String::from),
status: SessionStatus::Active,
turn_count: 0,
created_at: now,
last_active: now,
active_adapter: None,
metadata: SessionMetadata::default(),
context_embedding: None,
kv_cache: Arc::new(TwoTierKvCache::new(config.kv_cache.clone())),
}
}
pub fn is_active(&self) -> bool {
self.status == SessionStatus::Active
}
pub fn touch(&mut self) {
self.last_active = chrono::Utc::now();
}
pub fn increment_turn(&mut self) {
self.turn_count += 1;
self.touch();
}
pub fn is_expired(&self, config: &SessionConfig) -> bool {
let now = chrono::Utc::now();
let lifetime = (now - self.created_at).num_seconds() as u64;
if lifetime > config.max_lifetime_secs {
return true;
}
let idle = (now - self.last_active).num_seconds() as u64;
if idle > config.idle_timeout_secs {
return true;
}
if self.turn_count >= config.max_turns {
return true;
}
false
}
pub fn kv_cache(&self) -> &Arc<TwoTierKvCache> {
&self.kv_cache
}
pub fn set_context_embedding(&mut self, embedding: Vec<f32>) {
self.context_embedding = Some(embedding);
}
pub fn set_adapter(&mut self, adapter_id: Option<Uuid>) {
self.active_adapter = adapter_id;
}
pub fn pause(&mut self) {
self.status = SessionStatus::Paused;
}
pub fn resume(&mut self) {
self.status = SessionStatus::Active;
self.touch();
}
pub fn terminate(&mut self) {
self.status = SessionStatus::Terminated;
}
}
pub struct SessionManager {
config: SessionConfig,
sessions: DashMap<String, Arc<parking_lot::RwLock<Session>>>,
user_sessions: DashMap<String, Vec<String>>,
}
impl SessionManager {
pub fn new(config: SessionConfig) -> Self {
Self {
config,
sessions: DashMap::new(),
user_sessions: DashMap::new(),
}
}
pub fn create_session(&self, user_id: Option<&str>) -> Result<Session> {
let session = Session::new(&self.config, user_id);
let session_id = session.id.clone();
if let Some(uid) = user_id {
self.user_sessions
.entry(uid.to_string())
.or_default()
.push(session_id.clone());
}
let session_ref = Arc::new(parking_lot::RwLock::new(session));
self.sessions.insert(session_id.clone(), session_ref);
Ok(self
.sessions
.get(&session_id)
.map(|s| {
let guard = s.read();
Session {
id: guard.id.clone(),
user_id: guard.user_id.clone(),
status: guard.status,
turn_count: guard.turn_count,
created_at: guard.created_at,
last_active: guard.last_active,
active_adapter: guard.active_adapter,
metadata: guard.metadata.clone(),
context_embedding: guard.context_embedding.clone(),
kv_cache: guard.kv_cache.clone(),
}
})
.ok_or_else(|| RuvLLMError::Session("Failed to create session".to_string()))?)
}
pub fn get_session(&self, session_id: &str) -> Result<Option<Session>> {
Ok(self.sessions.get(session_id).map(|s| {
let guard = s.read();
Session {
id: guard.id.clone(),
user_id: guard.user_id.clone(),
status: guard.status,
turn_count: guard.turn_count,
created_at: guard.created_at,
last_active: guard.last_active,
active_adapter: guard.active_adapter,
metadata: guard.metadata.clone(),
context_embedding: guard.context_embedding.clone(),
kv_cache: guard.kv_cache.clone(),
}
}))
}
pub fn update_session<F>(&self, session_id: &str, f: F) -> Result<()>
where
F: FnOnce(&mut Session),
{
if let Some(session) = self.sessions.get(session_id) {
let mut guard = session.write();
f(&mut guard);
Ok(())
} else {
Err(RuvLLMError::NotFound(format!(
"Session not found: {}",
session_id
)))
}
}
pub fn terminate_session(&self, session_id: &str) -> Result<()> {
if let Some(session) = self.sessions.get(session_id) {
let mut guard = session.write();
guard.terminate();
if let Some(uid) = &guard.user_id {
if let Some(mut sessions) = self.user_sessions.get_mut(uid) {
sessions.retain(|s| s != session_id);
}
}
}
self.sessions.remove(session_id);
Ok(())
}
pub fn get_user_sessions(&self, user_id: &str) -> Vec<String> {
self.user_sessions
.get(user_id)
.map(|s| s.clone())
.unwrap_or_default()
}
pub fn cleanup_expired(&self) -> usize {
let mut expired = Vec::new();
for entry in self.sessions.iter() {
let guard = entry.value().read();
if guard.is_expired(&self.config) {
expired.push(guard.id.clone());
}
}
let count = expired.len();
for session_id in expired {
let _ = self.terminate_session(&session_id);
}
count
}
pub fn session_count(&self) -> usize {
self.sessions.len()
}
pub fn list_sessions(&self) -> Vec<String> {
self.sessions.iter().map(|e| e.key().clone()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let config = SessionConfig::default();
let session = Session::new(&config, Some("user-123"));
assert!(session.is_active());
assert_eq!(session.turn_count, 0);
assert_eq!(session.user_id, Some("user-123".to_string()));
}
#[test]
fn test_session_lifecycle() {
let config = SessionConfig::default();
let mut session = Session::new(&config, None);
session.increment_turn();
assert_eq!(session.turn_count, 1);
session.pause();
assert_eq!(session.status, SessionStatus::Paused);
session.resume();
assert_eq!(session.status, SessionStatus::Active);
session.terminate();
assert_eq!(session.status, SessionStatus::Terminated);
}
#[test]
fn test_session_manager() {
let config = SessionConfig::default();
let manager = SessionManager::new(config);
let session = manager.create_session(Some("user-1")).unwrap();
let session_id = session.id.clone();
assert!(manager.get_session(&session_id).unwrap().is_some());
let user_sessions = manager.get_user_sessions("user-1");
assert_eq!(user_sessions.len(), 1);
manager.terminate_session(&session_id).unwrap();
assert!(manager.get_session(&session_id).unwrap().is_none());
}
}