use anyhow::{bail, Result};
use chrono::{DateTime, Utc};
use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer};
use std::path::PathBuf;
use tokio::fs;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SessionId(pub String);
impl SessionId {
pub fn new() -> Self {
Self(uuid::Uuid::new_v4().to_string())
}
}
impl Default for SessionId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Serialize for SessionId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.0)
}
}
impl<'de> Deserialize<'de> for SessionId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(Self(s))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserMessage {
pub content: String,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentResponse {
pub content: String,
pub session_id: Option<String>,
pub seed_id: Option<String>,
pub phase_reached: Option<String>,
pub evaluation_passed: Option<bool>,
pub timestamp: DateTime<Utc>,
}
pub type SessionMetadata = std::collections::HashMap<String, serde_json::Value>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: SessionId,
pub user_id: String,
#[serde(default)]
pub user_messages: Vec<UserMessage>,
#[serde(default)]
pub agent_responses: Vec<AgentResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub active_seed_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub active_persona_id: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
#[serde(default)]
pub metadata: SessionMetadata,
}
impl Session {
pub fn new(user_id: impl Into<String>) -> Self {
let now = Utc::now();
Self {
id: SessionId::new(),
user_id: user_id.into(),
user_messages: Vec::new(),
agent_responses: Vec::new(),
active_seed_id: None,
active_persona_id: None,
created_at: now,
updated_at: now,
metadata: SessionMetadata::new(),
}
}
pub fn with_id(user_id: impl Into<String>, session_id: SessionId) -> Self {
let now = Utc::now();
Self {
id: session_id,
user_id: user_id.into(),
user_messages: Vec::new(),
agent_responses: Vec::new(),
active_seed_id: None,
active_persona_id: None,
created_at: now,
updated_at: now,
metadata: SessionMetadata::new(),
}
}
pub fn add_user_message(&mut self, content: impl Into<String>) {
self.user_messages.push(UserMessage {
content: content.into(),
timestamp: Utc::now(),
});
self.updated_at = Utc::now();
}
pub fn add_agent_response(&mut self, response: AgentResponse) {
self.agent_responses.push(response);
self.updated_at = Utc::now();
}
pub fn set_active_seed(&mut self, seed_id: Option<String>) {
self.active_seed_id = seed_id;
self.updated_at = Utc::now();
}
pub fn set_active_persona(&mut self, persona_id: Option<String>) {
self.active_persona_id = persona_id;
self.updated_at = Utc::now();
}
pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.metadata.insert(key.into(), value);
self.updated_at = Utc::now();
}
pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
self.metadata.get(key)
}
pub fn exchange_count(&self) -> usize {
self.user_messages.len().min(self.agent_responses.len())
}
pub fn is_empty(&self) -> bool {
self.user_messages.is_empty()
}
}
#[derive(Clone)]
pub struct StateStore {
pub base_path: PathBuf,
}
impl StateStore {
pub fn new(base_path: PathBuf) -> Result<Self> {
Ok(Self { base_path })
}
fn validate_category(category: &str) -> Result<()> {
if category.contains("..") || category.contains('\\') {
bail!("invalid category name: '{category}'");
}
if category.is_empty()
|| category.starts_with('/')
|| category.ends_with('/')
|| category.contains("//")
{
bail!("invalid category name: '{category}'");
}
Ok(())
}
fn validate_name(name: &str) -> Result<()> {
if name.contains("..") || name.contains('/') || name.contains('\\') {
bail!("invalid file name: '{name}'");
}
Ok(())
}
pub async fn save_markdown(&self, category: &str, name: &str, content: &str) -> Result<()> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let dir = self.base_path.join(category);
fs::create_dir_all(&dir).await?;
let path = dir.join(format!("{name}.md"));
let temp_path = dir.join(format!(
"{name}.{}.{}.tmp",
std::process::id(),
uuid::Uuid::new_v4()
));
fs::write(&temp_path, content).await?;
tokio::fs::rename(&temp_path, &path).await?;
Ok(())
}
pub async fn load_markdown(&self, category: &str, name: &str) -> Result<Option<String>> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let path = self.base_path.join(category).join(format!("{name}.md"));
match fs::read_to_string(&path).await {
Ok(content) => Ok(Some(content)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e.into()),
}
}
pub async fn list_category(&self, category: &str) -> Result<Vec<String>> {
Self::validate_category(category)?;
let dir = self.base_path.join(category);
if !dir.exists() {
return Ok(Vec::new());
}
let mut entries = fs::read_dir(&dir).await?;
let mut names = Vec::new();
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if let Some(ext) = path.extension() {
if ext == "md" || ext == "json" {
if let Some(stem) = path.file_stem() {
names.push(stem.to_string_lossy().into_owned());
}
}
}
}
names.sort();
Ok(names)
}
pub async fn save_json<T: Serialize>(
&self,
category: &str,
name: &str,
data: &T,
) -> Result<()> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let dir = self.base_path.join(category);
fs::create_dir_all(&dir).await?;
let path = dir.join(format!("{name}.json"));
let content = serde_json::to_string_pretty(data)?;
let temp_path = dir.join(format!(
"{name}.{}.{}.tmp",
std::process::id(),
uuid::Uuid::new_v4()
));
fs::write(&temp_path, &content).await?;
tokio::fs::rename(&temp_path, &path).await?;
Ok(())
}
pub async fn load_json<T: DeserializeOwned>(
&self,
category: &str,
name: &str,
) -> Result<Option<T>> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let path = self.base_path.join(category).join(format!("{name}.json"));
match fs::read_to_string(&path).await {
Ok(content) => Ok(Some(serde_json::from_str(&content)?)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e.into()),
}
}
pub async fn delete_file(&self, category: &str, name: &str) -> Result<bool> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let path = self.base_path.join(category).join(format!("{name}.json"));
if path.exists() {
tokio::fs::remove_file(path).await?;
Ok(true)
} else {
let path = self.base_path.join(category).join(format!("{name}.md"));
if path.exists() {
tokio::fs::remove_file(path).await?;
Ok(true)
} else {
Ok(false)
}
}
}
}
impl std::fmt::Debug for StateStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StateStore")
.field("base_path", &self.base_path)
.finish()
}
}
impl StateStore {
pub async fn save_session(&self, session: &Session) -> Result<()> {
self.save_json("sessions", &session.id.0, session).await
}
pub async fn save_session_with_prune(
&self,
session: &Session,
prune_config: &PruneConfig,
) -> Result<()> {
self.save_session(session).await?;
let store = self.clone();
let config = prune_config.clone();
tokio::spawn(async move {
if let Err(e) = store.prune_sessions(&config).await {
tracing::warn!(error = %e, "Background session pruning failed");
}
});
Ok(())
}
pub async fn load_session(&self, session_id: &SessionId) -> Result<Option<Session>> {
self.load_json("sessions", &session_id.0).await
}
pub async fn list_sessions(&self) -> Result<Vec<SessionSummary>> {
let mut sessions = Vec::new();
if let Ok(names) = self.list_category("sessions").await {
for name in names {
if let Ok(Some(session)) = self.load_json::<Session>("sessions", &name).await {
sessions.push(SessionSummary {
id: session.id.0.clone(),
user_id: session.user_id.clone(),
message_count: session.user_messages.len(),
active_seed_id: session.active_seed_id.clone(),
project_id: session
.metadata
.get("project_ids")
.and_then(|v| v.as_str())
.map(String::from),
created_at: session.created_at,
updated_at: session.updated_at,
});
}
}
}
sessions.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
Ok(sessions)
}
pub async fn delete_session(&self, session_id: &SessionId) -> Result<bool> {
let path = self
.base_path
.join("sessions")
.join(format!("{}.json", session_id.0));
match fs::remove_file(&path).await {
Ok(()) => {
tracing::info!(session_id = %session_id, "Session deleted");
Ok(true)
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(false),
Err(e) => Err(e.into()),
}
}
pub async fn get_or_create_session(
&self,
user_id: &str,
session_id: Option<&SessionId>,
) -> Result<Session> {
if let Some(sid) = session_id {
if let Some(existing) = self.load_session(sid).await? {
return Ok(existing);
}
}
let session = match session_id {
Some(sid) => Session::with_id(user_id, sid.clone()),
None => Session::new(user_id),
};
self.save_session(&session).await?;
Ok(session)
}
pub async fn update_session(&self, session: &Session) -> Result<()> {
self.save_session(session).await
}
pub async fn prune_sessions(&self, config: &PruneConfig) -> Result<usize> {
let mut sessions = self.list_sessions().await?;
let mut pruned = 0;
if config.ttl_hours > 0 {
let cutoff = Utc::now() - chrono::Duration::hours(config.ttl_hours as i64);
let to_prune_ttl: Vec<String> = sessions
.iter()
.filter(|s| s.updated_at < cutoff)
.map(|s| s.id.clone())
.collect();
for id in &to_prune_ttl {
let sid = SessionId(id.clone());
if self.delete_session(&sid).await.is_ok() {
pruned += 1;
}
}
sessions.retain(|s| !to_prune_ttl.contains(&s.id));
}
if config.max_sessions > 0 && sessions.len() > config.max_sessions {
let excess = sessions.len() - config.max_sessions;
for session in sessions.into_iter().rev().take(excess) {
let sid = SessionId(session.id);
if self.delete_session(&sid).await.is_ok() {
pruned += 1;
}
}
}
if pruned > 0 {
tracing::info!(pruned = pruned, "Session pruning completed");
}
Ok(pruned)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionSummary {
pub id: String,
pub user_id: String,
pub message_count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub active_seed_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub project_id: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct PruneConfig {
pub max_sessions: usize,
pub ttl_hours: u64,
}
impl Default for PruneConfig {
fn default() -> Self {
Self {
max_sessions: 100,
ttl_hours: 168, }
}
}
pub struct PruneThrottle {
last_prune: std::sync::Mutex<Option<std::time::Instant>>,
cooldown_secs: u64,
}
impl PruneThrottle {
pub fn new(cooldown_secs: u64) -> Self {
Self {
last_prune: std::sync::Mutex::new(None),
cooldown_secs,
}
}
pub fn should_prune(&self) -> bool {
let mut guard = self.last_prune.lock().unwrap_or_else(|e| {
tracing::warn!("PruneThrottle mutex poisoned, recovering: {e}");
e.into_inner()
});
let now = std::time::Instant::now();
match *guard {
Some(last) => {
if now.duration_since(last).as_secs() >= self.cooldown_secs {
*guard = Some(now);
true
} else {
false
}
}
None => {
*guard = Some(now);
true
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_session_creation_and_persistence() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
let mut session = Session::new("user-123");
session.add_user_message("Hello");
store.save_session(&session).await.unwrap();
let loaded = store.load_session(&session.id).await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.user_id, "user-123");
assert_eq!(loaded.user_messages.len(), 1);
}
#[tokio::test]
async fn test_session_list_sorts_by_updated() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
for i in 0..3 {
let mut session = Session::new(&format!("user-{}", i));
session.add_user_message(&format!("Message {}", i));
store.save_session(&session).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
let sessions = store.list_sessions().await.unwrap();
assert_eq!(sessions.len(), 3);
assert_eq!(sessions[0].user_id, "user-2");
}
#[tokio::test]
async fn test_delete_session() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
let session = Session::new("user-123");
store.save_session(&session).await.unwrap();
let deleted = store.delete_session(&session.id).await.unwrap();
assert!(deleted);
let loaded = store.load_session(&session.id).await.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_get_or_create_session_existing() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
let mut existing = Session::new("user-123");
existing.add_user_message("Original message");
store.save_session(&existing).await.unwrap();
let retrieved = store
.get_or_create_session("user-123", Some(&existing.id))
.await
.unwrap();
assert_eq!(retrieved.id, existing.id);
assert_eq!(retrieved.user_messages.len(), 1);
}
#[tokio::test]
async fn test_get_or_create_session_new() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
let session = store.get_or_create_session("user-456", None).await.unwrap();
assert_eq!(session.user_id, "user-456");
assert!(session.user_messages.is_empty());
}
#[tokio::test]
async fn test_prune_sessions_by_count() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
for i in 0..5 {
let mut session = Session::new(&format!("user-{}", i));
session.add_user_message(&format!("Message {}", i));
store.save_session(&session).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
let config = PruneConfig {
max_sessions: 3,
ttl_hours: 0,
};
let pruned = store.prune_sessions(&config).await.unwrap();
assert_eq!(pruned, 2);
let remaining = store.list_sessions().await.unwrap();
assert_eq!(remaining.len(), 3);
let remaining_ids: Vec<&str> = remaining.iter().map(|s| s.user_id.as_str()).collect();
assert!(remaining_ids.contains(&"user-2"));
assert!(remaining_ids.contains(&"user-3"));
assert!(remaining_ids.contains(&"user-4"));
}
#[tokio::test]
async fn test_prune_sessions_by_ttl() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
let mut old_session = Session::new("old-user");
old_session.updated_at = Utc::now() - chrono::Duration::hours(48);
store.save_session(&old_session).await.unwrap();
let mut recent_session = Session::new("recent-user");
recent_session.add_user_message("Hello");
store.save_session(&recent_session).await.unwrap();
let config = PruneConfig {
max_sessions: 0,
ttl_hours: 24,
};
let pruned = store.prune_sessions(&config).await.unwrap();
assert_eq!(pruned, 1);
let remaining = store.list_sessions().await.unwrap();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].user_id, "recent-user");
}
}