use std::path::Path;
use std::path::PathBuf;
use chrono::{DateTime, Utc};
use rusqlite::{params, Connection, OpenFlags};
use serde::{Deserialize, Serialize};
use crate::error::{Result, SqzError};
use crate::types::{CompressedContent, SessionId, SessionState};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionSummary {
pub id: SessionId,
pub project_dir: PathBuf,
pub compressed_summary: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
pub struct SessionStore {
db: Connection,
}
const SCHEMA: &str = r#"
PRAGMA journal_mode = WAL;
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
project_dir TEXT NOT NULL,
compressed_summary TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
data BLOB NOT NULL
);
CREATE VIRTUAL TABLE IF NOT EXISTS sessions_fts USING fts5(
id,
project_dir,
compressed_summary,
content='sessions',
content_rowid='rowid',
tokenize='porter ascii'
);
CREATE TRIGGER IF NOT EXISTS sessions_ai AFTER INSERT ON sessions BEGIN
INSERT INTO sessions_fts(rowid, id, project_dir, compressed_summary)
VALUES (new.rowid, new.id, new.project_dir, new.compressed_summary);
END;
CREATE TRIGGER IF NOT EXISTS sessions_ad AFTER DELETE ON sessions BEGIN
INSERT INTO sessions_fts(sessions_fts, rowid, id, project_dir, compressed_summary)
VALUES ('delete', old.rowid, old.id, old.project_dir, old.compressed_summary);
END;
CREATE TRIGGER IF NOT EXISTS sessions_au AFTER UPDATE ON sessions BEGIN
INSERT INTO sessions_fts(sessions_fts, rowid, id, project_dir, compressed_summary)
VALUES ('delete', old.rowid, old.id, old.project_dir, old.compressed_summary);
INSERT INTO sessions_fts(rowid, id, project_dir, compressed_summary)
VALUES (new.rowid, new.id, new.project_dir, new.compressed_summary);
END;
CREATE TABLE IF NOT EXISTS cache_entries (
hash TEXT PRIMARY KEY,
data TEXT NOT NULL,
accessed_at TEXT NOT NULL
);
"#;
pub(crate) fn apply_schema(conn: &Connection) -> rusqlite::Result<()> {
conn.execute_batch(SCHEMA)
}
fn open_connection(path: &Path) -> rusqlite::Result<Connection> {
let conn = Connection::open(path)?;
apply_schema(&conn)?;
Ok(conn)
}
fn row_to_summary(
id: String,
project_dir: String,
compressed_summary: String,
created_at: String,
updated_at: String,
) -> Result<SessionSummary> {
let created_at = created_at
.parse::<DateTime<Utc>>()
.map_err(|e| SqzError::Other(format!("invalid created_at timestamp: {e}")))?;
let updated_at = updated_at
.parse::<DateTime<Utc>>()
.map_err(|e| SqzError::Other(format!("invalid updated_at timestamp: {e}")))?;
Ok(SessionSummary {
id,
project_dir: PathBuf::from(project_dir),
compressed_summary,
created_at,
updated_at,
})
}
impl SessionStore {
#[cfg(test)]
pub(crate) fn from_connection(conn: Connection) -> Self {
Self { db: conn }
}
pub fn open(path: &Path) -> Result<Self> {
let conn = Connection::open_with_flags(path, OpenFlags::SQLITE_OPEN_READ_WRITE)?;
apply_schema(&conn)?;
Ok(Self { db: conn })
}
pub fn open_or_create(path: &Path) -> Result<Self> {
match open_connection(path) {
Ok(conn) => Ok(Self { db: conn }),
Err(e) => {
eprintln!(
"sqz warning: session store at '{}' is corrupted or inaccessible ({e}). \
Creating a new database. Prior session data has been lost.",
path.display()
);
if path.exists() {
let _ = std::fs::remove_file(path);
}
let conn = open_connection(path)
.map_err(|e2| SqzError::Other(format!("failed to create new session store: {e2}")))?;
Ok(Self { db: conn })
}
}
}
pub fn save_session(&self, session: &SessionState) -> Result<SessionId> {
let data = serde_json::to_vec(session)?;
let project_dir = session.project_dir.to_string_lossy().to_string();
let created_at = session.created_at.to_rfc3339();
let updated_at = session.updated_at.to_rfc3339();
self.db.execute(
r#"INSERT INTO sessions (id, project_dir, compressed_summary, created_at, updated_at, data)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
ON CONFLICT(id) DO UPDATE SET
project_dir = excluded.project_dir,
compressed_summary = excluded.compressed_summary,
created_at = excluded.created_at,
updated_at = excluded.updated_at,
data = excluded.data"#,
params![
session.id,
project_dir,
session.compressed_summary,
created_at,
updated_at,
data,
],
)?;
Ok(session.id.clone())
}
pub fn load_session(&self, id: SessionId) -> Result<SessionState> {
let data: Vec<u8> = self.db.query_row(
"SELECT data FROM sessions WHERE id = ?1",
params![id],
|row| row.get(0),
)?;
let session: SessionState = serde_json::from_slice(&data)?;
Ok(session)
}
pub fn search(&self, query: &str) -> Result<Vec<SessionSummary>> {
let mut stmt = self.db.prepare(
r#"SELECT s.id, s.project_dir, s.compressed_summary, s.created_at, s.updated_at
FROM sessions s
JOIN sessions_fts f ON s.rowid = f.rowid
WHERE sessions_fts MATCH ?1
ORDER BY rank"#,
)?;
let rows = stmt.query_map(params![query], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
row.get::<_, String>(4)?,
))
})?;
let mut results = Vec::new();
for row in rows {
let (id, project_dir, compressed_summary, created_at, updated_at) = row?;
results.push(row_to_summary(id, project_dir, compressed_summary, created_at, updated_at)?);
}
Ok(results)
}
pub fn search_by_date(
&self,
from: DateTime<Utc>,
to: DateTime<Utc>,
) -> Result<Vec<SessionSummary>> {
let mut stmt = self.db.prepare(
r#"SELECT id, project_dir, compressed_summary, created_at, updated_at
FROM sessions
WHERE updated_at >= ?1 AND updated_at <= ?2
ORDER BY updated_at DESC"#,
)?;
let rows = stmt.query_map(params![from.to_rfc3339(), to.to_rfc3339()], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
row.get::<_, String>(4)?,
))
})?;
let mut results = Vec::new();
for row in rows {
let (id, project_dir, compressed_summary, created_at, updated_at) = row?;
results.push(row_to_summary(id, project_dir, compressed_summary, created_at, updated_at)?);
}
Ok(results)
}
pub fn search_by_project(&self, dir: &Path) -> Result<Vec<SessionSummary>> {
let dir_str = dir.to_string_lossy().to_string();
let mut stmt = self.db.prepare(
r#"SELECT id, project_dir, compressed_summary, created_at, updated_at
FROM sessions
WHERE project_dir = ?1
ORDER BY updated_at DESC"#,
)?;
let rows = stmt.query_map(params![dir_str], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
row.get::<_, String>(4)?,
))
})?;
let mut results = Vec::new();
for row in rows {
let (id, project_dir, compressed_summary, created_at, updated_at) = row?;
results.push(row_to_summary(id, project_dir, compressed_summary, created_at, updated_at)?);
}
Ok(results)
}
pub fn save_cache_entry(&self, hash: &str, compressed: &CompressedContent) -> Result<()> {
let data = serde_json::to_string(compressed)?;
let now = Utc::now().to_rfc3339();
self.db.execute(
r#"INSERT INTO cache_entries (hash, data, accessed_at)
VALUES (?1, ?2, ?3)
ON CONFLICT(hash) DO UPDATE SET data = excluded.data, accessed_at = excluded.accessed_at"#,
params![hash, data, now],
)?;
Ok(())
}
pub fn delete_cache_entry(&self, hash: &str) -> Result<()> {
self.db.execute(
"DELETE FROM cache_entries WHERE hash = ?1",
params![hash],
)?;
Ok(())
}
pub fn list_cache_entries_lru(&self) -> Result<Vec<(String, u64)>> {
let mut stmt = self.db.prepare(
"SELECT hash, length(data) FROM cache_entries ORDER BY accessed_at ASC",
)?;
let rows = stmt.query_map([], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
})?;
let mut entries = Vec::new();
for row in rows {
let (hash, size) = row?;
entries.push((hash, size as u64));
}
Ok(entries)
}
pub fn get_cache_entry(&self, hash: &str) -> Result<Option<CompressedContent>> {
let result: rusqlite::Result<String> = self.db.query_row(
"SELECT data FROM cache_entries WHERE hash = ?1",
params![hash],
|row| row.get(0),
);
match result {
Ok(data) => {
let now = Utc::now().to_rfc3339();
let _ = self.db.execute(
"UPDATE cache_entries SET accessed_at = ?1 WHERE hash = ?2",
params![now, hash],
);
let entry: CompressedContent = serde_json::from_str(&data)?;
Ok(Some(entry))
}
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(SqzError::SessionStore(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{BudgetState, CorrectionLog, ModelFamily, SessionState};
use chrono::Utc;
use proptest::prelude::*;
use std::path::PathBuf;
fn make_session(id: &str, project_dir: &str, summary: &str) -> SessionState {
let now = Utc::now();
SessionState {
id: id.to_string(),
project_dir: PathBuf::from(project_dir),
conversation: vec![],
corrections: CorrectionLog::default(),
pins: vec![],
learnings: vec![],
compressed_summary: summary.to_string(),
budget: BudgetState {
window_size: 200_000,
consumed: 0,
pinned: 0,
model_family: ModelFamily::AnthropicClaude,
},
tool_usage: vec![],
created_at: now,
updated_at: now,
}
}
fn in_memory_store() -> SessionStore {
let conn = Connection::open_in_memory().unwrap();
apply_schema(&conn).unwrap();
SessionStore { db: conn }
}
#[test]
fn test_save_and_load_session() {
let store = in_memory_store();
let session = make_session("sess-1", "/home/user/project", "REST API refactor");
let id = store.save_session(&session).unwrap();
assert_eq!(id, "sess-1");
let loaded = store.load_session("sess-1".to_string()).unwrap();
assert_eq!(loaded.id, session.id);
assert_eq!(loaded.compressed_summary, session.compressed_summary);
assert_eq!(loaded.project_dir, session.project_dir);
}
#[test]
fn test_save_session_upsert() {
let store = in_memory_store();
let mut session = make_session("sess-2", "/proj", "initial summary");
store.save_session(&session).unwrap();
session.compressed_summary = "updated summary".to_string();
store.save_session(&session).unwrap();
let loaded = store.load_session("sess-2".to_string()).unwrap();
assert_eq!(loaded.compressed_summary, "updated summary");
}
#[test]
fn test_load_nonexistent_session_errors() {
let store = in_memory_store();
let result = store.load_session("does-not-exist".to_string());
assert!(result.is_err());
}
#[test]
fn test_search_fts() {
let store = in_memory_store();
store.save_session(&make_session("s1", "/proj", "REST API refactor with authentication")).unwrap();
store.save_session(&make_session("s2", "/proj", "database migration postgres")).unwrap();
let results = store.search("authentication").unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "s1");
}
#[test]
fn test_search_by_date() {
let store = in_memory_store();
let now = Utc::now();
let past = now - chrono::Duration::hours(2);
let future = now + chrono::Duration::hours(2);
store.save_session(&make_session("s1", "/proj", "recent session")).unwrap();
let results = store.search_by_date(past, future).unwrap();
assert!(!results.is_empty());
assert!(results.iter().any(|r| r.id == "s1"));
}
#[test]
fn test_search_by_project() {
let store = in_memory_store();
store.save_session(&make_session("s1", "/home/user/alpha", "alpha project")).unwrap();
store.save_session(&make_session("s2", "/home/user/beta", "beta project")).unwrap();
let results = store.search_by_project(Path::new("/home/user/alpha")).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "s1");
}
#[test]
fn test_cache_entry_round_trip() {
let store = in_memory_store();
let entry = CompressedContent {
data: "compressed data".to_string(),
tokens_compressed: 10,
tokens_original: 50,
stages_applied: vec!["strip_nulls".to_string()],
compression_ratio: 0.2,
};
store.save_cache_entry("abc123", &entry).unwrap();
let loaded = store.get_cache_entry("abc123").unwrap().unwrap();
assert_eq!(loaded.data, entry.data);
assert_eq!(loaded.tokens_compressed, entry.tokens_compressed);
assert_eq!(loaded.tokens_original, entry.tokens_original);
}
#[test]
fn test_get_cache_entry_missing_returns_none() {
let store = in_memory_store();
let result = store.get_cache_entry("nonexistent").unwrap();
assert!(result.is_none());
}
#[test]
fn test_open_or_create_corrupted_db() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("store.db");
std::fs::write(&path, b"this is not a valid sqlite database").unwrap();
let store = SessionStore::open_or_create(&path).unwrap();
let session = make_session("s1", "/proj", "after corruption");
store.save_session(&session).unwrap();
let loaded = store.load_session("s1".to_string()).unwrap();
assert_eq!(loaded.id, "s1");
}
fn make_session_at(id: &str, summary: &str, updated_at: DateTime<Utc>) -> SessionState {
let now = Utc::now();
SessionState {
id: id.to_string(),
project_dir: PathBuf::from("/proj"),
conversation: vec![],
corrections: CorrectionLog::default(),
pins: vec![],
learnings: vec![],
compressed_summary: summary.to_string(),
budget: BudgetState {
window_size: 200_000,
consumed: 0,
pinned: 0,
model_family: ModelFamily::AnthropicClaude,
},
tool_usage: vec![],
created_at: now,
updated_at,
}
}
proptest! {
#[test]
fn prop_search_correctness(
keyword in "[b-df-hj-np-tv-z]{5,8}",
matching_suffixes in proptest::collection::vec("[a-z ]{4,20}", 1..=6usize),
non_matching in proptest::collection::vec("[a-z ]{8,30}", 1..=6usize),
) {
for s in &non_matching {
prop_assume!(!s.contains(keyword.as_str()));
}
let store = in_memory_store();
let mut matching_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
for (i, suffix) in matching_suffixes.iter().enumerate() {
let id = format!("match-{i}");
let summary = format!("{} {} end", suffix, keyword);
store.save_session(&make_session(&id, "/proj", &summary)).unwrap();
matching_ids.insert(id);
}
let mut non_matching_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
for (i, summary) in non_matching.iter().enumerate() {
let id = format!("nomatch-{i}");
store.save_session(&make_session(&id, "/proj", summary)).unwrap();
non_matching_ids.insert(id);
}
let results = store.search(&keyword).unwrap();
let result_ids: std::collections::HashSet<String> =
results.iter().map(|r| r.id.clone()).collect();
for id in &matching_ids {
prop_assert!(
result_ids.contains(id),
"matching session '{}' not found in search results for keyword '{}'",
id, keyword
);
}
for id in &non_matching_ids {
prop_assert!(
!result_ids.contains(id),
"non-matching session '{}' incorrectly appeared in search results for keyword '{}'",
id, keyword
);
}
}
}
proptest! {
#[test]
fn prop_search_by_date_correctness(
offsets in proptest::collection::vec(0i64..=86400i64 * 365, 2..=8usize),
window_start_delta in 0i64..=3600i64,
window_end_delta in 3600i64..=7200i64,
) {
use chrono::TimeZone;
let mut unique_offsets: Vec<i64> = offsets.clone();
unique_offsets.sort_unstable();
unique_offsets.dedup();
prop_assume!(unique_offsets.len() >= 2);
let base_offset = unique_offsets[0];
let from_offset = base_offset + window_start_delta;
let to_offset = base_offset + window_end_delta;
let from = Utc.timestamp_opt(from_offset, 0).unwrap();
let to = Utc.timestamp_opt(to_offset, 0).unwrap();
let store = in_memory_store();
let mut in_range_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut out_range_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
for (i, &offset) in unique_offsets.iter().enumerate() {
let ts = Utc.timestamp_opt(offset, 0).unwrap();
let id = format!("sess-{i}");
let session = make_session_at(&id, "some summary", ts);
store.save_session(&session).unwrap();
if ts >= from && ts <= to {
in_range_ids.insert(id);
} else {
out_range_ids.insert(id);
}
}
let results = store.search_by_date(from, to).unwrap();
let result_ids: std::collections::HashSet<String> =
results.iter().map(|r| r.id.clone()).collect();
for id in &in_range_ids {
prop_assert!(
result_ids.contains(id),
"in-range session '{}' missing from search_by_date results",
id
);
}
for id in &out_range_ids {
prop_assert!(
!result_ids.contains(id),
"out-of-range session '{}' incorrectly appeared in search_by_date results",
id
);
}
}
}
}