use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex};
use anyhow::{Context as AnyhowContext, Result};
use chrono::{DateTime, Utc};
use rusqlite::{params, Connection, OptionalExtension};
use serde_json::Value;
use super::invalidation::InvalidationChecker;
use super::types::*;
const SCHEMA: &str = r#"
CREATE TABLE IF NOT EXISTS context (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
namespace TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
expires_at TEXT,
git_commit TEXT,
file_path TEXT,
file_mtime INTEGER,
metadata TEXT
);
CREATE INDEX IF NOT EXISTS idx_context_namespace ON context(namespace);
CREATE INDEX IF NOT EXISTS idx_context_file_path ON context(file_path);
CREATE INDEX IF NOT EXISTS idx_context_expires_at ON context(expires_at);
"#;
pub struct ContextStore {
conn: Arc<Mutex<Connection>>,
invalidation: Arc<Mutex<InvalidationChecker>>,
}
impl ContextStore {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let conn = Connection::open(path.as_ref())
.with_context(|| format!("Failed to open context store: {:?}", path.as_ref()))?;
conn.execute_batch(SCHEMA)?;
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA busy_timeout = 5000;
PRAGMA synchronous = NORMAL;"
)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
invalidation: Arc::new(Mutex::new(InvalidationChecker::new())),
})
}
pub fn in_memory() -> Result<Self> {
let conn = Connection::open_in_memory()?;
conn.execute_batch(SCHEMA)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
invalidation: Arc::new(Mutex::new(InvalidationChecker::new())),
})
}
pub fn with_git_repo(mut self, repo_path: impl AsRef<Path>) -> Result<Self> {
let checker = InvalidationChecker::from_git_repo(repo_path)?;
self.invalidation = Arc::new(Mutex::new(checker));
Ok(self)
}
pub fn refresh_git_state(&self) -> Result<()> {
let mut invalidation = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
invalidation.refresh()
}
pub fn get(&self, key: &str) -> Result<Option<ContextEntry>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let entry = conn.query_row(
"SELECT key, value, namespace, created_at, updated_at, expires_at,
git_commit, file_path, file_mtime, metadata
FROM context WHERE key = ?",
[key],
|row| row_to_entry(row),
).optional()?;
if let Some(ref entry) = entry {
let invalidation = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
if !invalidation.is_valid(entry) {
drop(invalidation);
drop(conn);
self.delete(key)?;
return Ok(None);
}
}
Ok(entry)
}
pub fn get_if_valid(&self, key: &str) -> Result<Option<ContextEntry>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let entry = conn.query_row(
"SELECT key, value, namespace, created_at, updated_at, expires_at,
git_commit, file_path, file_mtime, metadata
FROM context WHERE key = ?",
[key],
|row| row_to_entry(row),
).optional()?;
if let Some(ref entry) = entry {
let invalidation = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
if !invalidation.is_valid(entry) {
return Ok(None);
}
}
Ok(entry)
}
pub fn set(&self, entry: ContextEntry) -> Result<()> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let (namespace, _) = Namespace::from_key(&entry.key);
let metadata_json = serde_json::to_string(&entry.metadata)?;
conn.execute(
"INSERT OR REPLACE INTO context
(key, value, namespace, created_at, updated_at, expires_at,
git_commit, file_path, file_mtime, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
params![
entry.key,
entry.value.to_string(),
namespace.prefix(),
entry.created_at.to_rfc3339(),
entry.updated_at.to_rfc3339(),
entry.expires_at.map(|t| t.to_rfc3339()),
entry.git_commit,
entry.file_path,
entry.file_mtime,
metadata_json,
],
)?;
Ok(())
}
pub fn set_value(&self, key: &str, value: Value) -> Result<()> {
let entry = ContextEntry::new(key, value);
self.set(entry)
}
pub fn delete(&self, key: &str) -> Result<bool> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let rows = conn.execute("DELETE FROM context WHERE key = ?", [key])?;
Ok(rows > 0)
}
pub fn delete_prefix(&self, prefix: &str) -> Result<usize> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let pattern = format!("{}%", prefix);
let rows = conn.execute("DELETE FROM context WHERE key LIKE ?", [&pattern])?;
Ok(rows)
}
pub fn exists(&self, key: &str) -> Result<bool> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let exists: bool = conn.query_row(
"SELECT 1 FROM context WHERE key = ?",
[key],
|_| Ok(true),
).optional()?.unwrap_or(false);
Ok(exists)
}
pub fn list(&self, query: &ContextQuery) -> Result<Vec<ContextEntry>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let mut sql = String::from(
"SELECT key, value, namespace, created_at, updated_at, expires_at,
git_commit, file_path, file_mtime, metadata
FROM context WHERE 1=1"
);
let mut params: Vec<String> = Vec::new();
if let Some(ref ns) = query.namespace {
sql.push_str(" AND namespace = ?");
params.push(ns.prefix().to_string());
}
if let Some(ref prefix) = query.prefix {
sql.push_str(" AND key LIKE ?");
params.push(format!("{}%", prefix));
}
if !query.include_expired {
sql.push_str(" AND (expires_at IS NULL OR expires_at > ?)");
params.push(Utc::now().to_rfc3339());
}
sql.push_str(" ORDER BY updated_at DESC");
if let Some(limit) = query.limit {
sql.push_str(&format!(" LIMIT {}", limit));
}
if let Some(offset) = query.offset {
sql.push_str(&format!(" OFFSET {}", offset));
}
let mut stmt = conn.prepare(&sql)?;
let params_refs: Vec<&dyn rusqlite::ToSql> = params.iter()
.map(|s| s as &dyn rusqlite::ToSql)
.collect();
let entries: Vec<ContextEntry> = stmt.query_map(params_refs.as_slice(), |row| row_to_entry(row))?
.filter_map(|r| r.ok())
.collect();
let invalidation = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let valid_entries: Vec<ContextEntry> = entries
.into_iter()
.filter(|e| invalidation.is_valid(e))
.collect();
Ok(valid_entries)
}
pub fn keys(&self, namespace: Namespace) -> Result<Vec<String>> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let mut stmt = conn.prepare(
"SELECT key FROM context WHERE namespace = ? ORDER BY key"
)?;
let keys: Vec<String> = stmt.query_map([namespace.prefix()], |row| row.get(0))?
.filter_map(|r| r.ok())
.collect();
Ok(keys)
}
pub fn get_file_context(&self, path: &str) -> Result<Option<FileContext>> {
let key = format!("file:{}", path);
if let Some(entry) = self.get(&key)? {
let ctx: FileContext = serde_json::from_value(entry.value)?;
Ok(Some(ctx))
} else {
Ok(None)
}
}
pub fn set_file_context(&self, path: &str, ctx: &FileContext) -> Result<()> {
let key = format!("file:{}", path);
let value = serde_json::to_value(ctx)?;
let mtime = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?
.get_mtime(path);
let mut entry = ContextEntry::new(&key, value)
.with_metadata("type", "file_context");
entry.file_path = Some(path.to_string());
entry.file_mtime = mtime;
if let Some(commit) = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?
.head_commit()
{
entry.git_commit = Some(commit.to_string());
}
self.set(entry)
}
pub fn get_file_attr(&self, path: &str, attr: &str) -> Result<Option<Value>> {
let key = format!("file:{}:{}", path, attr);
if let Some(entry) = self.get(&key)? {
Ok(Some(entry.value))
} else {
Ok(None)
}
}
pub fn set_file_attr(&self, path: &str, attr: &str, value: Value) -> Result<()> {
let key = format!("file:{}:{}", path, attr);
let mtime = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?
.get_mtime(path);
let mut entry = ContextEntry::new(&key, value);
entry.file_path = Some(path.to_string());
entry.file_mtime = mtime;
self.set(entry)
}
pub fn get_symbol(&self, name: &str) -> Result<Option<SymbolInfo>> {
let key = format!("symbol:{}", name);
if let Some(entry) = self.get(&key)? {
let info: SymbolInfo = serde_json::from_value(entry.value)?;
Ok(Some(info))
} else {
Ok(None)
}
}
pub fn set_symbol(&self, info: &SymbolInfo, file_path: Option<&str>) -> Result<()> {
let key = format!("symbol:{}", info.name);
let value = serde_json::to_value(info)?;
let mut entry = ContextEntry::new(&key, value);
if let Some(path) = file_path {
entry.file_path = Some(path.to_string());
entry.file_mtime = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?
.get_mtime(path);
}
self.set(entry)
}
pub fn find_symbols(&self, prefix: &str) -> Result<Vec<SymbolInfo>> {
let query = ContextQuery::new()
.namespace(Namespace::Symbol)
.prefix(&format!("symbol:{}", prefix));
let entries = self.list(&query)?;
let symbols: Vec<SymbolInfo> = entries
.into_iter()
.filter_map(|e| serde_json::from_value(e.value).ok())
.collect();
Ok(symbols)
}
pub fn get_project_context(&self) -> Result<Option<ProjectContext>> {
let key = "project:info";
if let Some(entry) = self.get(key)? {
let ctx: ProjectContext = serde_json::from_value(entry.value)?;
Ok(Some(ctx))
} else {
Ok(None)
}
}
pub fn set_project_context(&self, ctx: &ProjectContext) -> Result<()> {
let key = "project:info";
let value = serde_json::to_value(ctx)?;
let mut entry = ContextEntry::new(key, value);
if let Some(commit) = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?
.head_commit()
{
entry.git_commit = Some(commit.to_string());
}
self.set(entry)
}
pub fn get_project_attr(&self, attr: &str) -> Result<Option<Value>> {
let key = format!("project:{}", attr);
if let Some(entry) = self.get(&key)? {
Ok(Some(entry.value))
} else {
Ok(None)
}
}
pub fn set_project_attr(&self, attr: &str, value: Value) -> Result<()> {
let key = format!("project:{}", attr);
let entry = ContextEntry::new(&key, value);
self.set(entry)
}
pub fn get_session(&self, session_id: &str) -> Result<Option<SessionContext>> {
let key = format!("session:{}", session_id);
if let Some(entry) = self.get(&key)? {
let ctx: SessionContext = serde_json::from_value(entry.value)?;
Ok(Some(ctx))
} else {
Ok(None)
}
}
pub fn set_session(&self, ctx: &SessionContext) -> Result<()> {
let key = format!("session:{}", ctx.session_id);
let value = serde_json::to_value(ctx)?;
let entry = ContextEntry::new(&key, value);
self.set(entry)
}
pub fn update_working_files(&self, session_id: &str, files: Vec<String>) -> Result<()> {
if let Some(mut ctx) = self.get_session(session_id)? {
ctx.working_files = files;
ctx.last_activity = Utc::now();
self.set_session(&ctx)
} else {
let ctx = SessionContext {
session_id: session_id.to_string(),
working_files: files,
started_at: Utc::now(),
last_activity: Utc::now(),
..Default::default()
};
self.set_session(&ctx)
}
}
pub fn add_decision(
&self,
session_id: &str,
decision: &str,
rationale: Option<&str>,
context: Vec<String>,
) -> Result<()> {
let mut ctx = self.get_session(session_id)?
.unwrap_or_else(|| SessionContext {
session_id: session_id.to_string(),
started_at: Utc::now(),
last_activity: Utc::now(),
..Default::default()
});
ctx.decisions.push(Decision {
decision: decision.to_string(),
rationale: rationale.map(|s| s.to_string()),
timestamp: Utc::now(),
context,
});
ctx.last_activity = Utc::now();
self.set_session(&ctx)
}
pub fn get_batch(&self, keys: &[String]) -> Result<HashMap<String, ContextEntry>> {
let mut results = HashMap::new();
for key in keys {
if let Some(entry) = self.get(key)? {
results.insert(key.clone(), entry);
}
}
Ok(results)
}
pub fn set_batch(&self, entries: Vec<ContextEntry>) -> Result<()> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let tx = conn.unchecked_transaction()?;
for entry in entries {
let (namespace, _) = Namespace::from_key(&entry.key);
let metadata_json = serde_json::to_string(&entry.metadata)?;
tx.execute(
"INSERT OR REPLACE INTO context
(key, value, namespace, created_at, updated_at, expires_at,
git_commit, file_path, file_mtime, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
params![
entry.key,
entry.value.to_string(),
namespace.prefix(),
entry.created_at.to_rfc3339(),
entry.updated_at.to_rfc3339(),
entry.expires_at.map(|t| t.to_rfc3339()),
entry.git_commit,
entry.file_path,
entry.file_mtime,
metadata_json,
],
)?;
}
tx.commit()?;
Ok(())
}
pub fn cleanup_expired(&self) -> Result<usize> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let now = Utc::now().to_rfc3339();
let rows = conn.execute(
"DELETE FROM context WHERE expires_at IS NOT NULL AND expires_at < ?",
[&now],
)?;
Ok(rows)
}
pub fn cleanup_invalid(&self) -> Result<usize> {
let entries = self.list(&ContextQuery::new().include_expired())?;
let invalidation = self.invalidation.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let invalid_keys: Vec<String> = entries
.iter()
.filter(|e| !invalidation.is_valid(e))
.map(|e| e.key.clone())
.collect();
drop(invalidation);
let mut deleted = 0;
for key in invalid_keys {
if self.delete(&key)? {
deleted += 1;
}
}
Ok(deleted)
}
pub fn stats(&self) -> Result<ContextStats> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let total: usize = conn.query_row(
"SELECT COUNT(*) FROM context",
[],
|row| row.get(0),
)?;
let by_namespace: HashMap<String, usize> = {
let mut stmt = conn.prepare(
"SELECT namespace, COUNT(*) FROM context GROUP BY namespace"
)?;
let rows = stmt.query_map([], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?))
})?;
rows.filter_map(|r| r.ok()).collect()
};
let expired: usize = conn.query_row(
"SELECT COUNT(*) FROM context WHERE expires_at IS NOT NULL AND expires_at < ?",
[Utc::now().to_rfc3339()],
|row| row.get(0),
)?;
Ok(ContextStats {
total_entries: total,
by_namespace,
expired_entries: expired,
})
}
pub fn clear(&self) -> Result<()> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
conn.execute("DELETE FROM context", [])?;
Ok(())
}
pub fn clear_all(&self) -> Result<usize> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let count: usize = conn.query_row(
"SELECT COUNT(*) FROM context",
[],
|row| row.get(0),
)?;
conn.execute("DELETE FROM context", [])?;
Ok(count)
}
pub fn clear_namespace(&self, namespace: Namespace) -> Result<usize> {
let conn = self.conn.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {}", e))?;
let count: usize = conn.query_row(
"SELECT COUNT(*) FROM context WHERE namespace = ?",
[namespace.prefix()],
|row| row.get(0),
)?;
conn.execute(
"DELETE FROM context WHERE namespace = ?",
[namespace.prefix()],
)?;
Ok(count)
}
pub fn get_file_mtime(&self, path: &str) -> Option<i64> {
let invalidation = self.invalidation.lock().ok()?;
invalidation.get_mtime(path)
}
pub fn refresh_invalidation(&self) -> Result<()> {
self.refresh_git_state()
}
pub fn list_simple(&self, namespace: Option<Namespace>, prefix: Option<&str>) -> Result<Vec<ContextEntry>> {
let mut query = ContextQuery::new();
if let Some(ns) = namespace {
query = query.namespace(ns);
}
if let Some(p) = prefix {
query = query.prefix(p);
}
self.list(&query)
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct ContextStats {
pub total_entries: usize,
pub by_namespace: HashMap<String, usize>,
pub expired_entries: usize,
}
fn row_to_entry(row: &rusqlite::Row) -> rusqlite::Result<ContextEntry> {
let value_str: String = row.get(1)?;
let metadata_str: String = row.get(9)?;
Ok(ContextEntry {
key: row.get(0)?,
value: serde_json::from_str(&value_str).unwrap_or(Value::Null),
created_at: parse_datetime(&row.get::<_, String>(3)?),
updated_at: parse_datetime(&row.get::<_, String>(4)?),
expires_at: row.get::<_, Option<String>>(5)?.map(|s| parse_datetime(&s)),
git_commit: row.get(6)?,
file_path: row.get(7)?,
file_mtime: row.get(8)?,
metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
})
}
fn parse_datetime(s: &str) -> DateTime<Utc> {
DateTime::parse_from_rfc3339(s)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_basic_kv_operations() {
let store = ContextStore::in_memory().unwrap();
store.set_value("test:key1", json!({"data": "value1"})).unwrap();
let entry = store.get("test:key1").unwrap().unwrap();
assert_eq!(entry.value, json!({"data": "value1"}));
assert!(store.delete("test:key1").unwrap());
assert!(store.get("test:key1").unwrap().is_none());
}
#[test]
fn test_file_context() {
let store = ContextStore::in_memory().unwrap();
let ctx = FileContext {
path: "src/main.rs".to_string(),
summary: Some("Main entry point".to_string()),
language: Some("rust".to_string()),
..Default::default()
};
store.set_file_context("src/main.rs", &ctx).unwrap();
let retrieved = store.get_file_context("src/main.rs").unwrap().unwrap();
assert_eq!(retrieved.summary, Some("Main entry point".to_string()));
}
#[test]
fn test_session_context() {
let store = ContextStore::in_memory().unwrap();
store.update_working_files("session-1", vec!["file1.rs".to_string()]).unwrap();
store.add_decision(
"session-1",
"Use async/await for IO",
Some("Better concurrency"),
vec!["src/io.rs".to_string()],
).unwrap();
let session = store.get_session("session-1").unwrap().unwrap();
assert_eq!(session.working_files, vec!["file1.rs"]);
assert_eq!(session.decisions.len(), 1);
assert_eq!(session.decisions[0].decision, "Use async/await for IO");
}
#[test]
fn test_namespace_listing() {
let store = ContextStore::in_memory().unwrap();
store.set_value("file:a.rs", json!({})).unwrap();
store.set_value("file:b.rs", json!({})).unwrap();
store.set_value("project:info", json!({})).unwrap();
let file_keys = store.keys(Namespace::File).unwrap();
assert_eq!(file_keys.len(), 2);
let project_keys = store.keys(Namespace::Project).unwrap();
assert_eq!(project_keys.len(), 1);
}
#[test]
fn test_ttl_expiration() {
let store = ContextStore::in_memory().unwrap();
let mut entry = ContextEntry::new("test:expired", json!({}));
entry.expires_at = Some(Utc::now() - chrono::Duration::hours(1));
store.set(entry).unwrap();
assert!(store.get("test:expired").unwrap().is_none());
}
}