use anyhow::{bail, Result};
use chrono::{DateTime, Utc};
use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer};
use std::path::PathBuf;
use tokio::fs;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SessionId(pub String);
impl SessionId {
pub fn new() -> Self {
Self(uuid::Uuid::new_v4().to_string())
}
}
impl Default for SessionId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Serialize for SessionId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.0)
}
}
impl<'de> Deserialize<'de> for SessionId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(Self(s))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserMessage {
pub content: String,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentResponse {
pub content: String,
pub session_id: Option<String>,
pub seed_id: Option<String>,
pub phase_reached: Option<String>,
pub evaluation_passed: Option<bool>,
pub timestamp: DateTime<Utc>,
}
pub type SessionMetadata = std::collections::HashMap<String, serde_json::Value>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: SessionId,
pub user_id: String,
#[serde(default)]
pub user_messages: Vec<UserMessage>,
#[serde(default)]
pub agent_responses: Vec<AgentResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub active_seed_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub active_persona_id: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
#[serde(default)]
pub metadata: SessionMetadata,
}
impl Session {
pub fn new(user_id: impl Into<String>) -> Self {
let now = Utc::now();
Self {
id: SessionId::new(),
user_id: user_id.into(),
user_messages: Vec::new(),
agent_responses: Vec::new(),
active_seed_id: None,
active_persona_id: None,
created_at: now,
updated_at: now,
metadata: SessionMetadata::new(),
}
}
pub fn with_id(user_id: impl Into<String>, session_id: SessionId) -> Self {
let now = Utc::now();
Self {
id: session_id,
user_id: user_id.into(),
user_messages: Vec::new(),
agent_responses: Vec::new(),
active_seed_id: None,
active_persona_id: None,
created_at: now,
updated_at: now,
metadata: SessionMetadata::new(),
}
}
pub fn add_user_message(&mut self, content: impl Into<String>) {
self.user_messages.push(UserMessage {
content: content.into(),
timestamp: Utc::now(),
});
self.updated_at = Utc::now();
}
pub fn add_agent_response(&mut self, response: AgentResponse) {
self.agent_responses.push(response);
self.updated_at = Utc::now();
}
pub fn set_active_seed(&mut self, seed_id: Option<String>) {
self.active_seed_id = seed_id;
self.updated_at = Utc::now();
}
pub fn set_active_persona(&mut self, persona_id: Option<String>) {
self.active_persona_id = persona_id;
self.updated_at = Utc::now();
}
pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.metadata.insert(key.into(), value);
self.updated_at = Utc::now();
}
pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
self.metadata.get(key)
}
pub fn exchange_count(&self) -> usize {
self.user_messages.len().min(self.agent_responses.len())
}
pub fn is_empty(&self) -> bool {
self.user_messages.is_empty()
}
}
#[derive(Clone)]
pub struct StateStore {
pub base_path: PathBuf,
}
impl StateStore {
pub fn new(base_path: PathBuf) -> Result<Self> {
Ok(Self { base_path })
}
fn validate_category(category: &str) -> Result<()> {
if category.contains("..") || category.contains('\\') {
bail!("invalid category name: '{}'", category);
}
if category.is_empty()
|| category.starts_with('/')
|| category.ends_with('/')
|| category.contains("//")
{
bail!("invalid category name: '{}'", category);
}
Ok(())
}
fn validate_name(name: &str) -> Result<()> {
if name.contains("..") || name.contains('/') || name.contains('\\') {
bail!("invalid file name: '{}'", name);
}
Ok(())
}
pub async fn save_markdown(&self, category: &str, name: &str, content: &str) -> Result<()> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let dir = self.base_path.join(category);
fs::create_dir_all(&dir).await?;
let path = dir.join(format!("{name}.md"));
let temp_path = dir.join(format!("{name}.{}.tmp", std::process::id()));
fs::write(&temp_path, content).await?;
tokio::fs::rename(&temp_path, &path).await?;
Ok(())
}
pub async fn load_markdown(&self, category: &str, name: &str) -> Result<Option<String>> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let path = self.base_path.join(category).join(format!("{name}.md"));
match fs::read_to_string(&path).await {
Ok(content) => Ok(Some(content)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e.into()),
}
}
pub async fn list_category(&self, category: &str) -> Result<Vec<String>> {
Self::validate_category(category)?;
let dir = self.base_path.join(category);
if !dir.exists() {
return Ok(Vec::new());
}
let mut entries = fs::read_dir(&dir).await?;
let mut names = Vec::new();
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if let Some(ext) = path.extension() {
if ext == "md" || ext == "json" {
if let Some(stem) = path.file_stem() {
names.push(stem.to_string_lossy().into_owned());
}
}
}
}
names.sort();
Ok(names)
}
pub async fn save_json<T: Serialize>(
&self,
category: &str,
name: &str,
data: &T,
) -> Result<()> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let dir = self.base_path.join(category);
fs::create_dir_all(&dir).await?;
let path = dir.join(format!("{name}.json"));
let content = serde_json::to_string_pretty(data)?;
let temp_path = dir.join(format!("{name}.{}.tmp", std::process::id()));
fs::write(&temp_path, &content).await?;
tokio::fs::rename(&temp_path, &path).await?;
Ok(())
}
pub async fn load_json<T: DeserializeOwned>(
&self,
category: &str,
name: &str,
) -> Result<Option<T>> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let path = self.base_path.join(category).join(format!("{name}.json"));
match fs::read_to_string(&path).await {
Ok(content) => Ok(Some(serde_json::from_str(&content)?)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e.into()),
}
}
pub async fn delete_file(&self, category: &str, name: &str) -> Result<bool> {
Self::validate_category(category)?;
Self::validate_name(name)?;
let path = self.base_path.join(category).join(format!("{name}.json"));
if path.exists() {
tokio::fs::remove_file(path).await?;
Ok(true)
} else {
let path = self.base_path.join(category).join(format!("{name}.md"));
if path.exists() {
tokio::fs::remove_file(path).await?;
Ok(true)
} else {
Ok(false)
}
}
}
}
impl std::fmt::Debug for StateStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StateStore")
.field("base_path", &self.base_path)
.finish()
}
}
impl StateStore {
pub async fn save_session(&self, session: &Session) -> Result<()> {
self.save_json("sessions", &session.id.0, session).await
}
pub async fn load_session(&self, session_id: &SessionId) -> Result<Option<Session>> {
self.load_json("sessions", &session_id.0).await
}
pub async fn list_sessions(&self) -> Result<Vec<SessionSummary>> {
let mut sessions = Vec::new();
if let Ok(names) = self.list_category("sessions").await {
for name in names {
if let Ok(Some(session)) = self.load_json::<Session>("sessions", &name).await {
sessions.push(SessionSummary {
id: session.id.0.clone(),
user_id: session.user_id.clone(),
message_count: session.user_messages.len(),
active_seed_id: session.active_seed_id.clone(),
created_at: session.created_at,
updated_at: session.updated_at,
});
}
}
}
sessions.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
Ok(sessions)
}
pub async fn delete_session(&self, session_id: &SessionId) -> Result<bool> {
let path = self
.base_path
.join("sessions")
.join(format!("{}.json", session_id.0));
match fs::remove_file(&path).await {
Ok(()) => {
tracing::info!(session_id = %session_id, "Session deleted");
Ok(true)
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(false),
Err(e) => Err(e.into()),
}
}
pub async fn get_or_create_session(
&self,
user_id: &str,
session_id: Option<&SessionId>,
) -> Result<Session> {
if let Some(sid) = session_id {
if let Some(existing) = self.load_session(sid).await? {
return Ok(existing);
}
}
let session = match session_id {
Some(sid) => Session::with_id(user_id, sid.clone()),
None => Session::new(user_id),
};
self.save_session(&session).await?;
Ok(session)
}
pub async fn update_session(&self, session: &Session) -> Result<()> {
self.save_session(session).await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionSummary {
pub id: String,
pub user_id: String,
pub message_count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub active_seed_id: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_session_creation_and_persistence() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
let mut session = Session::new("user-123");
session.add_user_message("Hello");
store.save_session(&session).await.unwrap();
let loaded = store.load_session(&session.id).await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.user_id, "user-123");
assert_eq!(loaded.user_messages.len(), 1);
}
#[tokio::test]
async fn test_session_list_sorts_by_updated() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
for i in 0..3 {
let mut session = Session::new(&format!("user-{}", i));
session.add_user_message(&format!("Message {}", i));
store.save_session(&session).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
let sessions = store.list_sessions().await.unwrap();
assert_eq!(sessions.len(), 3);
assert_eq!(sessions[0].user_id, "user-2");
}
#[tokio::test]
async fn test_delete_session() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
let session = Session::new("user-123");
store.save_session(&session).await.unwrap();
let deleted = store.delete_session(&session.id).await.unwrap();
assert!(deleted);
let loaded = store.load_session(&session.id).await.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_get_or_create_session_existing() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
let mut existing = Session::new("user-123");
existing.add_user_message("Original message");
store.save_session(&existing).await.unwrap();
let retrieved = store
.get_or_create_session("user-123", Some(&existing.id))
.await
.unwrap();
assert_eq!(retrieved.id, existing.id);
assert_eq!(retrieved.user_messages.len(), 1);
}
#[tokio::test]
async fn test_get_or_create_session_new() {
let temp_dir = tempfile::tempdir().unwrap();
let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
let session = store.get_or_create_session("user-456", None).await.unwrap();
assert_eq!(session.user_id, "user-456");
assert!(session.user_messages.is_empty());
}
}