use crate::error::{Error, Result};
use crate::llm::Message;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
fn default_session_dir() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".praisonai")
.join("sessions")
}
const DEFAULT_MAX_MESSAGES: usize = 100;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionMessage {
pub role: String,
pub content: String,
pub timestamp: f64,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
impl SessionMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
Self {
role: role.into(),
content: content.into(),
timestamp: now,
metadata: HashMap::new(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn to_message(&self) -> Message {
match self.role.as_str() {
"user" => Message::user(&self.content),
"assistant" => Message::assistant(&self.content),
"system" => Message::system(&self.content),
_ => Message::user(&self.content),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionData {
pub session_id: String,
#[serde(default)]
pub messages: Vec<SessionMessage>,
pub created_at: String,
pub updated_at: String,
#[serde(default)]
pub agent_name: Option<String>,
#[serde(default)]
pub user_id: Option<String>,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
impl SessionData {
pub fn new(session_id: impl Into<String>) -> Self {
let now = chrono::Utc::now().to_rfc3339();
Self {
session_id: session_id.into(),
messages: Vec::new(),
created_at: now.clone(),
updated_at: now,
agent_name: None,
user_id: None,
metadata: HashMap::new(),
}
}
pub fn get_chat_history(&self, max_messages: Option<usize>) -> Vec<Message> {
let messages = if let Some(max) = max_messages {
if self.messages.len() > max {
&self.messages[self.messages.len() - max..]
} else {
&self.messages[..]
}
} else {
&self.messages[..]
};
messages.iter().map(|m| m.to_message()).collect()
}
pub fn add_message(&mut self, message: SessionMessage) {
self.messages.push(message);
self.updated_at = chrono::Utc::now().to_rfc3339();
}
pub fn clear(&mut self) {
self.messages.clear();
self.updated_at = chrono::Utc::now().to_rfc3339();
}
}
pub trait SessionStore: Send + Sync {
fn load(&self, session_id: &str) -> Result<SessionData>;
fn save(&self, session: &SessionData) -> Result<()>;
fn exists(&self, session_id: &str) -> bool;
fn delete(&self, session_id: &str) -> Result<()>;
fn list(&self, limit: usize) -> Result<Vec<SessionInfo>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionInfo {
pub session_id: String,
pub agent_name: Option<String>,
pub created_at: String,
pub updated_at: String,
pub message_count: usize,
}
pub struct FileSessionStore {
session_dir: PathBuf,
max_messages: usize,
cache: Arc<RwLock<HashMap<String, SessionData>>>,
}
impl FileSessionStore {
pub fn new() -> Self {
let session_dir = default_session_dir();
fs::create_dir_all(&session_dir).ok();
Self {
session_dir,
max_messages: DEFAULT_MAX_MESSAGES,
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_dir(dir: impl Into<PathBuf>) -> Self {
let session_dir = dir.into();
fs::create_dir_all(&session_dir).ok();
Self {
session_dir,
max_messages: DEFAULT_MAX_MESSAGES,
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn max_messages(mut self, max: usize) -> Self {
self.max_messages = max;
self
}
fn get_path(&self, session_id: &str) -> PathBuf {
let safe_id: String = session_id
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect();
self.session_dir.join(format!("{}.json", safe_id))
}
}
impl Default for FileSessionStore {
fn default() -> Self {
Self::new()
}
}
impl SessionStore for FileSessionStore {
fn load(&self, session_id: &str) -> Result<SessionData> {
if let Ok(cache) = self.cache.read() {
if let Some(session) = cache.get(session_id) {
return Ok(session.clone());
}
}
let path = self.get_path(session_id);
if !path.exists() {
let session = SessionData::new(session_id);
if let Ok(mut cache) = self.cache.write() {
cache.insert(session_id.to_string(), session.clone());
}
return Ok(session);
}
let content = fs::read_to_string(&path)
.map_err(|e| Error::io(format!("Failed to read session file: {}", e)))?;
let session: SessionData = serde_json::from_str(&content)
.map_err(|e| Error::config(format!("Failed to parse session file: {}", e)))?;
if let Ok(mut cache) = self.cache.write() {
cache.insert(session_id.to_string(), session.clone());
}
Ok(session)
}
fn save(&self, session: &SessionData) -> Result<()> {
let path = self.get_path(&session.session_id);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|e| Error::io(format!("Failed to create session directory: {}", e)))?;
}
let mut session = session.clone();
if session.messages.len() > self.max_messages {
session.messages =
session.messages[session.messages.len() - self.max_messages..].to_vec();
}
let content = serde_json::to_string_pretty(&session)
.map_err(|e| Error::config(format!("Failed to serialize session: {}", e)))?;
fs::write(&path, content)
.map_err(|e| Error::io(format!("Failed to write session file: {}", e)))?;
if let Ok(mut cache) = self.cache.write() {
cache.insert(session.session_id.clone(), session);
}
Ok(())
}
fn exists(&self, session_id: &str) -> bool {
self.get_path(session_id).exists()
}
fn delete(&self, session_id: &str) -> Result<()> {
let path = self.get_path(session_id);
if let Ok(mut cache) = self.cache.write() {
cache.remove(session_id);
}
if path.exists() {
fs::remove_file(&path)
.map_err(|e| Error::io(format!("Failed to delete session file: {}", e)))?;
}
Ok(())
}
fn list(&self, limit: usize) -> Result<Vec<SessionInfo>> {
let mut sessions = Vec::new();
let entries = fs::read_dir(&self.session_dir)
.map_err(|e| Error::io(format!("Failed to read session directory: {}", e)))?;
for entry in entries.flatten() {
let path: std::path::PathBuf = entry.path();
if path.extension().is_some_and(|ext| ext == "json") {
if let Ok(content) = fs::read_to_string(&path) {
if let Ok(data) = serde_json::from_str::<SessionData>(&content) {
sessions.push(SessionInfo {
session_id: data.session_id,
agent_name: data.agent_name,
created_at: data.created_at,
updated_at: data.updated_at,
message_count: data.messages.len(),
});
}
}
}
}
sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
sessions.truncate(limit);
Ok(sessions)
}
}
pub struct InMemorySessionStore {
sessions: Arc<RwLock<HashMap<String, SessionData>>>,
}
impl InMemorySessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for InMemorySessionStore {
fn default() -> Self {
Self::new()
}
}
impl SessionStore for InMemorySessionStore {
fn load(&self, session_id: &str) -> Result<SessionData> {
let sessions = self.sessions.read().unwrap();
Ok(sessions
.get(session_id)
.cloned()
.unwrap_or_else(|| SessionData::new(session_id)))
}
fn save(&self, session: &SessionData) -> Result<()> {
let mut sessions = self.sessions.write().unwrap();
sessions.insert(session.session_id.clone(), session.clone());
Ok(())
}
fn exists(&self, session_id: &str) -> bool {
self.sessions.read().unwrap().contains_key(session_id)
}
fn delete(&self, session_id: &str) -> Result<()> {
self.sessions.write().unwrap().remove(session_id);
Ok(())
}
fn list(&self, limit: usize) -> Result<Vec<SessionInfo>> {
let sessions = self.sessions.read().unwrap();
let mut infos: Vec<_> = sessions
.values()
.map(|s| SessionInfo {
session_id: s.session_id.clone(),
agent_name: s.agent_name.clone(),
created_at: s.created_at.clone(),
updated_at: s.updated_at.clone(),
message_count: s.messages.len(),
})
.collect();
infos.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
infos.truncate(limit);
Ok(infos)
}
}
pub struct Session {
session_id: String,
data: SessionData,
store: Arc<dyn SessionStore>,
}
impl Session {
pub fn new(session_id: impl Into<String>) -> Self {
let session_id = session_id.into();
let store = Arc::new(FileSessionStore::new());
let data = store
.load(&session_id)
.unwrap_or_else(|_| SessionData::new(&session_id));
Self {
session_id,
data,
store,
}
}
pub fn with_store(session_id: impl Into<String>, store: Arc<dyn SessionStore>) -> Self {
let session_id = session_id.into();
let data = store
.load(&session_id)
.unwrap_or_else(|_| SessionData::new(&session_id));
Self {
session_id,
data,
store,
}
}
pub fn load(session_id: impl Into<String>) -> Result<Self> {
let session_id = session_id.into();
let store = Arc::new(FileSessionStore::new());
let data = store.load(&session_id)?;
Ok(Self {
session_id,
data,
store,
})
}
pub fn id(&self) -> &str {
&self.session_id
}
pub fn add_user_message(&mut self, content: impl Into<String>) -> Result<()> {
self.data.add_message(SessionMessage::user(content));
self.store.save(&self.data)
}
pub fn add_assistant_message(&mut self, content: impl Into<String>) -> Result<()> {
self.data.add_message(SessionMessage::assistant(content));
self.store.save(&self.data)
}
pub fn add_message(&mut self, role: &str, content: impl Into<String>) -> Result<()> {
self.data.add_message(SessionMessage::new(role, content));
self.store.save(&self.data)
}
pub fn get_history(&self, max_messages: Option<usize>) -> Vec<Message> {
self.data.get_chat_history(max_messages)
}
pub fn messages(&self) -> &[SessionMessage] {
&self.data.messages
}
pub fn message_count(&self) -> usize {
self.data.messages.len()
}
pub fn set_agent_name(&mut self, name: impl Into<String>) -> Result<()> {
self.data.agent_name = Some(name.into());
self.store.save(&self.data)
}
pub fn set_user_id(&mut self, user_id: impl Into<String>) -> Result<()> {
self.data.user_id = Some(user_id.into());
self.store.save(&self.data)
}
pub fn clear(&mut self) -> Result<()> {
self.data.clear();
self.store.save(&self.data)
}
pub fn delete(self) -> Result<()> {
self.store.delete(&self.session_id)
}
pub fn exists(&self) -> bool {
self.store.exists(&self.session_id)
}
pub fn save(&self) -> Result<()> {
self.store.save(&self.data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_message_creation() {
let msg = SessionMessage::user("Hello");
assert_eq!(msg.role, "user");
assert_eq!(msg.content, "Hello");
assert!(msg.timestamp > 0.0);
}
#[test]
fn test_session_data() {
let mut data = SessionData::new("test-session");
assert_eq!(data.session_id, "test-session");
assert!(data.messages.is_empty());
data.add_message(SessionMessage::user("Hello"));
data.add_message(SessionMessage::assistant("Hi there!"));
assert_eq!(data.messages.len(), 2);
let history = data.get_chat_history(None);
assert_eq!(history.len(), 2);
}
#[test]
fn test_in_memory_store() {
let store = InMemorySessionStore::new();
let mut session = SessionData::new("test");
session.add_message(SessionMessage::user("Hello"));
store.save(&session).unwrap();
assert!(store.exists("test"));
let loaded = store.load("test").unwrap();
assert_eq!(loaded.messages.len(), 1);
store.delete("test").unwrap();
assert!(!store.exists("test"));
}
#[test]
fn test_session_api() {
let store = Arc::new(InMemorySessionStore::new());
let mut session = Session::with_store("test-api", store);
session.add_user_message("Hello").unwrap();
session.add_assistant_message("Hi!").unwrap();
assert_eq!(session.message_count(), 2);
let history = session.get_history(None);
assert_eq!(history.len(), 2);
session.clear().unwrap();
assert_eq!(session.message_count(), 0);
}
#[test]
fn test_session_history_limit() {
let mut data = SessionData::new("test");
for i in 0..10 {
data.add_message(SessionMessage::user(format!("Message {}", i)));
}
let history = data.get_chat_history(Some(5));
assert_eq!(history.len(), 5);
}
}