use crate::distance::DistanceMetric;
use std::collections::HashMap;
use std::io::Write;
use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
pub(crate) struct HnswMeta {
pub dimension: usize,
pub metric: DistanceMetric,
pub enable_vector_storage: bool,
pub storage_mode: crate::StorageMode,
}
pub(crate) struct HnswMappingsData {
pub id_to_idx: HashMap<u64, usize>,
pub idx_to_id: HashMap<usize, u64>,
pub next_idx: usize,
}
pub(crate) struct HnswVectorsData {
pub vectors: Vec<(usize, Vec<f32>)>,
}
pub(crate) fn save_meta(path: &Path, meta: &HnswMeta) -> std::io::Result<()> {
let meta_path = path.join("native_meta.bin");
let bytes = postcard::to_allocvec(&(
meta.dimension,
meta.metric as u8,
meta.enable_vector_storage,
storage_mode_to_u8(meta.storage_mode),
))
.map_err(std::io::Error::other)?;
atomic_write(&meta_path, &bytes)
}
pub(crate) fn load_meta(path: &Path) -> std::io::Result<HnswMeta> {
let meta_path = path.join("native_meta.bin");
let bytes = std::fs::read(meta_path)?;
if let Ok((dimension, metric_u8, enable_vector_storage, storage_mode_u8)) =
postcard::from_bytes::<(usize, u8, bool, u8)>(&bytes)
{
let metric = metric_from_u8(metric_u8)?;
let storage_mode = storage_mode_from_u8(storage_mode_u8);
return Ok(HnswMeta {
dimension,
metric,
enable_vector_storage,
storage_mode,
});
}
let (dimension, metric_u8, enable_vector_storage): (usize, u8, bool) =
postcard::from_bytes(&bytes).map_err(std::io::Error::other)?;
let metric = metric_from_u8(metric_u8)?;
Ok(HnswMeta {
dimension,
metric,
enable_vector_storage,
storage_mode: crate::StorageMode::Full,
})
}
pub(crate) fn save_mappings(path: &Path, data: &HnswMappingsData) -> std::io::Result<()> {
let mappings_path = path.join("native_mappings.bin");
let bytes = postcard::to_allocvec(&(&data.id_to_idx, &data.idx_to_id, data.next_idx))
.map_err(std::io::Error::other)?;
atomic_write(&mappings_path, &bytes)
}
pub(crate) fn load_mappings(path: &Path) -> std::io::Result<HnswMappingsData> {
let mappings_path = path.join("native_mappings.bin");
let bytes = std::fs::read(mappings_path)?;
let (id_to_idx, idx_to_id, next_idx): (HashMap<u64, usize>, HashMap<usize, u64>, usize) =
postcard::from_bytes(&bytes).map_err(std::io::Error::other)?;
Ok(HnswMappingsData {
id_to_idx,
idx_to_id,
next_idx,
})
}
pub(crate) fn save_vectors(path: &Path, data: &HnswVectorsData) -> std::io::Result<()> {
let vectors_path = path.join("native_vectors.bin");
let bytes = postcard::to_allocvec(&data.vectors).map_err(std::io::Error::other)?;
atomic_write(&vectors_path, &bytes)
}
pub(crate) fn load_vectors(path: &Path) -> std::io::Result<HnswVectorsData> {
let vectors_path = path.join("native_vectors.bin");
let bytes = std::fs::read(vectors_path)?;
let vectors: Vec<(usize, Vec<f32>)> =
postcard::from_bytes(&bytes).map_err(std::io::Error::other)?;
Ok(HnswVectorsData { vectors })
}
fn atomic_write(final_path: &Path, data: &[u8]) -> std::io::Result<()> {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
let tid = std::thread::current().id();
let file_name = final_path.file_name().unwrap_or_default().to_string_lossy();
let tmp_name = format!("{file_name}.tmp.{pid}.{tid:?}.{seq}");
let tmp_path = final_path.with_file_name(&tmp_name);
let result = atomic_write_inner(&tmp_path, final_path, data);
if result.is_err() {
let _ = std::fs::remove_file(&tmp_path);
}
result
}
fn atomic_write_inner(tmp_path: &Path, final_path: &Path, data: &[u8]) -> std::io::Result<()> {
let file = std::fs::File::create(tmp_path)?;
let mut writer = std::io::BufWriter::new(file);
writer.write_all(data)?;
writer.flush()?;
writer.get_ref().sync_all()?;
std::fs::rename(tmp_path, final_path)
}
pub(crate) fn load_vectors_or_disable(
path: &Path,
meta: &HnswMeta,
) -> std::io::Result<(super::sharded_vectors::ShardedVectors, bool)> {
use super::sharded_vectors::ShardedVectors;
if !meta.enable_vector_storage {
return Ok((ShardedVectors::new(meta.dimension), false));
}
match load_vectors(path) {
Ok(vectors_data) => {
let vectors = ShardedVectors::new(meta.dimension);
vectors.insert_batch(vectors_data.vectors);
Ok((vectors, true))
}
Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
tracing::debug!(
"native_vectors.bin missing during HNSW load; disabling vector storage for safety"
);
Ok((ShardedVectors::new(meta.dimension), false))
}
Err(err) => Err(err),
}
}
pub(crate) fn save_or_cleanup_vectors(
path: &Path,
enable_vector_storage: bool,
vectors: &super::sharded_vectors::ShardedVectors,
) -> std::io::Result<()> {
if enable_vector_storage {
save_vectors(
path,
&HnswVectorsData {
vectors: vectors.collect_for_parallel(),
},
)
} else {
let vectors_path = path.join("native_vectors.bin");
if vectors_path.exists() {
std::fs::remove_file(vectors_path)?;
}
Ok(())
}
}
fn metric_from_u8(value: u8) -> std::io::Result<DistanceMetric> {
match value {
0 => Ok(DistanceMetric::Cosine),
1 => Ok(DistanceMetric::Euclidean),
2 => Ok(DistanceMetric::DotProduct),
3 => Ok(DistanceMetric::Hamming),
4 => Ok(DistanceMetric::Jaccard),
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unknown distance metric",
)),
}
}
const fn storage_mode_to_u8(mode: crate::StorageMode) -> u8 {
match mode {
crate::StorageMode::Full => 0,
crate::StorageMode::SQ8 => 1,
crate::StorageMode::Binary => 2,
crate::StorageMode::ProductQuantization => 3,
crate::StorageMode::RaBitQ => 4,
}
}
const fn storage_mode_from_u8(value: u8) -> crate::StorageMode {
match value {
1 => crate::StorageMode::SQ8,
2 => crate::StorageMode::Binary,
3 => crate::StorageMode::ProductQuantization,
4 => crate::StorageMode::RaBitQ,
_ => crate::StorageMode::Full,
}
}