use super::persistence::{
self, load_sidecars, save_sidecars, HnswMappingsData, HnswMeta, HnswVectorsData,
};
use super::sharded_mappings::ShardedMappings;
use super::sharded_vectors::ShardedVectors;
use crate::distance::DistanceMetric;
use crate::StorageMode;
use std::collections::HashMap;
use std::path::Path;
use tempfile::TempDir;
fn build_meta(generation: u64) -> HnswMeta {
HnswMeta {
dimension: 4,
metric: DistanceMetric::Cosine,
enable_vector_storage: true,
storage_mode: StorageMode::Full,
generation,
}
}
fn build_mappings() -> ShardedMappings {
let mut id_to_idx = HashMap::new();
id_to_idx.insert(1_u64, 0_usize);
id_to_idx.insert(2_u64, 1_usize);
let mut idx_to_id = HashMap::new();
idx_to_id.insert(0_usize, 1_u64);
idx_to_id.insert(1_usize, 2_u64);
ShardedMappings::from_parts(id_to_idx, idx_to_id, 2)
}
fn build_vectors() -> ShardedVectors {
let vectors = ShardedVectors::new(4);
vectors.insert_batch(vec![
(0_usize, vec![1.0_f32, 0.0, 0.0, 0.0]),
(1_usize, vec![0.0_f32, 1.0, 0.0, 0.0]),
]);
vectors
}
fn write_legacy_4tuple_meta(path: &Path, meta: &HnswMeta) -> std::io::Result<()> {
let metric_u8 = meta.metric as u8;
let storage_mode_u8 = match meta.storage_mode {
StorageMode::Full => 0u8,
StorageMode::SQ8 => 1,
StorageMode::Binary => 2,
StorageMode::ProductQuantization => 3,
StorageMode::RaBitQ => 4,
};
let bytes = postcard::to_allocvec(&(
meta.dimension,
metric_u8,
meta.enable_vector_storage,
storage_mode_u8,
))
.map_err(std::io::Error::other)?;
std::fs::write(path.join("native_meta.bin"), bytes)
}
fn write_legacy_3tuple_meta(path: &Path, meta: &HnswMeta) -> std::io::Result<()> {
let metric_u8 = meta.metric as u8;
let bytes = postcard::to_allocvec(&(meta.dimension, metric_u8, meta.enable_vector_storage))
.map_err(std::io::Error::other)?;
std::fs::write(path.join("native_meta.bin"), bytes)
}
fn write_legacy_3tuple_mappings(path: &Path, data: &HnswMappingsData) -> std::io::Result<()> {
let bytes = postcard::to_allocvec(&(&data.id_to_idx, &data.idx_to_id, data.next_idx))
.map_err(std::io::Error::other)?;
std::fs::write(path.join("native_mappings.bin"), bytes)
}
fn write_legacy_plain_vectors(path: &Path, data: &HnswVectorsData) -> std::io::Result<()> {
let bytes = postcard::to_allocvec(&data.vectors).map_err(std::io::Error::other)?;
std::fs::write(path.join("native_vectors.bin"), bytes)
}
fn mappings_data(mappings: &ShardedMappings, generation: u64) -> HnswMappingsData {
let (id_to_idx, idx_to_id, next_idx) = mappings.as_parts();
HnswMappingsData {
id_to_idx,
idx_to_id,
next_idx,
generation,
}
}
fn vectors_data(vectors: &ShardedVectors, generation: u64) -> HnswVectorsData {
HnswVectorsData {
vectors: vectors.collect_for_parallel(),
generation,
}
}
#[test]
fn test_save_sidecars_stamps_monotonic_generation() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let mappings = build_mappings();
let vectors = build_vectors();
let meta = build_meta(0);
save_sidecars(
path,
&mappings,
&vectors,
&meta,
persistence::next_generation(path).expect("test: next_generation"),
)
.expect("test: first save");
save_sidecars(
path,
&mappings,
&vectors,
&meta,
persistence::next_generation(path).expect("test: next_generation"),
)
.expect("test: second save");
save_sidecars(
path,
&mappings,
&vectors,
&meta,
persistence::next_generation(path).expect("test: next_generation"),
)
.expect("test: third save");
let loaded_meta = persistence::load_meta(path).expect("test: load meta");
assert_eq!(
loaded_meta.generation, 3,
"meta generation should be bumped once per save"
);
let loaded_mappings = persistence::load_mappings(path).expect("test: load mappings");
assert_eq!(
loaded_mappings.generation, 3,
"mappings generation must match meta"
);
let loaded_vectors = persistence::load_vectors(path).expect("test: load vectors");
assert_eq!(
loaded_vectors.generation, 3,
"vectors generation must match meta"
);
}
#[test]
fn test_load_sidecars_detects_stale_mappings() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let mappings = build_mappings();
let vectors = build_vectors();
let meta_4 = build_meta(4);
persistence::save_graph_generation(path, 4).expect("test: graph gen 4");
persistence::save_mappings(path, &mappings_data(&mappings, 4)).expect("test: save mappings 4");
persistence::save_vectors(path, &vectors_data(&vectors, 4)).expect("test: save vectors 4");
persistence::save_meta(path, &meta_4).expect("test: save meta 4");
persistence::save_mappings(path, &mappings_data(&mappings, 5)).expect("test: save mappings 5");
let loaded_meta = persistence::load_meta(path).expect("test: reload meta");
let err = load_sidecars(path, &loaded_meta)
.expect_err("test: stale mappings must trigger InvalidData");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(
err.to_string().contains("mappings generation"),
"error should mention mappings generation, got: {err}"
);
}
#[test]
fn test_load_sidecars_detects_stale_vectors() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let mappings = build_mappings();
let vectors = build_vectors();
persistence::save_graph_generation(path, 4).expect("test: graph gen 4");
persistence::save_mappings(path, &mappings_data(&mappings, 4)).expect("test: save mappings 4");
persistence::save_vectors(path, &vectors_data(&vectors, 4)).expect("test: save vectors 4");
persistence::save_meta(path, &build_meta(4)).expect("test: save meta 4");
persistence::save_graph_generation(path, 5).expect("test: graph gen 5");
persistence::save_mappings(path, &mappings_data(&mappings, 5)).expect("test: save mappings 5");
persistence::save_meta(path, &build_meta(5)).expect("test: save meta 5");
let loaded_meta = persistence::load_meta(path).expect("test: reload meta");
let err = load_sidecars(path, &loaded_meta)
.expect_err("test: stale vectors must trigger InvalidData");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(
err.to_string().contains("vectors generation"),
"error should mention vectors generation, got: {err}"
);
}
#[test]
fn test_load_sidecars_detects_newer_mappings_than_meta() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let mappings = build_mappings();
let vectors = build_vectors();
persistence::save_graph_generation(path, 5).expect("test: graph gen 5");
persistence::save_mappings(path, &mappings_data(&mappings, 10))
.expect("test: save mappings 10");
persistence::save_vectors(path, &vectors_data(&vectors, 5)).expect("test: save vectors 5");
persistence::save_meta(path, &build_meta(5)).expect("test: save meta 5");
let loaded_meta = persistence::load_meta(path).expect("test: reload meta");
let err = load_sidecars(path, &loaded_meta)
.expect_err("test: newer mappings than meta must trigger InvalidData");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(
err.to_string().contains("mappings generation"),
"error should mention mappings generation, got: {err}"
);
}
#[test]
fn test_backward_compat_legacy_meta_without_generation_loads() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let mappings = build_mappings();
let vectors = build_vectors();
let meta = build_meta(0); write_legacy_4tuple_meta(path, &meta).expect("test: legacy meta");
write_legacy_3tuple_mappings(path, &mappings_data(&mappings, 0))
.expect("test: legacy mappings");
write_legacy_plain_vectors(path, &vectors_data(&vectors, 0)).expect("test: legacy vectors");
let loaded_meta = persistence::load_meta(path).expect("test: legacy meta loads");
assert_eq!(
loaded_meta.generation, 0,
"legacy meta must default to generation 0"
);
let (_mappings, _vectors, enable_vs) =
load_sidecars(path, &loaded_meta).expect("test: legacy sidecars load");
assert!(enable_vs, "enable_vector_storage must round-trip from meta");
}
#[test]
fn test_backward_compat_legacy_3tuple_meta_loads() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let mappings = build_mappings();
let vectors = build_vectors();
let meta = build_meta(0);
write_legacy_3tuple_meta(path, &meta).expect("test: 3-tuple meta");
write_legacy_3tuple_mappings(path, &mappings_data(&mappings, 0))
.expect("test: legacy mappings");
write_legacy_plain_vectors(path, &vectors_data(&vectors, 0)).expect("test: legacy vectors");
let loaded_meta = persistence::load_meta(path).expect("test: 3-tuple meta loads");
assert_eq!(
loaded_meta.generation, 0,
"3-tuple meta must default to generation 0"
);
assert_eq!(
loaded_meta.storage_mode,
StorageMode::Full,
"3-tuple meta must default storage_mode to Full"
);
load_sidecars(path, &loaded_meta).expect("test: legacy sidecars load cleanly");
}
#[test]
fn test_save_then_load_roundtrip_gen_bumped() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let mappings = build_mappings();
let vectors = build_vectors();
persistence::save_mappings(path, &mappings_data(&mappings, 7)).expect("test: seed mappings");
persistence::save_vectors(path, &vectors_data(&vectors, 7)).expect("test: seed vectors");
persistence::save_meta(path, &build_meta(7)).expect("test: seed meta");
let meta_in = build_meta(0); let new_gen = persistence::next_generation(path).expect("test: next_generation");
assert_eq!(new_gen, 8, "next_generation must bump from 7 to 8");
save_sidecars(path, &mappings, &vectors, &meta_in, new_gen).expect("test: save bumps gen");
let loaded_meta = persistence::load_meta(path).expect("test: reload meta");
assert_eq!(
loaded_meta.generation, 8,
"save_sidecars must bump to next generation"
);
}
#[test]
fn test_save_when_no_prior_state_starts_at_gen_1() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let mappings = build_mappings();
let vectors = build_vectors();
let meta_in = build_meta(0);
let new_gen = persistence::next_generation(path).expect("test: next_generation");
assert_eq!(
new_gen, 1,
"next_generation on a fresh directory must return 1"
);
save_sidecars(path, &mappings, &vectors, &meta_in, new_gen).expect("test: first save");
let loaded_meta = persistence::load_meta(path).expect("test: reload meta");
assert_eq!(
loaded_meta.generation, 1,
"first save on a fresh directory must land at generation 1"
);
}
#[test]
fn test_load_sidecars_detects_stale_graph_generation() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let mappings = build_mappings();
let vectors = build_vectors();
persistence::save_graph_generation(path, 4).expect("test: graph gen 4");
persistence::save_mappings(path, &mappings_data(&mappings, 4)).expect("test: save mappings 4");
persistence::save_vectors(path, &vectors_data(&vectors, 4)).expect("test: save vectors 4");
persistence::save_meta(path, &build_meta(4)).expect("test: save meta 4");
persistence::save_graph_generation(path, 5).expect("test: graph gen 5 only");
let loaded_meta = persistence::load_meta(path).expect("test: reload meta");
let err =
load_sidecars(path, &loaded_meta).expect_err("test: stale graph must trigger InvalidData");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(
err.to_string().contains("graph generation"),
"error should mention graph generation, got: {err}"
);
}
#[test]
fn test_backward_compat_no_graph_generation_marker_loads_as_zero() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
let observed = persistence::load_graph_generation(path)
.expect("test: missing graph marker must not be an error");
assert_eq!(
observed, 0,
"missing native_hnsw.gen must be treated as generation 0"
);
}
#[test]
fn test_save_graph_generation_roundtrip() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
persistence::save_graph_generation(path, 42).expect("test: save marker");
let observed = persistence::load_graph_generation(path).expect("test: reload marker");
assert_eq!(observed, 42, "graph generation marker must round-trip");
persistence::save_graph_generation(path, 9999).expect("test: overwrite marker");
let observed = persistence::load_graph_generation(path).expect("test: reload marker 2");
assert_eq!(observed, 9999, "overwritten marker must round-trip");
}
#[test]
fn test_next_generation_propagates_corrupted_meta_error() {
let dir = TempDir::new().expect("test: temp dir");
let path = dir.path();
std::fs::write(path.join("native_meta.bin"), [0xFF_u8; 32]).expect("test: seed corrupted meta");
let result = persistence::next_generation(path);
assert!(
result.is_err(),
"corrupted meta must propagate, not silently reset to gen 1 (got {result:?})"
);
let fresh = TempDir::new().expect("test: fresh dir");
let gen =
persistence::next_generation(fresh.path()).expect("test: missing meta is not an error");
assert_eq!(
gen, 1,
"missing meta must yield generation 1, not propagate NotFound"
);
}