#![allow(clippy::unwrap_used)]
use std::sync::{Arc, Barrier};
use std::thread;
use tempfile::tempdir;
use super::bm25::Bm25Index;
use super::bm25_persistence::{load_snapshot, save_snapshot, snapshot_path};
use super::bm25_persistence_wal::{
wal_append_add_document, wal_append_remove_document, wal_path_for_bm25, wal_replay,
wal_truncate,
};
fn sample_corpus() -> Vec<(u64, &'static str)> {
vec![
(1, "rust programming language systems"),
(2, "python programming data science"),
(3, "java programming enterprise"),
(4, "rust memory safety concurrency"),
(5, "typescript programming web frontend"),
]
}
fn sort_by_id(mut results: Vec<(u64, f32)>) -> Vec<(u64, f32)> {
results.sort_by_key(|(id, _)| *id);
results
}
fn assert_bitwise_eq(a: &[(u64, f32)], b: &[(u64, f32)]) {
assert_eq!(a.len(), b.len(), "result lengths differ: {a:?} vs {b:?}");
for (lhs, rhs) in a.iter().zip(b.iter()) {
assert_eq!(lhs.0, rhs.0, "ids differ: {a:?} vs {b:?}");
assert_eq!(
lhs.1.to_bits(),
rhs.1.to_bits(),
"score bits differ for id {}: {:?} vs {:?}",
lhs.0,
lhs.1,
rhs.1
);
}
}
#[test]
fn test_save_load_snapshot_roundtrip() {
let dir = tempdir().unwrap();
let index = Bm25Index::new();
for (id, text) in sample_corpus() {
index.add_document(id, text);
}
save_snapshot(dir.path(), &index).expect("save_snapshot");
let loaded = load_snapshot(dir.path())
.expect("load_snapshot")
.expect("snapshot file should exist");
assert_eq!(loaded.len(), index.len(), "doc_count should match");
for query in ["rust", "programming", "python", "web"] {
let original = sort_by_id(index.search(query, 10));
let round_trip = sort_by_id(loaded.search(query, 10));
assert_bitwise_eq(&original, &round_trip);
}
}
#[test]
fn test_wal_append_add_document_replay() {
let dir = tempdir().unwrap();
let wal_path = wal_path_for_bm25(dir.path());
let reference = Bm25Index::new();
for i in 0u64..10 {
let text = format!("document number {i} rust persistence");
wal_append_add_document(&wal_path, i, &text).unwrap();
reference.add_document(i, &text);
}
let replayed = Bm25Index::new();
let count = wal_replay(&wal_path, &replayed).expect("wal_replay");
assert_eq!(count, 10, "replay should apply all 10 entries");
assert_eq!(replayed.len(), reference.len());
for query in ["rust", "persistence", "number"] {
let expected = sort_by_id(reference.search(query, 10));
let actual = sort_by_id(replayed.search(query, 10));
assert_bitwise_eq(&expected, &actual);
}
}
#[test]
fn test_wal_append_remove_document_replay() {
let dir = tempdir().unwrap();
let wal_path = wal_path_for_bm25(dir.path());
let reference = Bm25Index::new();
for (id, text) in sample_corpus() {
wal_append_add_document(&wal_path, id, text).unwrap();
reference.add_document(id, text);
}
wal_append_remove_document(&wal_path, 2).unwrap();
wal_append_remove_document(&wal_path, 4).unwrap();
reference.remove_document(2);
reference.remove_document(4);
let replayed = Bm25Index::new();
let count = wal_replay(&wal_path, &replayed).expect("wal_replay");
let expected_count = u64::try_from(sample_corpus().len() + 2).expect("sample corpus is small");
assert_eq!(count, expected_count, "adds + removes must be counted");
assert_eq!(replayed.len(), reference.len());
let expected = sort_by_id(reference.search("programming", 10));
let actual = sort_by_id(replayed.search("programming", 10));
assert_bitwise_eq(&expected, &actual);
}
#[test]
fn test_snapshot_plus_wal_replay_preserves_query_topk() {
let dir = tempdir().unwrap();
let wal_path = wal_path_for_bm25(dir.path());
let initial = Bm25Index::new();
for (id, text) in sample_corpus() {
initial.add_document(id, text);
}
save_snapshot(dir.path(), &initial).unwrap();
wal_truncate(&wal_path).unwrap();
let mutations: &[(u64, &str)] = &[
(10, "rust async runtime tokio"),
(11, "data science machine learning python"),
(12, "web assembly rust performance"),
];
for (id, text) in mutations {
wal_append_add_document(&wal_path, *id, text).unwrap();
initial.add_document(*id, text);
}
wal_append_remove_document(&wal_path, 3).unwrap();
initial.remove_document(3);
let mut reloaded = load_snapshot(dir.path())
.unwrap()
.expect("snapshot should exist");
let replayed_count = wal_replay(&wal_path, &reloaded).unwrap();
let expected_replay = u64::try_from(mutations.len() + 1).expect("mutations is small");
assert_eq!(replayed_count, expected_replay);
assert_eq!(reloaded.len(), initial.len());
for query in ["rust", "python", "web", "programming"] {
let expected = sort_by_id(initial.search(query, 10));
let actual = sort_by_id(reloaded.search(query, 10));
assert_bitwise_eq(&expected, &actual);
}
let _ = &mut reloaded;
}
#[test]
fn test_no_snapshot_falls_back_to_payload_rebuild() {
let dir = tempdir().unwrap();
let result = load_snapshot(dir.path()).expect("load_snapshot should not error on missing file");
assert!(
result.is_none(),
"absent snapshot must signal the caller to run the legacy rebuild path"
);
let wal_path = wal_path_for_bm25(dir.path());
let idx = Bm25Index::new();
let count = wal_replay(&wal_path, &idx).expect("wal_replay");
assert_eq!(count, 0);
assert!(idx.is_empty());
}
#[test]
fn test_corrupted_snapshot_surfaces_error_without_silent_data_loss() {
let dir = tempdir().unwrap();
let path = snapshot_path(dir.path());
std::fs::write(&path, b"\xFF\x00\xAA\xBB\xCC\xDE\xAD\xBE\xEF\x42\x17").unwrap();
let result = load_snapshot(dir.path());
match result {
Err(err) => {
let msg = err.to_string().to_lowercase();
assert!(
msg.contains("bm25") || msg.contains("snapshot"),
"error message should identify BM25 or snapshot context: {msg}"
);
}
Ok(_) => {
panic!("corrupt snapshot MUST surface as Err, never Ok(None) (issue #618 learning)")
}
}
}
#[test]
fn test_wal_append_then_save_snapshot_truncates_wal() {
let dir = tempdir().unwrap();
let wal_path = wal_path_for_bm25(dir.path());
let index = Bm25Index::new();
for (id, text) in sample_corpus() {
wal_append_add_document(&wal_path, id, text).unwrap();
index.add_document(id, text);
}
assert!(wal_path.exists(), "WAL file should exist after appends");
let wal_len_before = std::fs::metadata(&wal_path).unwrap().len();
assert!(
wal_len_before > 0,
"WAL should be non-empty before snapshot"
);
save_snapshot(dir.path(), &index).unwrap();
wal_truncate(&wal_path).unwrap();
let wal_len_after = std::fs::metadata(&wal_path).unwrap().len();
assert_eq!(
wal_len_after, 0,
"wal_truncate must reduce WAL to zero bytes"
);
let reloaded = load_snapshot(dir.path()).unwrap().expect("snapshot exists");
let replayed = wal_replay(&wal_path, &reloaded).unwrap();
assert_eq!(replayed, 0, "truncated WAL must replay zero entries");
assert_eq!(reloaded.len(), index.len());
}
const CONCURRENT_THREADS: u64 = 4;
const CONCURRENT_PER_THREAD: u64 = 25;
#[test]
fn test_concurrent_add_document_safe_with_wal_append() {
let dir = tempdir().unwrap();
let wal_path = Arc::new(wal_path_for_bm25(dir.path()));
let thread_count =
usize::try_from(CONCURRENT_THREADS).expect("CONCURRENT_THREADS fits in usize");
let barrier = Arc::new(Barrier::new(thread_count));
let mut handles = Vec::with_capacity(thread_count);
for t in 0..CONCURRENT_THREADS {
let wal = Arc::clone(&wal_path);
let bar = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
bar.wait();
for i in 0..CONCURRENT_PER_THREAD {
let id = t * CONCURRENT_PER_THREAD + i;
let marker = format!("marker{id:05}");
let text = format!("concurrent thread{t} doc{i} {marker}");
wal_append_add_document(&wal, id, &text).expect("wal append");
}
}));
}
for h in handles {
h.join().expect("thread join");
}
let replayed_index = Bm25Index::new();
let count = wal_replay(&wal_path, &replayed_index).expect("wal_replay");
let expected = CONCURRENT_THREADS * CONCURRENT_PER_THREAD;
assert_eq!(
count, expected,
"all {expected} concurrent appends must survive replay"
);
let expected_usize = usize::try_from(expected).expect("expected fits in usize");
assert_eq!(replayed_index.len(), expected_usize);
for t in 0..CONCURRENT_THREADS {
for i in 0..CONCURRENT_PER_THREAD {
let id = t * CONCURRENT_PER_THREAD + i;
let needle = format!("marker{id:05}");
let hits = replayed_index.search(&needle, 2);
assert!(
hits.iter().any(|(hid, _)| *hid == id),
"id {id} missing after concurrent replay (hits: {hits:?})"
);
}
}
}