use mcpkill::cache::Cache;
use mcpkill::chunker::chunk;
use mcpkill::similarity::cosine;
use tempfile::NamedTempFile;
fn unit(i: usize, dim: usize) -> Vec<f32> {
let mut v = vec![0.0f32; dim];
v[i % dim] = 1.0;
v
}
fn open_tmp_cache() -> (Cache, NamedTempFile) {
let file = NamedTempFile::new().unwrap();
let cache = Cache::new(file.path().to_str().unwrap()).unwrap();
(cache, file)
}
#[test]
fn cache_exact_hit() {
let (cache, tmp) = open_tmp_cache();
let db = tmp.path().to_str().unwrap();
let emb = unit(0, 4);
let chunks = vec![
("chunk A".to_string(), unit(0, 4)),
("chunk B".to_string(), unit(1, 4)),
];
cache.store("tool arg", &emb, &chunks, 800, 100).unwrap();
let result = cache.search(&emb, 0.99).unwrap();
assert!(result.is_some(), "expected a cache hit");
let (_, found) = result.unwrap();
assert_eq!(found.len(), 2);
assert_eq!(found[0].text, "chunk A");
let stats = cache.stats(db).unwrap();
assert_eq!(stats.entries, 1);
assert_eq!(stats.tokens_original, 800);
assert_eq!(stats.tokens_returned, 100);
}
#[test]
fn cache_miss_below_threshold() {
let (cache, _tmp) = open_tmp_cache();
let stored = unit(0, 4); let query = unit(1, 4);
cache
.store(
"tool arg",
&stored,
&[("chunk".to_string(), stored.clone())],
100,
10,
)
.unwrap();
let result = cache.search(&query, 0.85).unwrap();
assert!(result.is_none(), "orthogonal query must not hit cache");
}
#[test]
fn cache_hit_increments_hit_count() {
let (cache, tmp) = open_tmp_cache();
let db = tmp.path().to_str().unwrap();
let emb = unit(0, 4);
cache
.store("q", &emb, &[("c".to_string(), emb.clone())], 100, 10)
.unwrap();
let (query_id, _) = cache.search(&emb, 0.99).unwrap().unwrap();
cache.record_hit(query_id).unwrap();
cache.record_hit(query_id).unwrap();
let stats = cache.stats(db).unwrap();
assert_eq!(stats.cache_hits, 2);
}
#[test]
fn evict_expired_removes_old_entries() {
let (cache, tmp) = open_tmp_cache();
let db_path = tmp.path().to_str().unwrap();
{
use rusqlite::Connection;
let conn = Connection::open(db_path).unwrap();
let old_ts = 0i64; conn.execute(
"INSERT INTO queries (query_text, query_embedding, original_tokens, returned_tokens, created_at, last_used_at)
VALUES ('old', X'00000000', 100, 10, ?1, ?1)",
rusqlite::params![old_ts],
).unwrap();
}
let n = cache.evict_expired(7).unwrap();
assert_eq!(n, 1, "expected 1 evicted entry");
let stats = cache.stats(db_path).unwrap();
assert_eq!(stats.entries, 0);
}
#[test]
fn clear_all_empties_the_cache() {
let (cache, tmp) = open_tmp_cache();
let db = tmp.path().to_str().unwrap();
let emb = unit(0, 4);
for i in 0..5 {
cache
.store(
&format!("q{i}"),
&emb,
&[("c".to_string(), emb.clone())],
100,
10,
)
.unwrap();
}
assert_eq!(cache.stats(db).unwrap().entries, 5);
let removed = cache.clear_all().unwrap();
assert_eq!(removed, 5);
assert_eq!(cache.stats(db).unwrap().entries, 0);
}
#[test]
fn chunker_produces_non_empty_output_for_any_input() {
let inputs = [
"plain text",
"## Header\nBody",
r#"{"key": "value"}"#,
"[1, 2, 3]",
"",
];
for input in inputs {
let chunks = chunk(input);
if input.trim().is_empty() {
assert!(chunks.is_empty(), "empty input should yield no chunks");
} else {
assert!(!chunks.is_empty(), "non-empty input should yield ≥1 chunk");
}
}
}
#[test]
fn cosine_self_similarity_is_one() {
let v = unit(2, 8);
let s = cosine(&v, &v);
assert!((s - 1.0).abs() < 1e-6);
}