use crate::llm::{Message, TokenUsage, ToolDefinition};
use crate::planning::Task;
use crate::session::{ContextUsage, SessionConfig, SessionState};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use tokio::fs;
use tokio::io::AsyncWriteExt;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionData {
pub id: String,
pub config: SessionConfig,
pub state: SessionState,
pub messages: Vec<Message>,
pub context_usage: ContextUsage,
pub total_usage: TokenUsage,
#[serde(default)]
pub total_cost: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_name: Option<String>,
#[serde(default)]
pub cost_records: Vec<crate::telemetry::LlmCostRecord>,
pub tool_names: Vec<String>,
pub thinking_enabled: bool,
pub thinking_budget: Option<usize>,
pub created_at: i64,
pub updated_at: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub llm_config: Option<LlmConfigData>,
#[serde(default, alias = "todos")]
pub tasks: Vec<Task>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmConfigData {
pub provider: String,
pub model: String,
#[serde(skip_serializing, default)]
pub api_key: Option<String>,
pub base_url: Option<String>,
}
impl SessionData {
pub fn tool_names_from_definitions(tools: &[ToolDefinition]) -> Vec<String> {
tools.iter().map(|t| t.name.clone()).collect()
}
}
#[async_trait::async_trait]
pub trait SessionStore: Send + Sync {
async fn save(&self, session: &SessionData) -> Result<()>;
async fn load(&self, id: &str) -> Result<Option<SessionData>>;
async fn delete(&self, id: &str) -> Result<()>;
async fn list(&self) -> Result<Vec<String>>;
async fn exists(&self, id: &str) -> Result<bool>;
async fn health_check(&self) -> Result<()> {
Ok(())
}
fn backend_name(&self) -> &str {
"unknown"
}
}
pub struct FileSessionStore {
dir: PathBuf,
}
impl FileSessionStore {
pub async fn new<P: AsRef<Path>>(dir: P) -> Result<Self> {
let dir = dir.as_ref().to_path_buf();
fs::create_dir_all(&dir)
.await
.with_context(|| format!("Failed to create session directory: {}", dir.display()))?;
Ok(Self { dir })
}
fn session_path(&self, id: &str) -> PathBuf {
let safe_id = id.replace(['/', '\\'], "_").replace("..", "_");
self.dir.join(format!("{}.json", safe_id))
}
}
#[async_trait::async_trait]
impl SessionStore for FileSessionStore {
async fn save(&self, session: &SessionData) -> Result<()> {
let path = self.session_path(&session.id);
let json = serde_json::to_string_pretty(session)
.with_context(|| format!("Failed to serialize session: {}", session.id))?;
let unique_suffix = format!(
"{}.{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
std::process::id()
);
let temp_path = path.with_extension(format!("json.{}.tmp", unique_suffix));
let mut file = fs::File::create(&temp_path)
.await
.with_context(|| format!("Failed to create temp file: {}", temp_path.display()))?;
file.write_all(json.as_bytes())
.await
.with_context(|| format!("Failed to write session data: {}", session.id))?;
file.sync_all()
.await
.with_context(|| format!("Failed to sync session file: {}", session.id))?;
fs::rename(&temp_path, &path)
.await
.with_context(|| format!("Failed to rename session file: {}", session.id))?;
tracing::debug!("Saved session {} to {}", session.id, path.display());
Ok(())
}
async fn load(&self, id: &str) -> Result<Option<SessionData>> {
let path = self.session_path(id);
if !path.exists() {
return Ok(None);
}
let json = fs::read_to_string(&path)
.await
.with_context(|| format!("Failed to read session file: {}", path.display()))?;
let session: SessionData = serde_json::from_str(&json)
.with_context(|| format!("Failed to parse session file: {}", path.display()))?;
tracing::debug!("Loaded session {} from {}", id, path.display());
Ok(Some(session))
}
async fn delete(&self, id: &str) -> Result<()> {
let path = self.session_path(id);
if path.exists() {
fs::remove_file(&path)
.await
.with_context(|| format!("Failed to delete session file: {}", path.display()))?;
tracing::debug!("Deleted session {} from {}", id, path.display());
}
Ok(())
}
async fn list(&self) -> Result<Vec<String>> {
let mut session_ids = Vec::new();
let mut entries = fs::read_dir(&self.dir)
.await
.with_context(|| format!("Failed to read session directory: {}", self.dir.display()))?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "json") {
if let Some(stem) = path.file_stem() {
if let Some(id) = stem.to_str() {
session_ids.push(id.to_string());
}
}
}
}
Ok(session_ids)
}
async fn exists(&self, id: &str) -> Result<bool> {
let path = self.session_path(id);
Ok(path.exists())
}
async fn health_check(&self) -> Result<()> {
let probe = self.dir.join(".health_check");
fs::write(&probe, b"ok")
.await
.with_context(|| format!("Store directory not writable: {}", self.dir.display()))?;
let _ = fs::remove_file(&probe).await;
Ok(())
}
fn backend_name(&self) -> &str {
"file"
}
}
pub struct MemorySessionStore {
sessions: tokio::sync::RwLock<HashMap<String, SessionData>>,
}
impl MemorySessionStore {
pub fn new() -> Self {
Self {
sessions: tokio::sync::RwLock::new(HashMap::new()),
}
}
}
impl Default for MemorySessionStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl SessionStore for MemorySessionStore {
async fn save(&self, session: &SessionData) -> Result<()> {
let mut sessions = self.sessions.write().await;
sessions.insert(session.id.clone(), session.clone());
Ok(())
}
async fn load(&self, id: &str) -> Result<Option<SessionData>> {
let sessions = self.sessions.read().await;
Ok(sessions.get(id).cloned())
}
async fn delete(&self, id: &str) -> Result<()> {
let mut sessions = self.sessions.write().await;
sessions.remove(id);
Ok(())
}
async fn list(&self) -> Result<Vec<String>> {
let sessions = self.sessions.read().await;
Ok(sessions.keys().cloned().collect())
}
async fn exists(&self, id: &str) -> Result<bool> {
let sessions = self.sessions.read().await;
Ok(sessions.contains_key(id))
}
fn backend_name(&self) -> &str {
"memory"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hitl::ConfirmationPolicy;
use crate::permissions::PermissionPolicy;
use crate::prompts::PlanningMode;
use crate::queue::SessionQueueConfig;
use tempfile::tempdir;
fn create_test_session_data() -> SessionData {
SessionData {
id: "test-session-1".to_string(),
config: SessionConfig {
name: "Test Session".to_string(),
workspace: "/tmp/workspace".to_string(),
system_prompt: Some("You are helpful.".to_string()),
max_context_length: 200000,
auto_compact: false,
auto_compact_threshold: crate::session::DEFAULT_AUTO_COMPACT_THRESHOLD,
storage_type: crate::config::StorageBackend::File,
queue_config: None,
confirmation_policy: None,
permission_policy: None,
parent_id: None,
security_config: None,
hook_engine: None,
planning_mode: PlanningMode::default(),
goal_tracking: false,
},
state: SessionState::Active,
messages: vec![
Message::user("Hello"),
Message {
role: "assistant".to_string(),
content: vec![crate::llm::ContentBlock::Text {
text: "Hi there!".to_string(),
}],
reasoning_content: None,
},
],
context_usage: ContextUsage {
used_tokens: 100,
max_tokens: 200000,
percent: 0.0005,
turns: 2,
},
total_usage: TokenUsage {
prompt_tokens: 50,
completion_tokens: 50,
total_tokens: 100,
cache_read_tokens: None,
cache_write_tokens: None,
},
tool_names: vec!["bash".to_string(), "read".to_string()],
thinking_enabled: false,
thinking_budget: None,
created_at: 1700000000,
updated_at: 1700000100,
llm_config: None,
tasks: vec![],
parent_id: None,
total_cost: 0.0,
model_name: None,
cost_records: Vec::new(),
}
}
#[tokio::test]
async fn test_file_store_save_and_load() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let session = create_test_session_data();
store.save(&session).await.unwrap();
let loaded = store.load(&session.id).await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.id, session.id);
assert_eq!(loaded.config.name, session.config.name);
assert_eq!(loaded.messages.len(), 2);
assert_eq!(loaded.state, SessionState::Active);
}
#[tokio::test]
async fn test_file_store_load_nonexistent() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let loaded = store.load("nonexistent").await.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_file_store_delete() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let session = create_test_session_data();
store.save(&session).await.unwrap();
assert!(store.exists(&session.id).await.unwrap());
store.delete(&session.id).await.unwrap();
assert!(!store.exists(&session.id).await.unwrap());
assert!(store.load(&session.id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_file_store_list() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let list = store.list().await.unwrap();
assert!(list.is_empty());
for i in 1..=3 {
let mut session = create_test_session_data();
session.id = format!("session-{}", i);
store.save(&session).await.unwrap();
}
let list = store.list().await.unwrap();
assert_eq!(list.len(), 3);
assert!(list.contains(&"session-1".to_string()));
assert!(list.contains(&"session-2".to_string()));
assert!(list.contains(&"session-3".to_string()));
}
#[tokio::test]
async fn test_file_store_overwrite() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let mut session = create_test_session_data();
store.save(&session).await.unwrap();
session.messages.push(Message::user("Another message"));
session.updated_at = 1700000200;
store.save(&session).await.unwrap();
let loaded = store.load(&session.id).await.unwrap().unwrap();
assert_eq!(loaded.messages.len(), 3);
assert_eq!(loaded.updated_at, 1700000200);
}
#[tokio::test]
async fn test_file_store_path_traversal_prevention() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let mut session = create_test_session_data();
session.id = "../../../etc/passwd".to_string();
store.save(&session).await.unwrap();
let files: Vec<_> = std::fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.collect();
assert_eq!(files.len(), 1);
let loaded = store.load(&session.id).await.unwrap();
assert!(loaded.is_some());
}
#[tokio::test]
async fn test_file_store_with_policies() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let mut session = create_test_session_data();
session.config.confirmation_policy = Some(ConfirmationPolicy::enabled());
session.config.permission_policy = Some(PermissionPolicy::new().allow("Bash(cargo:*)"));
session.config.queue_config = Some(SessionQueueConfig::default());
store.save(&session).await.unwrap();
let loaded = store.load(&session.id).await.unwrap().unwrap();
assert!(loaded.config.confirmation_policy.is_some());
assert!(loaded.config.permission_policy.is_some());
assert!(loaded.config.queue_config.is_some());
}
#[tokio::test]
async fn test_file_store_with_llm_config() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let mut session = create_test_session_data();
session.llm_config = Some(LlmConfigData {
provider: "anthropic".to_string(),
model: "claude-3-5-sonnet-20241022".to_string(),
api_key: Some("secret".to_string()), base_url: None,
});
store.save(&session).await.unwrap();
let loaded = store.load(&session.id).await.unwrap().unwrap();
let llm_config = loaded.llm_config.unwrap();
assert_eq!(llm_config.provider, "anthropic");
assert_eq!(llm_config.model, "claude-3-5-sonnet-20241022");
assert!(llm_config.api_key.is_none());
}
#[tokio::test]
async fn test_memory_store_save_and_load() {
let store = MemorySessionStore::new();
let session = create_test_session_data();
store.save(&session).await.unwrap();
let loaded = store.load(&session.id).await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().id, session.id);
}
#[tokio::test]
async fn test_memory_store_delete() {
let store = MemorySessionStore::new();
let session = create_test_session_data();
store.save(&session).await.unwrap();
assert!(store.exists(&session.id).await.unwrap());
store.delete(&session.id).await.unwrap();
assert!(!store.exists(&session.id).await.unwrap());
}
#[tokio::test]
async fn test_memory_store_list() {
let store = MemorySessionStore::new();
for i in 1..=3 {
let mut session = create_test_session_data();
session.id = format!("session-{}", i);
store.save(&session).await.unwrap();
}
let list = store.list().await.unwrap();
assert_eq!(list.len(), 3);
}
#[test]
fn test_session_data_serialization() {
let session = create_test_session_data();
let json = serde_json::to_string(&session).unwrap();
let parsed: SessionData = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, session.id);
assert_eq!(parsed.messages.len(), session.messages.len());
}
#[test]
fn test_tool_names_from_definitions() {
let tools = vec![
crate::llm::ToolDefinition {
name: "bash".to_string(),
description: "Execute bash".to_string(),
parameters: serde_json::json!({}),
},
crate::llm::ToolDefinition {
name: "read".to_string(),
description: "Read file".to_string(),
parameters: serde_json::json!({}),
},
];
let names = SessionData::tool_names_from_definitions(&tools);
assert_eq!(names, vec!["bash", "read"]);
}
#[tokio::test]
async fn test_file_store_backslash_sanitization() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let mut session = create_test_session_data();
session.id = r"foo\bar\baz".to_string();
store.save(&session).await.unwrap();
let loaded = store.load(&session.id).await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.id, session.id);
let expected_path = dir.path().join("foo_bar_baz.json");
assert!(expected_path.exists());
}
#[tokio::test]
async fn test_file_store_mixed_separator_sanitization() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let mut session = create_test_session_data();
session.id = r"foo/bar\baz..qux".to_string();
store.save(&session).await.unwrap();
let loaded = store.load(&session.id).await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.id, session.id);
let expected_path = dir.path().join("foo_bar_baz_qux.json");
assert!(expected_path.exists());
}
#[tokio::test]
async fn test_file_store_corrupted_json_recovery() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let corrupted_path = dir.path().join("test-id.json");
tokio::fs::write(&corrupted_path, b"not valid json {{{")
.await
.unwrap();
let result = store.load("test-id").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_file_store_exists() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let session = create_test_session_data();
assert!(!store.exists(&session.id).await.unwrap());
store.save(&session).await.unwrap();
assert!(store.exists(&session.id).await.unwrap());
store.delete(&session.id).await.unwrap();
assert!(!store.exists(&session.id).await.unwrap());
}
#[tokio::test]
async fn test_memory_store_exists() {
let store = MemorySessionStore::new();
assert!(!store.exists("unknown-id").await.unwrap());
let session = create_test_session_data();
store.save(&session).await.unwrap();
assert!(store.exists(&session.id).await.unwrap());
}
#[tokio::test]
async fn test_file_store_health_check() {
let dir = tempfile::tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
assert!(store.health_check().await.is_ok());
assert_eq!(store.backend_name(), "file");
}
#[tokio::test]
async fn test_file_store_health_check_bad_dir() {
let store = FileSessionStore {
dir: std::path::PathBuf::from("/nonexistent/path/that/does/not/exist"),
};
assert!(store.health_check().await.is_err());
}
#[tokio::test]
async fn test_memory_store_health_check() {
let store = MemorySessionStore::new();
assert!(store.health_check().await.is_ok());
assert_eq!(store.backend_name(), "memory");
}
#[tokio::test]
async fn test_file_store_load_empty_file() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let empty_path = dir.path().join("empty-session.json");
tokio::fs::write(&empty_path, b"").await.unwrap();
let result = store.load("empty-session").await;
assert!(
result.is_err(),
"Empty file must return error, not Ok(None)"
);
}
#[tokio::test]
async fn test_file_store_load_partial_json() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let partial_path = dir.path().join("partial-session.json");
tokio::fs::write(&partial_path, b"{\"id\":\"partial-session\",\"message")
.await
.unwrap();
let result = store.load("partial-session").await;
assert!(result.is_err(), "Partial JSON must return error");
}
#[tokio::test]
async fn test_file_store_concurrent_save() {
let dir = tempdir().unwrap();
let store = std::sync::Arc::new(FileSessionStore::new(dir.path()).await.unwrap());
let session = create_test_session_data();
let id = session.id.clone();
store.save(&session).await.unwrap();
let mut handles = Vec::new();
for _ in 0..5 {
let s = store.clone();
let sess = session.clone();
handles.push(tokio::spawn(async move { s.save(&sess).await }));
}
for h in handles {
h.await.unwrap().unwrap();
}
let loaded = store.load(&id).await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().id, id);
}
#[tokio::test]
async fn test_file_store_load_nonexistent_returns_none() {
let dir = tempdir().unwrap();
let store = FileSessionStore::new(dir.path()).await.unwrap();
let result = store.load("does-not-exist-at-all").await.unwrap();
assert!(result.is_none(), "Missing session must return Ok(None)");
}
}