use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::{Result, bail};
use dashmap::DashMap;
use model2vec_rs::model::StaticModel;
use tokio::sync::{OnceCell, RwLock};
use crate::VelesIndex;
use crate::persist;
pub const DEFAULT_CACHE_SIZE: usize = 10;
struct CacheEntry {
cell: Arc<OnceCell<Arc<RwLock<VelesIndex>>>>,
last_access: AtomicU64,
}
pub struct IndexCache {
entries: DashMap<String, CacheEntry>,
model: StaticModel,
capacity: usize,
counter: AtomicU64,
}
impl IndexCache {
pub fn new(model: StaticModel) -> Self {
Self::with_capacity(model, DEFAULT_CACHE_SIZE)
}
pub fn with_capacity(model: StaticModel, capacity: usize) -> Self {
Self {
entries: DashMap::with_capacity(capacity.max(1)),
model,
capacity: capacity.max(1),
counter: AtomicU64::new(0),
}
}
pub async fn get_or_load(
&self,
repo: &str,
include_text_files: bool,
) -> Result<Arc<RwLock<VelesIndex>>> {
let cell = {
let entry = self
.entries
.entry(repo.to_string())
.or_insert_with(|| CacheEntry {
cell: Arc::new(OnceCell::new()),
last_access: AtomicU64::new(0),
});
entry.last_access.store(self.tick(), Ordering::Relaxed);
entry.cell.clone()
};
let index = cell
.get_or_try_init(|| async {
let built = self.build_index(repo, include_text_files)?;
anyhow::Ok(Arc::new(RwLock::new(built)))
})
.await
.map_err(|e| anyhow::anyhow!("failed to load {repo}: {e}"))?;
if self.entries.len() > self.capacity {
self.evict_lru();
}
Ok(index.clone())
}
pub fn peek(&self, repo: &str) -> Option<Arc<RwLock<VelesIndex>>> {
let entry = self.entries.get(repo)?;
entry.last_access.store(self.tick(), Ordering::Relaxed);
entry.cell.get().cloned()
}
pub fn invalidate(&self, repo: &str) -> bool {
self.entries.remove(repo).is_some()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn capacity(&self) -> usize {
self.capacity
}
fn tick(&self) -> u64 {
self.counter.fetch_add(1, Ordering::Relaxed)
}
fn evict_lru(&self) {
let oldest = self
.entries
.iter()
.min_by_key(|e| e.value().last_access.load(Ordering::Relaxed))
.map(|e| e.key().clone());
if let Some(key) = oldest {
self.entries.remove(&key);
}
}
fn build_index(&self, repo: &str, include_text_files: bool) -> Result<VelesIndex> {
let model = self.model.clone();
let path = Path::new(repo);
if path.is_dir() {
if persist::index_exists(path) {
match VelesIndex::load(path, model.clone()) {
Ok(idx) => return Ok(idx),
Err(_) => {
}
}
}
VelesIndex::from_path(path, Some(model), None, include_text_files)
} else if repo.starts_with("https://") || repo.starts_with("http://") {
VelesIndex::from_git(repo, None, Some(model), include_text_files)
} else {
bail!("Invalid repo: must be a local directory or https:// URL")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_model() -> StaticModel {
crate::model::load_model(None).expect("test model load")
}
#[tokio::test]
async fn caches_same_repo_across_calls() {
let cache = IndexCache::new(test_model());
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("a.rs"), "fn hello() {}\n").unwrap();
let repo = dir.path().to_string_lossy().into_owned();
let a = cache.get_or_load(&repo, false).await.unwrap();
let b = cache.get_or_load(&repo, false).await.unwrap();
assert!(Arc::ptr_eq(&a, &b), "cache miss on repeat lookup");
assert_eq!(cache.len(), 1);
}
#[tokio::test]
async fn evicts_lru_when_over_capacity() {
let cache = IndexCache::with_capacity(test_model(), 2);
let dirs: Vec<_> = (0..3)
.map(|i| {
let d = tempfile::tempdir().unwrap();
std::fs::write(d.path().join("a.rs"), format!("fn fn_{i}() {{}}\n")).unwrap();
d
})
.collect();
let paths: Vec<String> = dirs
.iter()
.map(|d| d.path().to_string_lossy().into_owned())
.collect();
let _ = cache.get_or_load(&paths[0], false).await.unwrap();
let _ = cache.get_or_load(&paths[1], false).await.unwrap();
let _ = cache.get_or_load(&paths[0], false).await.unwrap();
let _ = cache.get_or_load(&paths[2], false).await.unwrap();
assert_eq!(cache.len(), 2);
assert!(cache.entries.contains_key(&paths[0]));
assert!(cache.entries.contains_key(&paths[2]));
assert!(!cache.entries.contains_key(&paths[1]));
}
#[tokio::test]
async fn invalidate_removes_entry() {
let cache = IndexCache::new(test_model());
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("a.rs"), "fn x() {}\n").unwrap();
let repo = dir.path().to_string_lossy().into_owned();
let _ = cache.get_or_load(&repo, false).await.unwrap();
assert!(cache.invalidate(&repo));
assert!(cache.is_empty());
assert!(!cache.invalidate(&repo));
}
}