#![allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::float_cmp
)]
use super::sharded_index::ShardedIndex;
use super::*;
use rustc_hash::FxHashMap;
use std::io::Write as _;
use std::sync::Arc;
use tempfile::tempdir;
#[test]
fn test_replace_all_atomic_no_intermediate_empty() {
let index = Arc::new(ShardedIndex::new());
for i in 0..100u64 {
index.insert(i, i as usize * 16);
}
let mut new_entries: FxHashMap<u64, usize> = FxHashMap::default();
for i in 0..100u64 {
new_entries.insert(i, i as usize * 32);
}
index.replace_all(new_entries);
for i in 0..100u64 {
assert_eq!(
index.get(i),
Some(i as usize * 32),
"ID {i} should have updated offset after replace_all"
);
}
assert_eq!(index.len(), 100);
}
#[test]
fn test_replace_all_with_empty_map_clears_index() {
let index = ShardedIndex::new();
for i in 0..50u64 {
index.insert(i, i as usize * 8);
}
index.replace_all(FxHashMap::default());
assert!(
index.is_empty(),
"replace_all with empty map should clear index"
);
}
#[test]
fn test_replace_all_concurrent_reader_sees_consistent_state() {
let index = Arc::new(ShardedIndex::new());
for i in 0..64u64 {
index.insert(i, i as usize * 10);
}
let reader_index = Arc::clone(&index);
let reader = std::thread::spawn(move || {
for _ in 0..10_000 {
let mut found = 0usize;
let mut missing = 0usize;
for i in 0..64u64 {
if reader_index.get(i).is_some() {
found += 1;
} else {
missing += 1;
}
}
assert!(
found == 64 || found == 0,
"Inconsistent state: found={found}, missing={missing}"
);
}
});
for cycle in 0..100u64 {
let mut new_entries: FxHashMap<u64, usize> = FxHashMap::default();
for i in 0..64u64 {
new_entries.insert(i, (cycle * 64 + i) as usize);
}
index.replace_all(new_entries);
}
reader.join().expect("reader thread should not panic");
}
#[test]
fn test_compaction_uses_atomic_swap() {
let dir = tempdir().unwrap();
let dim = 4;
let mut storage = MmapStorage::new(dir.path(), dim).unwrap();
for i in 0u64..20 {
storage.store(i, &[i as f32; 4]).unwrap();
}
for i in 0u64..10 {
storage.delete(i).unwrap();
}
let reclaimed = storage.compact().unwrap();
assert!(reclaimed > 0);
for i in 10u64..20 {
let v = storage.retrieve(i).unwrap();
assert_eq!(v, Some(vec![i as f32; 4]));
}
}
#[test]
fn test_wal_replay_recovers_unflushed_stores() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
{
let mut storage = MmapStorage::new(&path, dim).unwrap();
storage.store(1, &[1.0, 2.0, 3.0]).unwrap();
storage.store(2, &[4.0, 5.0, 6.0]).unwrap();
storage.wal.write().flush().unwrap();
storage.wal.write().get_ref().sync_all().unwrap();
storage.mmap.write().flush().unwrap();
}
let storage = MmapStorage::new(&path, dim).unwrap();
let v1 = storage.retrieve(1).unwrap();
let v2 = storage.retrieve(2).unwrap();
assert_eq!(
v1,
Some(vec![1.0, 2.0, 3.0]),
"Vector 1 should be recovered from WAL"
);
assert_eq!(
v2,
Some(vec![4.0, 5.0, 6.0]),
"Vector 2 should be recovered from WAL"
);
}
#[test]
fn test_wal_replay_recovers_deletes() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
{
let mut storage = MmapStorage::new(&path, dim).unwrap();
storage.store(1, &[1.0, 2.0, 3.0]).unwrap();
storage.store(2, &[4.0, 5.0, 6.0]).unwrap();
storage.flush().unwrap();
storage.delete(1).unwrap();
storage.wal.write().flush().unwrap();
storage.wal.write().get_ref().sync_all().unwrap();
}
let storage = MmapStorage::new(&path, dim).unwrap();
assert!(
storage.retrieve(1).unwrap().is_none(),
"Deleted vector should not be recoverable"
);
assert_eq!(
storage.retrieve(2).unwrap(),
Some(vec![4.0, 5.0, 6.0]),
"Non-deleted vector should survive"
);
}
#[test]
fn test_wal_replay_skips_legacy_format() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
std::fs::create_dir_all(&path).unwrap();
let mut index: FxHashMap<u64, usize> = FxHashMap::default();
index.insert(1, 0);
let index_bytes = postcard::to_allocvec(&index).unwrap();
std::fs::write(path.join("vectors.idx"), &index_bytes).unwrap();
let data_path = path.join("vectors.dat");
let dim = 3;
let vector_bytes: Vec<u8> = [1.0f32, 2.0, 3.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let mut data = vec![0u8; 16 * 1024 * 1024]; data[..vector_bytes.len()].copy_from_slice(&vector_bytes);
std::fs::write(&data_path, &data).unwrap();
let mut wal = Vec::new();
wal.push(1u8);
wal.extend_from_slice(&2u64.to_le_bytes());
let vec_bytes: Vec<u8> = [7.0f32, 8.0, 9.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
wal.extend_from_slice(&(vec_bytes.len() as u32).to_le_bytes());
wal.extend_from_slice(&vec_bytes);
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.len(), 1, "Only indexed entry should exist");
assert!(
storage.retrieve(2).unwrap().is_none(),
"Legacy WAL entry should not be replayed"
);
}
#[test]
fn test_wal_replay_truncates_after_success() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
{
let mut storage = MmapStorage::new(&path, dim).unwrap();
storage.store(1, &[1.0, 2.0, 3.0]).unwrap();
storage.wal.write().flush().unwrap();
storage.wal.write().get_ref().sync_all().unwrap();
storage.mmap.write().flush().unwrap();
}
let _storage = MmapStorage::new(&path, dim).unwrap();
let wal_len = std::fs::metadata(path.join("vectors.wal")).unwrap().len();
assert_eq!(
wal_len, 0,
"WAL should be truncated after successful replay"
);
}
#[test]
fn test_legacy_compaction_marker_skipped_mid_stream() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut wal = crc_store_entry(1, &vec3_bytes([1.0, 2.0, 3.0]));
wal.push(4u8); wal.extend_from_slice(&crc_store_entry(2, &vec3_bytes([4.0, 5.0, 6.0])));
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.retrieve(1).unwrap(), Some(vec![1.0, 2.0, 3.0]));
assert_eq!(
storage.retrieve(2).unwrap(),
Some(vec![4.0, 5.0, 6.0]),
"entries after a legacy compaction marker must be replayed"
);
}
#[test]
fn test_legacy_compaction_marker_leading_byte_detected() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut wal = vec![4u8];
wal.extend_from_slice(&crc_store_entry(7, &vec3_bytes([7.0, 8.0, 9.0])));
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(
storage.retrieve(7).unwrap(),
Some(vec![7.0, 8.0, 9.0]),
"a leading legacy marker must not disable WAL replay"
);
}
#[test]
fn test_bak_recovery_restores_from_backup() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
std::fs::create_dir_all(&path).unwrap();
let data_path = path.join("vectors.dat");
let bak_path = path.join("vectors.dat.bak");
let dim = 3;
let data = vec![0u8; 16 * 1024 * 1024];
std::fs::write(&bak_path, &data).unwrap();
std::fs::write(path.join("vectors.wal"), b"").unwrap();
assert!(!data_path.exists());
assert!(bak_path.exists());
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.len(), 0); assert!(
data_path.exists(),
"vectors.dat should be restored from .bak"
);
assert!(!bak_path.exists(), ".bak should be cleaned up");
}
#[test]
fn test_bak_recovery_removes_stale_backup() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
std::fs::create_dir_all(&path).unwrap();
let data_path = path.join("vectors.dat");
let bak_path = path.join("vectors.dat.bak");
let data = vec![0u8; 16 * 1024 * 1024];
std::fs::write(&data_path, &data).unwrap();
std::fs::write(&bak_path, &data).unwrap();
std::fs::write(path.join("vectors.wal"), b"").unwrap();
let _storage = MmapStorage::new(&path, 3).unwrap();
assert!(
!bak_path.exists(),
".bak should be removed when original exists"
);
}
#[test]
fn test_tmp_recovery_removes_incomplete_compaction() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
std::fs::create_dir_all(&path).unwrap();
let data_path = path.join("vectors.dat");
let tmp_path = path.join("vectors.dat.tmp");
let data = vec![0u8; 16 * 1024 * 1024];
std::fs::write(&data_path, &data).unwrap();
std::fs::write(&tmp_path, b"incomplete compaction data").unwrap();
std::fs::write(path.join("vectors.wal"), b"").unwrap();
let _storage = MmapStorage::new(&path, 3).unwrap();
assert!(!tmp_path.exists(), ".tmp should be removed on startup");
}
fn crc_store_entry(id: u64, data: &[u8]) -> Vec<u8> {
use crate::storage::log_payload::crc32_hash;
let mut frame = Vec::new();
frame.push(1u8);
frame.extend_from_slice(&id.to_le_bytes());
frame.extend_from_slice(&(data.len() as u32).to_le_bytes());
frame.extend_from_slice(data);
let crc = crc32_hash(&frame);
frame.extend_from_slice(&crc.to_le_bytes());
frame
}
fn vec3_bytes(v: [f32; 3]) -> Vec<u8> {
v.iter().flat_map(|f| f.to_le_bytes()).collect()
}
#[test]
fn test_898_replay_rejects_oversized_wal_length_no_huge_alloc() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut wal = crc_store_entry(1, &vec3_bytes([1.0, 2.0, 3.0]));
wal.push(1u8);
wal.extend_from_slice(&2u64.to_le_bytes());
wal.extend_from_slice(&u32::MAX.to_le_bytes());
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.retrieve(1).unwrap(), Some(vec![1.0, 2.0, 3.0]));
assert!(storage.retrieve(2).unwrap().is_none());
}
#[test]
fn test_898_replay_torn_tail_recovers_prior_entries() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut wal = crc_store_entry(1, &vec3_bytes([1.0, 2.0, 3.0]));
wal.extend_from_slice(&crc_store_entry(2, &vec3_bytes([4.0, 5.0, 6.0])));
wal.push(1u8);
wal.extend_from_slice(&7u64.to_le_bytes()[..3]);
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.retrieve(1).unwrap(), Some(vec![1.0, 2.0, 3.0]));
assert_eq!(storage.retrieve(2).unwrap(), Some(vec![4.0, 5.0, 6.0]));
assert_eq!(storage.len(), 2, "torn tail must not corrupt prior entries");
}
#[test]
fn test_898_replay_midstream_crc_corruption_skips_and_continues() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut bad = crc_store_entry(2, &vec3_bytes([4.0, 5.0, 6.0]));
let last = bad.len() - 1;
bad[last] ^= 0xFF;
let mut wal = crc_store_entry(1, &vec3_bytes([1.0, 2.0, 3.0]));
wal.extend_from_slice(&bad);
wal.extend_from_slice(&crc_store_entry(3, &vec3_bytes([7.0, 8.0, 9.0])));
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let before = crate::metrics::global_guardrails_metrics()
.wal_replay_corrupt_entries
.load(std::sync::atomic::Ordering::Relaxed);
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.retrieve(1).unwrap(), Some(vec![1.0, 2.0, 3.0]));
assert!(
storage.retrieve(2).unwrap().is_none(),
"corrupt mid-stream entry must be skipped"
);
assert_eq!(
storage.retrieve(3).unwrap(),
Some(vec![7.0, 8.0, 9.0]),
"entries after a mid-stream corruption must still be recovered"
);
let after = crate::metrics::global_guardrails_metrics()
.wal_replay_corrupt_entries
.load(std::sync::atomic::Ordering::Relaxed);
assert!(after > before, "corrupt-entry metric must be incremented");
}
#[test]
fn test_898_replay_grows_mmap_no_silent_gap() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let vec_size = dim * 4;
let near_cap = 16 * 1024 * 1024 - vec_size; let mut index: FxHashMap<u64, usize> = FxHashMap::default();
index.insert(1, near_cap);
std::fs::write(
path.join("vectors.idx"),
postcard::to_allocvec(&index).unwrap(),
)
.unwrap();
let mut data = vec![0u8; 16 * 1024 * 1024];
data[near_cap..near_cap + vec_size].copy_from_slice(&vec3_bytes([1.0, 2.0, 3.0]));
std::fs::write(path.join("vectors.dat"), &data).unwrap();
let wal = crc_store_entry(2, &vec3_bytes([4.0, 5.0, 6.0]));
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.retrieve(1).unwrap(), Some(vec![1.0, 2.0, 3.0]));
assert_eq!(
storage.retrieve(2).unwrap(),
Some(vec![4.0, 5.0, 6.0]),
"vector beyond initial mmap must be recovered, not dropped"
);
}
#[test]
fn test_898_load_index_rejects_out_of_bounds_offset() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
std::fs::write(path.join("vectors.dat"), vec![0u8; 64]).unwrap();
let mut index: FxHashMap<u64, usize> = FxHashMap::default();
index.insert(1, 1_000_000);
std::fs::write(
path.join("vectors.idx"),
postcard::to_allocvec(&index).unwrap(),
)
.unwrap();
std::fs::write(path.join("vectors.wal"), b"").unwrap();
let result = MmapStorage::new(&path, dim);
assert!(
result.is_err(),
"index offset beyond data file must be rejected as corrupt"
);
}
#[test]
fn test_898_fsync_store_persists_wal_before_ok() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut storage = MmapStorage::new_with_durability(&path, dim, DurabilityMode::Fsync).unwrap();
storage.store(1, &[1.0, 2.0, 3.0]).unwrap();
let wal_len = std::fs::metadata(path.join("vectors.wal")).unwrap().len();
assert!(
wal_len >= (17 + dim * 4) as u64,
"Fsync store must persist the WAL entry before returning Ok (got {wal_len} bytes)"
);
}
#[test]
fn test_898_valid_roundtrip_and_crash_recovery_still_works() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
{
let mut storage = MmapStorage::new(&path, dim).unwrap();
storage.store(1, &[1.0, 2.0, 3.0]).unwrap();
storage.store(2, &[4.0, 5.0, 6.0]).unwrap();
storage.flush().unwrap();
}
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.retrieve(1).unwrap(), Some(vec![1.0, 2.0, 3.0]));
assert_eq!(storage.retrieve(2).unwrap(), Some(vec![4.0, 5.0, 6.0]));
assert_eq!(storage.len(), 2);
let wal_len = std::fs::metadata(path.join("vectors.wal")).unwrap().len();
assert_eq!(wal_len, 0, "WAL truncated only after mmap+idx made durable");
}
fn crc_delete_entry(id: u64) -> Vec<u8> {
use crate::storage::log_payload::crc32_hash;
let mut frame = Vec::new();
frame.push(2u8);
frame.extend_from_slice(&id.to_le_bytes());
let crc = crc32_hash(&frame);
frame.extend_from_slice(&crc.to_le_bytes());
frame
}
fn corrupt_entry_count() -> u64 {
crate::metrics::global_guardrails_metrics()
.wal_replay_corrupt_entries
.load(std::sync::atomic::Ordering::Relaxed)
}
#[test]
fn test_898b_fsync_delete_persists_wal_before_punch_hole() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut storage = MmapStorage::new_with_durability(&path, dim, DurabilityMode::Fsync).unwrap();
storage.store(1, &[1.0, 2.0, 3.0]).unwrap();
storage.flush_full().unwrap();
let wal_before = std::fs::metadata(path.join("vectors.wal")).unwrap().len();
storage.delete(1).unwrap();
let wal_after = std::fs::metadata(path.join("vectors.wal")).unwrap().len();
assert_eq!(
wal_after - wal_before,
13,
"Fsync delete must persist the WAL delete record before punch_hole"
);
}
#[test]
fn test_898b_delete_survives_crash_no_zero_resurrection() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut wal = crc_store_entry(1, &vec3_bytes([1.0, 2.0, 3.0]));
wal.extend_from_slice(&crc_delete_entry(1));
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let storage = MmapStorage::new(&path, dim).unwrap();
assert!(
storage.retrieve(1).unwrap().is_none(),
"deleted id must stay deleted after replay, not resurrect as zeros"
);
assert_eq!(storage.len(), 0);
}
#[test]
fn test_898b_valid_delete_normal_recovery_no_regression() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
{
let mut storage = MmapStorage::new(&path, dim).unwrap();
storage.store(1, &[1.0, 2.0, 3.0]).unwrap();
storage.store(2, &[4.0, 5.0, 6.0]).unwrap();
storage.delete(1).unwrap();
storage.flush().unwrap();
}
let storage = MmapStorage::new(&path, dim).unwrap();
assert!(
storage.retrieve(1).unwrap().is_none(),
"id 1 must stay deleted"
);
assert_eq!(storage.retrieve(2).unwrap(), Some(vec![4.0, 5.0, 6.0]));
assert_eq!(storage.len(), 1);
}
#[test]
fn test_898b_store_offset_overflow_leaves_next_offset_unchanged() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let vector_size = dim * std::mem::size_of::<f32>();
let mut storage = MmapStorage::new(&path, dim).unwrap();
let poisoned = usize::MAX - (vector_size - 1);
storage
.next_offset
.store(poisoned, std::sync::atomic::Ordering::SeqCst);
let before = storage
.next_offset
.load(std::sync::atomic::Ordering::SeqCst);
let result = storage.store(999, &[1.0, 2.0, 3.0]);
let after = storage
.next_offset
.load(std::sync::atomic::Ordering::SeqCst);
assert!(result.is_err(), "overflowing store offset must be rejected");
assert_eq!(
before, after,
"next_offset must not advance on the overflow error path"
);
}
#[test]
fn test_898_store_batch_offset_overflow_leaves_next_offset_unchanged() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let vector_size = dim * std::mem::size_of::<f32>();
let mut storage = MmapStorage::new(&path, dim).unwrap();
let poisoned = usize::MAX - (vector_size - 1);
storage
.next_offset
.store(poisoned, std::sync::atomic::Ordering::SeqCst);
let before = storage
.next_offset
.load(std::sync::atomic::Ordering::SeqCst);
let v = [1.0_f32, 2.0, 3.0];
let result = storage.store_batch(&[(999_u64, &v[..])]);
let after = storage
.next_offset
.load(std::sync::atomic::Ordering::SeqCst);
assert!(result.is_err(), "overflowing batch store must be rejected");
assert_eq!(
before, after,
"next_offset must not advance on the batch overflow error path"
);
}
#[test]
fn test_898b_torn_tail_crc_fail_at_eof_no_corrupt_metric() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut tail = crc_store_entry(2, &vec3_bytes([4.0, 5.0, 6.0]));
let last = tail.len() - 1;
tail[last] ^= 0xFF;
let mut wal = crc_store_entry(1, &vec3_bytes([1.0, 2.0, 3.0]));
wal.extend_from_slice(&tail);
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let before = corrupt_entry_count();
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.retrieve(1).unwrap(), Some(vec![1.0, 2.0, 3.0]));
assert!(
storage.retrieve(2).unwrap().is_none(),
"torn-tail record must be dropped"
);
assert_eq!(
before,
corrupt_entry_count(),
"a CRC-failing torn tail at EOF must NOT raise a corruption alert"
);
}
#[test]
fn test_898b_midstream_crc_fail_with_valid_after_increments_metric() {
let dir = tempdir().unwrap();
let path = dir.path().to_path_buf();
let dim = 3;
let mut bad = crc_store_entry(2, &vec3_bytes([4.0, 5.0, 6.0]));
let last = bad.len() - 1;
bad[last] ^= 0xFF;
let mut wal = crc_store_entry(1, &vec3_bytes([1.0, 2.0, 3.0]));
wal.extend_from_slice(&bad);
wal.extend_from_slice(&crc_store_entry(3, &vec3_bytes([7.0, 8.0, 9.0])));
std::fs::write(path.join("vectors.wal"), &wal).unwrap();
let before = corrupt_entry_count();
let storage = MmapStorage::new(&path, dim).unwrap();
assert_eq!(storage.retrieve(1).unwrap(), Some(vec![1.0, 2.0, 3.0]));
assert!(storage.retrieve(2).unwrap().is_none());
assert_eq!(storage.retrieve(3).unwrap(), Some(vec![7.0, 8.0, 9.0]));
assert!(
corrupt_entry_count() > before,
"a CRC failure with valid framing after it must increment the metric"
);
}