pub mod history;
pub mod media;
pub mod repair;
pub mod types;
pub use history::ConversationHistory;
pub use repair::{repair_messages, RepairStats};
pub use types::{ContentPart, ImageSource, Message, Role, Session, ToolCall};
use crate::config::Config;
use crate::error::Result;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::warn;
pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, Session>>>,
storage_path: Option<PathBuf>,
}
impl SessionManager {
pub fn new() -> Result<Self> {
let storage_path = Config::dir().join("sessions");
std::fs::create_dir_all(&storage_path)?;
Ok(Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
storage_path: Some(storage_path),
})
}
pub fn new_memory() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
storage_path: None,
}
}
pub fn with_path(path: PathBuf) -> Result<Self> {
std::fs::create_dir_all(&path)?;
Ok(Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
storage_path: Some(path),
})
}
pub async fn get_or_create(&self, key: &str) -> Result<Session> {
{
let sessions = self.sessions.read().await;
if let Some(session) = sessions.get(key) {
return Ok(session.clone());
}
}
if let Some(ref storage_path) = self.storage_path {
let file_path = storage_path.join(format!("{}.json", Self::sanitize_key(key)));
if file_path.exists() {
let content = tokio::fs::read_to_string(&file_path).await?;
let mut session: Session = serde_json::from_str(&content)?;
self.maybe_repair_loaded_session(&mut session, "get_or_create");
let mut sessions = self.sessions.write().await;
sessions.insert(key.to_string(), session.clone());
return Ok(session);
}
}
let session = Session::new(key);
let mut sessions = self.sessions.write().await;
sessions.insert(key.to_string(), session.clone());
Ok(session)
}
pub async fn get(&self, key: &str) -> Result<Option<Session>> {
{
let sessions = self.sessions.read().await;
if let Some(session) = sessions.get(key) {
return Ok(Some(session.clone()));
}
}
if let Some(ref storage_path) = self.storage_path {
let file_path = storage_path.join(format!("{}.json", Self::sanitize_key(key)));
if file_path.exists() {
let content = tokio::fs::read_to_string(&file_path).await?;
let mut session: Session = serde_json::from_str(&content)?;
self.maybe_repair_loaded_session(&mut session, "get");
let mut sessions = self.sessions.write().await;
sessions.insert(key.to_string(), session.clone());
return Ok(Some(session));
}
}
Ok(None)
}
pub async fn save(&self, session: &Session) -> Result<()> {
{
let mut sessions = self.sessions.write().await;
sessions.insert(session.key.clone(), session.clone());
}
if let Some(ref storage_path) = self.storage_path {
let file_path = storage_path.join(format!("{}.json", Self::sanitize_key(&session.key)));
let content = serde_json::to_string_pretty(session)?;
tokio::fs::write(&file_path, content).await?;
}
Ok(())
}
pub async fn delete(&self, key: &str) -> Result<()> {
{
let mut sessions = self.sessions.write().await;
sessions.remove(key);
}
if let Some(ref storage_path) = self.storage_path {
let file_path = storage_path.join(format!("{}.json", Self::sanitize_key(key)));
if file_path.exists() {
tokio::fs::remove_file(&file_path).await?;
}
}
Ok(())
}
pub async fn list(&self) -> Result<Vec<String>> {
let mut keys = Vec::new();
{
let sessions = self.sessions.read().await;
keys.extend(sessions.keys().cloned());
}
if let Some(ref storage_path) = self.storage_path {
let mut dir_entries = tokio::fs::read_dir(storage_path).await?;
while let Some(entry) = dir_entries.next_entry().await? {
let path = entry.path();
if path.extension().map(|e| e == "json").unwrap_or(false) {
if let Ok(content) = tokio::fs::read_to_string(&path).await {
if let Ok(session) = serde_json::from_str::<Session>(&content) {
if !keys.contains(&session.key) {
keys.push(session.key);
}
}
}
}
}
}
keys.sort();
Ok(keys)
}
pub async fn exists(&self, key: &str) -> bool {
{
let sessions = self.sessions.read().await;
if sessions.contains_key(key) {
return true;
}
}
if let Some(ref storage_path) = self.storage_path {
let file_path = storage_path.join(format!("{}.json", Self::sanitize_key(key)));
return file_path.exists();
}
false
}
pub async fn clear_cache(&self) {
let mut sessions = self.sessions.write().await;
sessions.clear();
}
pub async fn cache_size(&self) -> usize {
let sessions = self.sessions.read().await;
sessions.len()
}
pub fn sessions_dir(&self) -> Option<&std::path::Path> {
self.storage_path.as_deref()
}
fn sanitize_key(key: &str) -> String {
let mut result = String::with_capacity(key.len() * 3);
for c in key.chars() {
match c {
'/' => result.push_str("%2F"),
'\\' => result.push_str("%5C"),
':' => result.push_str("%3A"),
'*' => result.push_str("%2A"),
'?' => result.push_str("%3F"),
'"' => result.push_str("%22"),
'<' => result.push_str("%3C"),
'>' => result.push_str("%3E"),
'|' => result.push_str("%7C"),
'%' => result.push_str("%25"), c => result.push(c),
}
}
result
}
#[allow(dead_code)]
fn unsanitize_key(sanitized: &str) -> String {
let mut result = String::with_capacity(sanitized.len());
let mut chars = sanitized.chars().peekable();
while let Some(c) = chars.next() {
if c == '%' {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() == 2 {
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
continue;
}
}
result.push('%');
result.push_str(&hex);
} else {
result.push(c);
}
}
result
}
fn maybe_repair_loaded_session(&self, session: &mut Session, source: &str) {
if !Config::get().session.auto_repair {
return;
}
let (repaired, stats) =
crate::session::repair::repair_messages(std::mem::take(&mut session.messages));
if stats.total_repairs() > 0 {
warn!(
session_key = %session.key,
source = source,
orphan_tool_results_removed = stats.orphan_tool_results_removed,
empty_messages_removed = stats.empty_messages_removed,
role_alternation_fixes = stats.role_alternation_fixes,
duplicate_messages_removed = stats.duplicate_messages_removed,
truncation_repairs = stats.truncation_repairs,
"Session history repaired on load"
);
}
session.messages = repaired;
}
}
impl Clone for SessionManager {
fn clone(&self) -> Self {
Self {
sessions: Arc::clone(&self.sessions),
storage_path: self.storage_path.clone(),
}
}
}
impl Default for SessionManager {
fn default() -> Self {
Self::new_memory()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_session_create_and_retrieve() {
let manager = SessionManager::new_memory();
let session = manager.get_or_create("test-session").await.unwrap();
assert!(session.messages.is_empty());
assert_eq!(session.key, "test-session");
}
#[tokio::test]
async fn test_session_save_and_load() {
let manager = SessionManager::new_memory();
let mut session = manager.get_or_create("test-session").await.unwrap();
session.add_message(Message::user("Hello"));
manager.save(&session).await.unwrap();
let loaded = manager.get_or_create("test-session").await.unwrap();
assert_eq!(loaded.messages.len(), 1);
assert_eq!(loaded.messages[0].content, "Hello");
}
#[test]
fn test_message_creation() {
let user_msg = Message::user("Hello");
assert_eq!(user_msg.role, Role::User);
assert_eq!(user_msg.content, "Hello");
let assistant_msg = Message::assistant("Hi there");
assert_eq!(assistant_msg.role, Role::Assistant);
assert_eq!(assistant_msg.content, "Hi there");
let system_msg = Message::system("You are helpful");
assert_eq!(system_msg.role, Role::System);
let tool_msg = Message::tool_result("call_1", "Success");
assert_eq!(tool_msg.role, Role::Tool);
assert_eq!(tool_msg.tool_call_id, Some("call_1".to_string()));
}
#[tokio::test]
async fn test_session_delete() {
let manager = SessionManager::new_memory();
manager.get_or_create("test-session").await.unwrap();
assert!(manager.exists("test-session").await);
manager.delete("test-session").await.unwrap();
assert!(!manager.exists("test-session").await);
}
#[tokio::test]
async fn test_session_list() {
let manager = SessionManager::new_memory();
manager.get_or_create("session-a").await.unwrap();
manager.get_or_create("session-b").await.unwrap();
manager.get_or_create("session-c").await.unwrap();
let keys = manager.list().await.unwrap();
assert_eq!(keys.len(), 3);
assert!(keys.contains(&"session-a".to_string()));
assert!(keys.contains(&"session-b".to_string()));
assert!(keys.contains(&"session-c".to_string()));
}
#[tokio::test]
async fn test_session_get_nonexistent() {
let manager = SessionManager::new_memory();
let result = manager.get("nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_session_manager_clone() {
let manager1 = SessionManager::new_memory();
let manager2 = manager1.clone();
let mut session = manager1.get_or_create("shared").await.unwrap();
session.add_message(Message::user("Test"));
manager1.save(&session).await.unwrap();
let loaded = manager2.get("shared").await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().messages.len(), 1);
}
#[tokio::test]
async fn test_session_clear_cache() {
let manager = SessionManager::new_memory();
manager.get_or_create("session1").await.unwrap();
manager.get_or_create("session2").await.unwrap();
assert_eq!(manager.cache_size().await, 2);
manager.clear_cache().await;
assert_eq!(manager.cache_size().await, 0);
}
#[tokio::test]
async fn test_file_persistence() {
let temp_dir = TempDir::new().unwrap();
let storage_path = temp_dir.path().to_path_buf();
{
let manager = SessionManager::with_path(storage_path.clone()).unwrap();
let mut session = manager.get_or_create("persist-test").await.unwrap();
session.add_message(Message::user("Persisted message"));
manager.save(&session).await.unwrap();
}
{
let manager = SessionManager::with_path(storage_path.clone()).unwrap();
let session = manager.get_or_create("persist-test").await.unwrap();
assert_eq!(session.messages.len(), 1);
assert_eq!(session.messages[0].content, "Persisted message");
}
}
#[tokio::test]
async fn test_file_persistence_delete() {
let temp_dir = TempDir::new().unwrap();
let storage_path = temp_dir.path().to_path_buf();
let manager = SessionManager::with_path(storage_path.clone()).unwrap();
let session = manager.get_or_create("delete-test").await.unwrap();
manager.save(&session).await.unwrap();
let file_path = storage_path.join("delete-test.json");
assert!(file_path.exists(), "Session file should exist after save");
manager.delete("delete-test").await.unwrap();
assert!(!file_path.exists(), "Session file should be deleted");
}
#[tokio::test]
async fn test_file_persistence_list() {
let temp_dir = TempDir::new().unwrap();
let storage_path = temp_dir.path().to_path_buf();
let manager = SessionManager::with_path(storage_path).unwrap();
for name in ["alpha", "beta", "gamma"] {
let session = manager.get_or_create(name).await.unwrap();
manager.save(&session).await.unwrap();
}
manager.clear_cache().await;
let keys = manager.list().await.unwrap();
assert_eq!(keys.len(), 3);
assert!(keys.contains(&"alpha".to_string()));
assert!(keys.contains(&"beta".to_string()));
assert!(keys.contains(&"gamma".to_string()));
}
#[test]
fn test_sanitize_key() {
assert_eq!(SessionManager::sanitize_key("simple"), "simple");
assert_eq!(
SessionManager::sanitize_key("telegram:chat123"),
"telegram%3Achat123"
);
assert_eq!(
SessionManager::sanitize_key("path/to/session"),
"path%2Fto%2Fsession"
);
assert_eq!(
SessionManager::sanitize_key("a:b/c\\d*e?f\"g<h>i|j"),
"a%3Ab%2Fc%5Cd%2Ae%3Ff%22g%3Ch%3Ei%7Cj"
);
assert_eq!(SessionManager::sanitize_key("100%done"), "100%25done");
}
#[test]
fn test_unsanitize_key() {
let keys = [
"simple",
"telegram:chat123",
"path/to/session",
"a:b/c\\d*e?f\"g<h>i|j",
"100%done",
"multi%percent%%test",
];
for key in &keys {
let sanitized = SessionManager::sanitize_key(key);
let unsanitized = SessionManager::unsanitize_key(&sanitized);
assert_eq!(
unsanitized, *key,
"Key '{}' should round-trip through sanitize/unsanitize",
key
);
}
}
#[test]
fn test_sanitize_key_no_collisions() {
let key1 = "a:b";
let key2 = "a/b";
let key3 = "a_b";
let sanitized1 = SessionManager::sanitize_key(key1);
let sanitized2 = SessionManager::sanitize_key(key2);
let sanitized3 = SessionManager::sanitize_key(key3);
assert_ne!(sanitized1, sanitized2, "a:b and a/b should not collide");
assert_ne!(sanitized1, sanitized3, "a:b and a_b should not collide");
assert_ne!(sanitized2, sanitized3, "a/b and a_b should not collide");
assert_eq!(sanitized1, "a%3Ab");
assert_eq!(sanitized2, "a%2Fb");
assert_eq!(sanitized3, "a_b");
}
#[tokio::test]
async fn test_list_returns_original_keys_with_special_chars() {
let temp_dir = TempDir::new().unwrap();
let storage_path = temp_dir.path().to_path_buf();
let manager = SessionManager::with_path(storage_path).unwrap();
let keys = ["telegram:chat123", "discord/server456", "slack:channel:789"];
for key in &keys {
let session = manager.get_or_create(key).await.unwrap();
manager.save(&session).await.unwrap();
}
manager.clear_cache().await;
let listed_keys = manager.list().await.unwrap();
assert_eq!(listed_keys.len(), 3);
for key in &keys {
assert!(
listed_keys.contains(&key.to_string()),
"list() should contain original key '{}', got {:?}",
key,
listed_keys
);
}
}
#[tokio::test]
async fn test_concurrent_access() {
let manager = Arc::new(SessionManager::new_memory());
let mut handles = Vec::new();
for i in 0..10 {
let manager_clone = Arc::clone(&manager);
let handle = tokio::spawn(async move {
let mut session = manager_clone.get_or_create("concurrent").await.unwrap();
session.add_message(Message::user(&format!("Message {}", i)));
manager_clone.save(&session).await.unwrap();
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let session = manager.get("concurrent").await.unwrap().unwrap();
assert!(!session.messages.is_empty());
}
#[tokio::test]
async fn test_session_with_all_message_types() {
let manager = SessionManager::new_memory();
let mut session = manager.get_or_create("all-types").await.unwrap();
session.add_message(Message::system("You are a helpful assistant"));
session.add_message(Message::user("Search for rust programming"));
session.add_message(Message::assistant_with_tools(
"Let me search for that.",
vec![ToolCall::new("call_1", "search", r#"{"q": "rust"}"#)],
));
session.add_message(Message::tool_result("call_1", "Found 100 results"));
session.add_message(Message::assistant("I found 100 results about Rust."));
manager.save(&session).await.unwrap();
let loaded = manager.get_or_create("all-types").await.unwrap();
assert_eq!(loaded.messages.len(), 5);
assert_eq!(loaded.messages[0].role, Role::System);
assert_eq!(loaded.messages[1].role, Role::User);
assert_eq!(loaded.messages[2].role, Role::Assistant);
assert!(loaded.messages[2].has_tool_calls());
assert_eq!(loaded.messages[3].role, Role::Tool);
assert!(loaded.messages[3].is_tool_result());
assert_eq!(loaded.messages[4].role, Role::Assistant);
}
#[tokio::test]
async fn test_session_default() {
let manager = SessionManager::default();
let session = manager.get_or_create("test").await.unwrap();
assert!(session.is_empty());
}
}