use anyhow::{Context, Result};
use rusqlite::{Connection, OptionalExtension};
use std::path::Path;
use crate::models::SearchResult;
#[cfg(test)]
use crate::models::{Language, Span, SymbolKind};
pub struct SymbolCache {
db_path: std::path::PathBuf,
}
impl SymbolCache {
pub fn open(cache_dir: &Path) -> Result<Self> {
let db_path = cache_dir.join("meta.db");
if !db_path.exists() {
anyhow::bail!("Cache not initialized - run 'rfx index' first");
}
let cache = Self { db_path };
cache.init_schema()?;
Ok(cache)
}
fn init_schema(&self) -> Result<()> {
let conn = Connection::open(&self.db_path)
.context("Failed to open meta.db")?;
let uses_file_id: bool = conn
.query_row(
"SELECT COUNT(*) FROM pragma_table_info('symbols') WHERE name='file_id'",
[],
|row| row.get::<_, i64>(0),
)
.unwrap_or(0) > 0;
if !uses_file_id {
let table_exists: bool = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='symbols'",
[],
|row| row.get::<_, i64>(0),
)
.unwrap_or(0) > 0;
if table_exists {
log::warn!("Symbol cache schema outdated - migrating to file_id-based schema");
conn.execute("DROP TABLE IF EXISTS symbols", [])?;
}
}
conn.execute(
"CREATE TABLE IF NOT EXISTS symbols (
file_id INTEGER NOT NULL,
file_hash TEXT NOT NULL,
symbols_json TEXT NOT NULL,
last_cached INTEGER NOT NULL,
PRIMARY KEY (file_id, file_hash),
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_symbols_file_id ON symbols(file_id)",
[],
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_symbols_hash ON symbols(file_hash)",
[],
)?;
log::debug!("Symbol cache schema initialized (file_id-based)");
Ok(())
}
pub fn get(&self, file_path: &str, file_hash: &str) -> Result<Option<Vec<SearchResult>>> {
let conn = Connection::open(&self.db_path)?;
let file_id: Option<i64> = conn
.query_row(
"SELECT id FROM files WHERE path = ?",
[file_path],
|row| row.get(0),
)
.optional()?;
let Some(file_id) = file_id else {
log::debug!("Symbol cache MISS: {} (file not in index)", file_path);
return Ok(None);
};
let symbols_json: Option<String> = conn
.query_row(
"SELECT symbols_json FROM symbols WHERE file_id = ? AND file_hash = ?",
[&file_id.to_string(), file_hash],
|row| row.get(0),
)
.optional()?;
match symbols_json {
Some(json) => {
let mut symbols: Vec<SearchResult> = serde_json::from_str(&json)
.context("Failed to deserialize cached symbols")?;
for symbol in &mut symbols {
symbol.path = file_path.to_string();
}
log::debug!("Symbol cache HIT: {} ({} symbols)", file_path, symbols.len());
Ok(Some(symbols))
}
None => {
log::debug!("Symbol cache MISS: {}", file_path);
Ok(None)
}
}
}
pub fn batch_get(&self, files: &[(String, String)]) -> Result<Vec<(String, Option<Vec<SearchResult>>)>> {
if files.is_empty() {
return Ok(Vec::new());
}
let conn = Connection::open(&self.db_path)?;
let mut file_id_stmt = conn.prepare("SELECT id FROM files WHERE path = ?")?;
let mut symbols_stmt = conn.prepare(
"SELECT symbols_json FROM symbols WHERE file_id = ? AND file_hash = ?"
)?;
let mut results = Vec::with_capacity(files.len());
let mut hits = 0;
let mut misses = 0;
for (file_path, file_hash) in files {
let file_id: Option<i64> = file_id_stmt
.query_row([file_path.as_str()], |row| row.get(0))
.optional()?;
let symbols = if let Some(file_id) = file_id {
let symbols_json: Option<String> = symbols_stmt
.query_row([&file_id.to_string(), file_hash.as_str()], |row| row.get(0))
.optional()?;
match symbols_json {
Some(json) => {
match serde_json::from_str::<Vec<SearchResult>>(&json) {
Ok(mut symbols) => {
for symbol in &mut symbols {
symbol.path = file_path.clone();
}
hits += 1;
Some(symbols)
}
Err(e) => {
log::warn!("Failed to deserialize cached symbols for {}: {}", file_path, e);
misses += 1;
None
}
}
}
None => {
misses += 1;
None
}
}
} else {
misses += 1;
None
};
results.push((file_path.clone(), symbols));
}
log::debug!("Batch symbol cache: {} hits, {} misses ({} total)", hits, misses, files.len());
Ok(results)
}
pub fn batch_get_with_kind(
&self,
file_ids: &[(i64, String, String)], kind_filter: Option<crate::models::SymbolKind>
) -> Result<std::collections::HashMap<i64, Vec<SearchResult>>> {
use std::collections::HashMap;
if file_ids.is_empty() {
return Ok(HashMap::new());
}
let conn = Connection::open(&self.db_path)?;
const BATCH_SIZE: usize = 900;
let file_info: HashMap<i64, (String, String)> = file_ids.iter()
.map(|(id, hash, path)| (*id, (hash.clone(), path.clone())))
.collect();
let kind_for_filtering = kind_filter.clone();
let mut cache_map: HashMap<i64, Vec<SearchResult>> = HashMap::new();
let mut hits = 0;
for chunk in file_ids.chunks(BATCH_SIZE) {
let id_placeholders = chunk.iter()
.map(|_| "?")
.collect::<Vec<_>>()
.join(", ");
let query = format!(
"SELECT file_id, symbols_json
FROM symbols
WHERE file_id IN ({})",
id_placeholders
);
let params: Vec<Box<dyn rusqlite::ToSql>> = chunk.iter()
.map(|(id, _, _)| Box::new(*id) as Box<dyn rusqlite::ToSql>)
.collect();
let mut stmt = conn.prepare(&query)?;
let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect();
let rows = stmt.query_map(param_refs.as_slice(), |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, String>(1)?
))
})?;
for row_result in rows {
let (file_id, symbols_json) = row_result?;
if let Some((_hash, file_path)) = file_info.get(&file_id) {
match serde_json::from_str::<Vec<SearchResult>>(&symbols_json) {
Ok(mut symbols) => {
for symbol in &mut symbols {
symbol.path = file_path.clone();
}
if let Some(ref filter_kind) = kind_for_filtering {
symbols.retain(|s| &s.kind == filter_kind);
}
cache_map.insert(file_id, symbols);
hits += 1;
}
Err(e) => {
log::warn!("Failed to deserialize cached symbols for file_id {}: {}", file_id, e);
}
}
}
}
}
let misses = file_ids.len() - hits;
if kind_for_filtering.is_some() {
log::debug!(
"Batch symbol cache with Rust-side kind filter: {} hits, {} misses ({} total, {} chunks)",
hits, misses, file_ids.len(), (file_ids.len() + BATCH_SIZE - 1) / BATCH_SIZE
);
} else {
log::debug!(
"Batch symbol cache: {} hits, {} misses ({} total, {} chunks)",
hits, misses, file_ids.len(), (file_ids.len() + BATCH_SIZE - 1) / BATCH_SIZE
);
}
Ok(cache_map)
}
pub fn set(&self, file_path: &str, file_hash: &str, symbols: &[SearchResult]) -> Result<()> {
let conn = Connection::open(&self.db_path)?;
let file_id: i64 = conn.query_row(
"SELECT id FROM files WHERE path = ?",
[file_path],
|row| row.get(0)
).context(format!("File not found in index: {}", file_path))?;
let symbols_without_path: Vec<_> = symbols
.iter()
.map(|s| {
let mut s = s.clone();
s.path = String::new(); s
})
.collect();
let symbols_json = serde_json::to_string(&symbols_without_path)
.context("Failed to serialize symbols")?;
let now = chrono::Utc::now().timestamp();
conn.execute(
"INSERT OR REPLACE INTO symbols (file_id, file_hash, symbols_json, last_cached)
VALUES (?, ?, ?, ?)",
[&file_id.to_string(), file_hash, &symbols_json, &now.to_string()],
)?;
log::debug!("Cached {} symbols for {}", symbols.len(), file_path);
Ok(())
}
pub fn batch_set(&self, entries: &[(String, String, Vec<SearchResult>)]) -> Result<()> {
let mut conn = Connection::open(&self.db_path)?;
let tx = conn.transaction()?;
let now = chrono::Utc::now().timestamp();
let now_str = now.to_string();
for (file_path, file_hash, symbols) in entries {
let file_id: i64 = tx.query_row(
"SELECT id FROM files WHERE path = ?",
[file_path.as_str()],
|row| row.get(0)
).context(format!("File not found in index: {}", file_path))?;
let symbols_without_path: Vec<_> = symbols
.iter()
.map(|s| {
let mut s = s.clone();
s.path = String::new();
s
})
.collect();
let symbols_json = serde_json::to_string(&symbols_without_path)
.context("Failed to serialize symbols")?;
tx.execute(
"INSERT OR REPLACE INTO symbols (file_id, file_hash, symbols_json, last_cached)
VALUES (?, ?, ?, ?)",
[&file_id.to_string(), file_hash.as_str(), &symbols_json, &now_str],
)?;
}
tx.commit()?;
log::debug!("Batch cached symbols for {} files", entries.len());
Ok(())
}
pub fn clear(&self) -> Result<()> {
let conn = Connection::open(&self.db_path)?;
conn.execute("DELETE FROM symbols", [])?;
log::info!("Cleared symbol cache");
Ok(())
}
pub fn stats(&self) -> Result<SymbolCacheStats> {
let conn = Connection::open(&self.db_path)?;
let total_files: usize = conn
.query_row("SELECT COUNT(DISTINCT file_id) FROM symbols", [], |row| {
row.get(0)
})
.unwrap_or(0);
let total_entries: usize = conn
.query_row("SELECT COUNT(*) FROM symbols", [], |row| row.get(0))
.unwrap_or(0);
let cache_size_bytes: u64 = conn
.query_row(
"SELECT SUM(LENGTH(symbols_json)) FROM symbols",
[],
|row| row.get(0),
)
.unwrap_or(0);
Ok(SymbolCacheStats {
total_files,
total_entries,
cache_size_bytes,
})
}
pub fn cleanup_stale(&self) -> Result<usize> {
let conn = Connection::open(&self.db_path)?;
let removed = conn.execute(
"DELETE FROM symbols WHERE file_id NOT IN (SELECT id FROM files)",
[],
)?;
if removed > 0 {
log::info!("Removed {} stale symbol cache entries", removed);
}
Ok(removed)
}
}
#[derive(Debug, Clone)]
pub struct SymbolCacheStats {
pub total_files: usize,
pub total_entries: usize,
pub cache_size_bytes: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::CacheManager;
use tempfile::TempDir;
#[test]
fn test_symbol_cache_init() {
let temp = TempDir::new().unwrap();
let cache_mgr = CacheManager::new(temp.path());
cache_mgr.init().unwrap();
let symbol_cache = SymbolCache::open(cache_mgr.path()).unwrap();
let stats = symbol_cache.stats().unwrap();
assert_eq!(stats.total_files, 0);
}
#[test]
fn test_symbol_cache_set_get() {
let temp = TempDir::new().unwrap();
let cache_mgr = CacheManager::new(temp.path());
cache_mgr.init().unwrap();
cache_mgr.update_file("test.rs", "rust", 100).unwrap();
let symbol_cache = SymbolCache::open(cache_mgr.path()).unwrap();
let symbols = vec![
SearchResult::new(
"test.rs".to_string(),
Language::Rust,
SymbolKind::Function,
Some("test_fn".to_string()),
Span::new(1, 0, 5, 0),
None,
"fn test_fn() {}".to_string(),
),
];
symbol_cache
.set("test.rs", "hash123", &symbols)
.unwrap();
let cached = symbol_cache.get("test.rs", "hash123").unwrap();
assert!(cached.is_some());
assert_eq!(cached.as_ref().unwrap().len(), 1);
assert_eq!(
cached.unwrap()[0].symbol.as_deref(),
Some("test_fn")
);
}
#[test]
fn test_symbol_cache_hash_mismatch() {
let temp = TempDir::new().unwrap();
let cache_mgr = CacheManager::new(temp.path());
cache_mgr.init().unwrap();
cache_mgr.update_file("test.rs", "rust", 100).unwrap();
let symbol_cache = SymbolCache::open(cache_mgr.path()).unwrap();
let symbols = vec![SearchResult::new(
"test.rs".to_string(),
Language::Rust,
SymbolKind::Function,
Some("test_fn".to_string()),
Span::new(1, 0, 5, 0),
None,
"fn test_fn() {}".to_string(),
)];
symbol_cache
.set("test.rs", "hash123", &symbols)
.unwrap();
let cached = symbol_cache.get("test.rs", "hash456").unwrap();
assert!(cached.is_none());
}
#[test]
fn test_symbol_cache_batch_set() {
let temp = TempDir::new().unwrap();
let cache_mgr = CacheManager::new(temp.path());
cache_mgr.init().unwrap();
cache_mgr.update_file("file1.rs", "rust", 100).unwrap();
cache_mgr.update_file("file2.rs", "rust", 200).unwrap();
let symbol_cache = SymbolCache::open(cache_mgr.path()).unwrap();
let entries = vec![
(
"file1.rs".to_string(),
"hash1".to_string(),
vec![SearchResult::new(
"file1.rs".to_string(),
Language::Rust,
SymbolKind::Function,
Some("fn1".to_string()),
Span::new(1, 0, 5, 0),
None,
"fn fn1() {}".to_string(),
)],
),
(
"file2.rs".to_string(),
"hash2".to_string(),
vec![SearchResult::new(
"file2.rs".to_string(),
Language::Rust,
SymbolKind::Function,
Some("fn2".to_string()),
Span::new(1, 0, 5, 0),
None,
"fn fn2() {}".to_string(),
)],
),
];
symbol_cache.batch_set(&entries).unwrap();
let stats = symbol_cache.stats().unwrap();
assert_eq!(stats.total_files, 2);
let cached1 = symbol_cache.get("file1.rs", "hash1").unwrap();
assert!(cached1.is_some());
let cached2 = symbol_cache.get("file2.rs", "hash2").unwrap();
assert!(cached2.is_some());
}
#[test]
fn test_symbol_cache_batch_get() {
let temp = TempDir::new().unwrap();
let cache_mgr = CacheManager::new(temp.path());
cache_mgr.init().unwrap();
cache_mgr.update_file("file1.rs", "rust", 100).unwrap();
cache_mgr.update_file("file2.rs", "rust", 200).unwrap();
cache_mgr.update_file("file3.rs", "rust", 300).unwrap();
let symbol_cache = SymbolCache::open(cache_mgr.path()).unwrap();
let entries = vec![
(
"file1.rs".to_string(),
"hash1".to_string(),
vec![SearchResult::new(
"file1.rs".to_string(),
Language::Rust,
SymbolKind::Function,
Some("fn1".to_string()),
Span::new(1, 0, 5, 0),
None,
"fn fn1() {}".to_string(),
)],
),
(
"file2.rs".to_string(),
"hash2".to_string(),
vec![SearchResult::new(
"file2.rs".to_string(),
Language::Rust,
SymbolKind::Struct,
Some("Struct2".to_string()),
Span::new(1, 0, 5, 0),
None,
"struct Struct2 {}".to_string(),
)],
),
(
"file3.rs".to_string(),
"hash3".to_string(),
vec![SearchResult::new(
"file3.rs".to_string(),
Language::Rust,
SymbolKind::Enum,
Some("Enum3".to_string()),
Span::new(1, 0, 5, 0),
None,
"enum Enum3 {}".to_string(),
)],
),
];
symbol_cache.batch_set(&entries).unwrap();
let lookup = vec![
("file1.rs".to_string(), "hash1".to_string()),
("file2.rs".to_string(), "hash2".to_string()),
("file3.rs".to_string(), "hash3".to_string()),
];
let results = symbol_cache.batch_get(&lookup).unwrap();
assert_eq!(results.len(), 3);
assert!(results[0].1.is_some());
assert_eq!(results[0].1.as_ref().unwrap()[0].symbol.as_deref(), Some("fn1"));
assert!(results[1].1.is_some());
assert_eq!(results[1].1.as_ref().unwrap()[0].symbol.as_deref(), Some("Struct2"));
assert!(results[2].1.is_some());
assert_eq!(results[2].1.as_ref().unwrap()[0].symbol.as_deref(), Some("Enum3"));
let mixed_lookup = vec![
("file1.rs".to_string(), "hash1".to_string()), ("nonexistent.rs".to_string(), "hash999".to_string()), ("file2.rs".to_string(), "wrong_hash".to_string()), ("file3.rs".to_string(), "hash3".to_string()), ];
let mixed_results = symbol_cache.batch_get(&mixed_lookup).unwrap();
assert_eq!(mixed_results.len(), 4);
assert!(mixed_results[0].1.is_some()); assert!(mixed_results[1].1.is_none()); assert!(mixed_results[2].1.is_none()); assert!(mixed_results[3].1.is_some());
let empty_results = symbol_cache.batch_get(&[]).unwrap();
assert_eq!(empty_results.len(), 0);
}
#[test]
fn test_symbol_cache_clear() {
let temp = TempDir::new().unwrap();
let cache_mgr = CacheManager::new(temp.path());
cache_mgr.init().unwrap();
cache_mgr.update_file("test.rs", "rust", 100).unwrap();
let symbol_cache = SymbolCache::open(cache_mgr.path()).unwrap();
let symbols = vec![SearchResult::new(
"test.rs".to_string(),
Language::Rust,
SymbolKind::Function,
Some("test_fn".to_string()),
Span::new(1, 0, 5, 0),
None,
"fn test_fn() {}".to_string(),
)];
symbol_cache.set("test.rs", "hash123", &symbols).unwrap();
let stats_before = symbol_cache.stats().unwrap();
assert_eq!(stats_before.total_files, 1);
symbol_cache.clear().unwrap();
let stats_after = symbol_cache.stats().unwrap();
assert_eq!(stats_after.total_files, 0);
}
#[test]
fn test_symbol_cache_cleanup_stale() {
let temp = TempDir::new().unwrap();
let cache_mgr = CacheManager::new(temp.path());
cache_mgr.init().unwrap();
cache_mgr.update_file("exists.rs", "rust", 100).unwrap();
cache_mgr.record_branch_file("exists.rs", "main", "hash1", None).unwrap();
cache_mgr.update_file("deleted.rs", "rust", 200).unwrap();
let symbol_cache = SymbolCache::open(cache_mgr.path()).unwrap();
let symbols = vec![SearchResult::new(
"test.rs".to_string(),
Language::Rust,
SymbolKind::Function,
Some("test_fn".to_string()),
Span::new(1, 0, 5, 0),
None,
"fn test_fn() {}".to_string(),
)];
symbol_cache.set("exists.rs", "hash1", &symbols).unwrap();
symbol_cache
.set("deleted.rs", "hash2", &symbols)
.unwrap();
let stats_before = symbol_cache.stats().unwrap();
assert_eq!(stats_before.total_files, 2);
let conn = rusqlite::Connection::open(cache_mgr.path().join("meta.db")).unwrap();
conn.execute("DELETE FROM files WHERE path = 'deleted.rs'", []).unwrap();
let removed = symbol_cache.cleanup_stale().unwrap();
assert_eq!(removed, 0);
let stats_after = symbol_cache.stats().unwrap();
assert_eq!(stats_after.total_files, 1);
let cached = symbol_cache.get("exists.rs", "hash1").unwrap();
assert!(cached.is_some());
let cached2 = symbol_cache.get("deleted.rs", "hash2").unwrap();
assert!(cached2.is_none());
}
}