use crate::builder::GraphBuilder;
use crate::graph::ArborGraph;
use arbor_core::CodeNode;
use sled::{Batch, Db};
use std::path::Path;
use thiserror::Error;
const CACHE_VERSION: &str = concat!("arbor-", env!("CARGO_PKG_VERSION"));
#[derive(Error, Debug)]
pub enum StoreError {
#[error("Database error: {0}")]
Sled(#[from] sled::Error),
#[error("Serialization error: {0}")]
Bincode(#[from] bincode::Error),
#[error("Corrupted data: {0}")]
Corrupted(String),
#[error("Cache version mismatch: expected {expected}, found {found}")]
VersionMismatch { expected: String, found: String },
}
pub struct GraphStore {
db: Db,
}
impl GraphStore {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, StoreError> {
let db = sled::open(path)?;
let store = Self { db };
if let Some(version_bytes) = store.db.get("meta:version")? {
let version: String = bincode::deserialize(&version_bytes)?;
if version != CACHE_VERSION {
return Err(StoreError::VersionMismatch {
expected: CACHE_VERSION.to_string(),
found: version,
});
}
} else {
let version_bytes = bincode::serialize(&CACHE_VERSION.to_string())?;
store.db.insert("meta:version", version_bytes)?;
}
Ok(store)
}
pub fn open_or_reset<P: AsRef<Path>>(path: P) -> Result<Self, StoreError> {
match Self::open(path.as_ref()) {
Ok(store) => Ok(store),
Err(StoreError::VersionMismatch { .. }) => {
let db = sled::open(path.as_ref())?;
db.clear()?;
let version_bytes = bincode::serialize(&CACHE_VERSION.to_string())?;
db.insert("meta:version", version_bytes)?;
db.flush()?;
Ok(Self { db })
}
Err(e) => Err(e),
}
}
pub fn get_mtime(&self, file_path: &str) -> Result<Option<u64>, StoreError> {
let key = format!("m:{}", file_path);
match self.db.get(&key)? {
Some(bytes) => {
let mtime: u64 = bincode::deserialize(&bytes)?;
Ok(Some(mtime))
}
None => Ok(None),
}
}
pub fn get_file_nodes(&self, file_path: &str) -> Result<Option<Vec<CodeNode>>, StoreError> {
let file_key = format!("f:{}", file_path);
match self.db.get(&file_key)? {
Some(index_bytes) => {
let node_ids: Vec<String> = bincode::deserialize(&index_bytes)?;
let mut nodes = Vec::with_capacity(node_ids.len());
for id in node_ids {
let node_key = format!("n:{}", id);
if let Some(node_bytes) = self.db.get(&node_key)? {
let node: CodeNode = bincode::deserialize(&node_bytes)?;
nodes.push(node);
}
}
Ok(Some(nodes))
}
None => Ok(None),
}
}
pub fn update_file(
&self,
file_path: &str,
nodes: &[CodeNode],
mtime: u64,
) -> Result<(), StoreError> {
let file_key = format!("f:{}", file_path);
let mtime_key = format!("m:{}", file_path);
let mut batch = Batch::default();
if let Some(old_bytes) = self.db.get(&file_key)? {
let old_ids: Vec<String> = bincode::deserialize(&old_bytes)?;
for id in old_ids {
batch.remove(format!("n:{}", id).as_bytes());
}
}
let mut new_ids = Vec::with_capacity(nodes.len());
for node in nodes {
let node_key = format!("n:{}", node.id);
let bytes = bincode::serialize(node)?;
batch.insert(node_key.as_bytes(), bytes);
new_ids.push(node.id.clone());
}
let index_bytes = bincode::serialize(&new_ids)?;
batch.insert(file_key.as_bytes(), index_bytes);
let mtime_bytes = bincode::serialize(&mtime)?;
batch.insert(mtime_key.as_bytes(), mtime_bytes);
self.db.apply_batch(batch)?;
self.db.flush()?;
Ok(())
}
pub fn remove_file(&self, file_path: &str) -> Result<(), StoreError> {
let file_key = format!("f:{}", file_path);
let mtime_key = format!("m:{}", file_path);
let mut batch = Batch::default();
if let Some(old_bytes) = self.db.get(&file_key)? {
let old_ids: Vec<String> = bincode::deserialize(&old_bytes)?;
for id in old_ids {
batch.remove(format!("n:{}", id).as_bytes());
}
}
batch.remove(file_key.as_bytes());
batch.remove(mtime_key.as_bytes());
self.db.apply_batch(batch)?;
self.db.flush()?;
Ok(())
}
pub fn list_cached_files(&self) -> Result<Vec<String>, StoreError> {
let mut files = Vec::new();
let prefix = b"f:";
for item in self.db.scan_prefix(prefix) {
let (key, _) = item?;
let key_str = String::from_utf8_lossy(&key);
if let Some(file_path) = key_str.strip_prefix("f:") {
files.push(file_path.to_string());
}
}
Ok(files)
}
pub fn load_graph(&self) -> Result<ArborGraph, StoreError> {
let mut builder = GraphBuilder::new();
let mut nodes = Vec::new();
let prefix = b"n:";
for item in self.db.scan_prefix(prefix) {
let (_key, value) = item?;
let node: CodeNode = bincode::deserialize(&value)?;
nodes.push(node);
}
if nodes.is_empty() {
return Ok(ArborGraph::new());
}
builder.add_nodes(nodes);
let graph = builder.build();
Ok(graph)
}
pub fn clear(&self) -> Result<(), StoreError> {
self.db.clear()?;
let version_bytes = bincode::serialize(&CACHE_VERSION.to_string())?;
self.db.insert("meta:version", version_bytes)?;
self.db.flush()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use arbor_core::NodeKind;
use tempfile::tempdir;
#[test]
fn test_incremental_updates() {
let dir = tempdir().unwrap();
let store = GraphStore::open(dir.path()).unwrap();
let node1 = CodeNode::new("foo", "foo", NodeKind::Function, "test.rs");
let node2 = CodeNode::new("bar", "bar", NodeKind::Function, "test.rs");
store
.update_file("test.rs", &[node1.clone(), node2.clone()], 1000)
.unwrap();
let graph = store.load_graph().unwrap();
assert_eq!(graph.node_count(), 2);
assert_eq!(store.get_mtime("test.rs").unwrap(), Some(1000));
store
.update_file("test.rs", std::slice::from_ref(&node1), 2000)
.unwrap();
let graph2 = store.load_graph().unwrap();
assert_eq!(graph2.node_count(), 1);
assert!(!graph2.find_by_name("foo").is_empty());
assert!(graph2.find_by_name("bar").is_empty());
assert_eq!(store.get_mtime("test.rs").unwrap(), Some(2000));
}
#[test]
fn test_cache_version() {
let dir = tempdir().unwrap();
let store = GraphStore::open(dir.path()).unwrap();
drop(store);
let store2 = GraphStore::open(dir.path()).unwrap();
drop(store2);
}
#[test]
fn test_remove_file() {
let dir = tempdir().unwrap();
let store = GraphStore::open(dir.path()).unwrap();
let node = CodeNode::new("foo", "foo", NodeKind::Function, "test.rs");
store.update_file("test.rs", &[node], 1000).unwrap();
assert!(store.get_mtime("test.rs").unwrap().is_some());
assert!(store.get_file_nodes("test.rs").unwrap().is_some());
store.remove_file("test.rs").unwrap();
assert!(store.get_mtime("test.rs").unwrap().is_none());
assert!(store.get_file_nodes("test.rs").unwrap().is_none());
}
#[test]
fn test_list_cached_files() {
let dir = tempdir().unwrap();
let store = GraphStore::open(dir.path()).unwrap();
let node1 = CodeNode::new("foo", "foo", NodeKind::Function, "a.rs");
let node2 = CodeNode::new("bar", "bar", NodeKind::Function, "b.rs");
store.update_file("a.rs", &[node1], 1000).unwrap();
store.update_file("b.rs", &[node2], 2000).unwrap();
let files = store.list_cached_files().unwrap();
assert_eq!(files.len(), 2);
assert!(files.contains(&"a.rs".to_string()));
assert!(files.contains(&"b.rs".to_string()));
}
}