use std::collections::HashMap;
use std::fs::{self, File};
use std::io::{BufReader, BufWriter};
use std::path::Path;
use std::time::SystemTime;
use fs2::FileExt;
use serde::{Deserialize, Serialize};
use crate::semantic::types::{CacheConfig, CacheStats, CodeChunk, EmbeddingModel};
use crate::TldrResult;
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
struct CacheKey {
content_hash: String,
file_path: String,
function_name: Option<String>,
model: String,
}
impl CacheKey {
fn from_chunk(chunk: &CodeChunk, model: EmbeddingModel) -> Self {
Self {
content_hash: chunk.content_hash.clone(),
file_path: chunk.file_path.to_string_lossy().to_string(),
function_name: chunk.function_name.clone(),
model: format!("{:?}", model),
}
}
fn to_key_string(&self) -> String {
format!(
"{}:{}:{}:{}",
self.file_path,
self.function_name.as_deref().unwrap_or(""),
self.content_hash,
self.model
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CacheEntry {
embedding: Vec<f32>,
cached_at: u64,
file_mtime: Option<u64>,
}
pub struct EmbeddingCache {
config: CacheConfig,
entries: HashMap<String, CacheEntry>,
stats: CacheStats,
dirty: bool,
}
impl EmbeddingCache {
pub fn open(config: CacheConfig) -> TldrResult<Self> {
fs::create_dir_all(&config.cache_dir)?;
Self::cleanup_temp_files(&config.cache_dir);
let cache_file = config.cache_dir.join("cache.json");
let entries = if cache_file.exists() {
Self::load_with_lock(&cache_file).unwrap_or_else(|_| {
HashMap::new()
})
} else {
HashMap::new()
};
let size_bytes = entries
.values()
.map(|e| e.embedding.len() * std::mem::size_of::<f32>())
.sum();
Ok(Self {
config,
stats: CacheStats {
entries: entries.len(),
size_bytes,
hit_rate: 0.0,
},
entries,
dirty: false,
})
}
fn cleanup_temp_files(cache_dir: &Path) {
if let Ok(entries) = fs::read_dir(cache_dir) {
for entry in entries.flatten() {
if let Some(ext) = entry.path().extension() {
if ext == "tmp" {
let _ = fs::remove_file(entry.path());
}
}
}
}
}
fn load_with_lock(path: &Path) -> TldrResult<HashMap<String, CacheEntry>> {
let file = File::open(path)?;
file.lock_shared()?;
let reader = BufReader::new(&file);
let entries: HashMap<String, CacheEntry> =
serde_json::from_reader(reader).map_err(|e| crate::TldrError::ParseError {
file: path.to_path_buf(),
line: None,
message: format!("Cache file corrupted: {}", e),
})?;
file.unlock()?;
Ok(entries)
}
pub fn get(&mut self, chunk: &CodeChunk, model: EmbeddingModel) -> Option<Vec<f32>> {
let key = CacheKey::from_chunk(chunk, model);
let key_str = key.to_key_string();
if let Some(entry) = self.entries.get(&key_str) {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let age_days = (now.saturating_sub(entry.cached_at)) / (24 * 60 * 60);
if age_days > self.config.ttl_days as u64 {
self.stats.hit_rate = self.calculate_hit_rate(false);
return None; }
if let Some(cached_mtime) = entry.file_mtime {
if let Ok(metadata) = fs::metadata(&chunk.file_path) {
if let Ok(mtime) = metadata.modified() {
let current_mtime = mtime
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
if current_mtime > cached_mtime {
self.stats.hit_rate = self.calculate_hit_rate(false);
return None; }
}
}
}
self.stats.hit_rate = self.calculate_hit_rate(true);
Some(entry.embedding.clone())
} else {
self.stats.hit_rate = self.calculate_hit_rate(false);
None
}
}
fn calculate_hit_rate(&self, hit: bool) -> f64 {
let alpha = 0.1;
if hit {
self.stats.hit_rate * (1.0 - alpha) + alpha
} else {
self.stats.hit_rate * (1.0 - alpha)
}
}
pub fn put(&mut self, chunk: &CodeChunk, embedding: Vec<f32>, model: EmbeddingModel) {
let key = CacheKey::from_chunk(chunk, model);
let key_str = key.to_key_string();
let file_mtime = fs::metadata(&chunk.file_path)
.ok()
.and_then(|m| m.modified().ok())
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
.map(|d| d.as_secs());
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let entry_size = embedding.len() * std::mem::size_of::<f32>();
if !self.entries.contains_key(&key_str) {
self.stats.entries += 1;
self.stats.size_bytes += entry_size;
}
self.entries.insert(
key_str,
CacheEntry {
embedding,
cached_at: now,
file_mtime,
},
);
self.dirty = true;
}
pub fn flush(&mut self) -> TldrResult<()> {
if !self.dirty {
return Ok(());
}
let cache_file = self.config.cache_dir.join("cache.json");
let temp_file = self.config.cache_dir.join("cache.json.tmp");
{
let file = File::create(&temp_file)?;
file.lock_exclusive()?; let writer = BufWriter::new(&file);
serde_json::to_writer(writer, &self.entries).map_err(|e| {
crate::TldrError::ParseError {
file: temp_file.clone(),
line: None,
message: format!("Failed to serialize cache: {}", e),
}
})?;
file.sync_all()?;
file.unlock()?;
}
fs::rename(&temp_file, &cache_file)?;
self.dirty = false;
Ok(())
}
pub fn evict_stale(&mut self) -> usize {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let ttl_secs = self.config.ttl_days as u64 * 24 * 60 * 60;
let cutoff = now.saturating_sub(ttl_secs);
let before = self.entries.len();
self.entries.retain(|_, entry| entry.cached_at >= cutoff);
let evicted = before - self.entries.len();
if evicted > 0 {
self.stats.entries = self.entries.len();
self.stats.size_bytes = self
.entries
.values()
.map(|e| e.embedding.len() * std::mem::size_of::<f32>())
.sum();
self.dirty = true;
}
evicted
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl Drop for EmbeddingCache {
fn drop(&mut self) {
let _ = self.flush();
}
}
#[cfg(test)]
mod cache_tests {
use super::*;
use crate::Language;
use std::path::PathBuf;
use tempfile::tempdir;
fn create_test_chunk(name: &str, content: &str) -> CodeChunk {
CodeChunk {
file_path: PathBuf::from(format!("test/{}.rs", name)),
function_name: Some(name.to_string()),
class_name: None,
line_start: 1,
line_end: 10,
content: content.to_string(),
content_hash: format!("{:x}", md5::compute(content)),
language: Language::Rust,
}
}
#[test]
fn cache_config_default_values() {
let config = CacheConfig::default();
assert!(config.cache_dir.ends_with("tldr/embeddings"));
assert_eq!(config.max_size_mb, 500);
assert_eq!(config.ttl_days, 30);
}
#[test]
fn cache_open_creates_directory() {
let temp = tempdir().unwrap();
let cache_dir = temp.path().join("cache");
let config = CacheConfig {
cache_dir: cache_dir.clone(),
max_size_mb: 100,
ttl_days: 7,
};
let _cache = EmbeddingCache::open(config).unwrap();
assert!(cache_dir.exists());
}
#[test]
fn cache_put_get_roundtrip() {
let temp = tempdir().unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7,
};
let mut cache = EmbeddingCache::open(config).unwrap();
let chunk = create_test_chunk("foo", "fn foo() {}");
let embedding = vec![0.1, 0.2, 0.3];
cache.put(&chunk, embedding.clone(), EmbeddingModel::ArcticM);
let result = cache.get(&chunk, EmbeddingModel::ArcticM);
assert!(result.is_some());
assert_eq!(result.unwrap(), embedding);
}
#[test]
fn cache_miss_on_content_hash_change() {
let temp = tempdir().unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7,
};
let mut cache = EmbeddingCache::open(config).unwrap();
let chunk1 = create_test_chunk("foo", "fn foo() {}");
let embedding = vec![0.1, 0.2, 0.3];
cache.put(&chunk1, embedding, EmbeddingModel::ArcticM);
let chunk2 = create_test_chunk("foo", "fn foo() { return 1; }");
let result = cache.get(&chunk2, EmbeddingModel::ArcticM);
assert!(result.is_none());
}
#[test]
fn cache_miss_on_model_change() {
let temp = tempdir().unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7,
};
let mut cache = EmbeddingCache::open(config).unwrap();
let chunk = create_test_chunk("foo", "fn foo() {}");
let embedding = vec![0.1, 0.2, 0.3];
cache.put(&chunk, embedding, EmbeddingModel::ArcticM);
let result = cache.get(&chunk, EmbeddingModel::ArcticL);
assert!(result.is_none());
}
#[test]
fn cache_flush_persists_to_disk() {
let temp = tempdir().unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7,
};
let chunk = create_test_chunk("foo", "fn foo() {}");
let embedding = vec![0.1, 0.2, 0.3];
{
let mut cache = EmbeddingCache::open(config.clone()).unwrap();
cache.put(&chunk, embedding.clone(), EmbeddingModel::ArcticM);
cache.flush().unwrap();
}
let mut cache2 = EmbeddingCache::open(config).unwrap();
let result = cache2.get(&chunk, EmbeddingModel::ArcticM);
assert!(result.is_some());
assert_eq!(result.unwrap(), embedding);
}
#[test]
fn cache_evict_stale_removes_old_entries() {
let temp = tempdir().unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7, };
let mut cache = EmbeddingCache::open(config).unwrap();
let chunk = create_test_chunk("foo", "fn foo() {}");
let embedding = vec![0.1, 0.2, 0.3];
cache.put(&chunk, embedding, EmbeddingModel::ArcticM);
assert_eq!(cache.len(), 1);
let key = CacheKey::from_chunk(&chunk, EmbeddingModel::ArcticM).to_key_string();
if let Some(entry) = cache.entries.get_mut(&key) {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
entry.cached_at = now - (8 * 24 * 60 * 60); }
let evicted = cache.evict_stale();
assert_eq!(evicted, 1);
assert_eq!(cache.len(), 0);
}
#[test]
fn cache_stats_tracking() {
let temp = tempdir().unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7,
};
let mut cache = EmbeddingCache::open(config).unwrap();
assert_eq!(cache.stats().entries, 0);
assert_eq!(cache.stats().size_bytes, 0);
let chunk1 = create_test_chunk("foo", "fn foo() {}");
let chunk2 = create_test_chunk("bar", "fn bar() {}");
let embedding = vec![0.1_f32, 0.2, 0.3];
cache.put(&chunk1, embedding.clone(), EmbeddingModel::ArcticM);
cache.put(&chunk2, embedding.clone(), EmbeddingModel::ArcticM);
assert_eq!(cache.stats().entries, 2);
assert_eq!(cache.stats().size_bytes, 24); }
#[test]
fn cache_key_includes_function_identity() {
let temp = tempdir().unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7,
};
let mut cache = EmbeddingCache::open(config).unwrap();
let content = "fn template() { return 1; }";
let chunk1 = CodeChunk {
file_path: PathBuf::from("test/foo.rs"),
function_name: Some("foo".to_string()),
class_name: None,
line_start: 1,
line_end: 10,
content: content.to_string(),
content_hash: format!("{:x}", md5::compute(content)),
language: Language::Rust,
};
let chunk2 = CodeChunk {
file_path: PathBuf::from("test/bar.rs"),
function_name: Some("bar".to_string()),
class_name: None,
line_start: 1,
line_end: 10,
content: content.to_string(),
content_hash: format!("{:x}", md5::compute(content)), language: Language::Rust,
};
let embedding1 = vec![0.1, 0.2, 0.3];
let embedding2 = vec![0.4, 0.5, 0.6];
cache.put(&chunk1, embedding1.clone(), EmbeddingModel::ArcticM);
cache.put(&chunk2, embedding2.clone(), EmbeddingModel::ArcticM);
assert_eq!(cache.len(), 2);
let result1 = cache.get(&chunk1, EmbeddingModel::ArcticM);
let result2 = cache.get(&chunk2, EmbeddingModel::ArcticM);
assert_eq!(result1.unwrap(), embedding1);
assert_eq!(result2.unwrap(), embedding2);
}
#[test]
fn cache_ttl_checked_on_read() {
let temp = tempdir().unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 0, };
let mut cache = EmbeddingCache::open(config).unwrap();
let chunk = create_test_chunk("foo", "fn foo() {}");
let embedding = vec![0.1, 0.2, 0.3];
cache.put(&chunk, embedding, EmbeddingModel::ArcticM);
let _result = cache.get(&chunk, EmbeddingModel::ArcticM);
assert!(cache
.entries
.contains_key(&CacheKey::from_chunk(&chunk, EmbeddingModel::ArcticM).to_key_string()));
}
#[test]
fn cache_len_and_is_empty() {
let temp = tempdir().unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7,
};
let mut cache = EmbeddingCache::open(config).unwrap();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
let chunk = create_test_chunk("foo", "fn foo() {}");
cache.put(&chunk, vec![0.1, 0.2], EmbeddingModel::ArcticM);
assert!(!cache.is_empty());
assert_eq!(cache.len(), 1);
}
#[test]
fn cache_handles_corrupted_file() {
let temp = tempdir().unwrap();
let cache_file = temp.path().join("cache.json");
fs::write(&cache_file, "not valid json{{{").unwrap();
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7,
};
let cache = EmbeddingCache::open(config);
assert!(cache.is_ok());
assert!(cache.unwrap().is_empty());
}
#[test]
fn cache_cleans_up_temp_files() {
let temp = tempdir().unwrap();
let temp_file = temp.path().join("cache.json.tmp");
fs::write(&temp_file, "orphaned temp file").unwrap();
assert!(temp_file.exists());
let config = CacheConfig {
cache_dir: temp.path().to_path_buf(),
max_size_mb: 100,
ttl_days: 7,
};
let _cache = EmbeddingCache::open(config).unwrap();
assert!(!temp_file.exists());
}
}