use std::fs::File;
use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::Path;
use std::sync::Arc;
use crate::distance::Distance;
use crate::hnsw::{Config, Hnsw, PruneStrategy, VecStore};
use crate::payload::Payload;
pub(crate) const MAGIC: &[u8; 8] = b"HNSWNDX\0";
pub(crate) const VERSION: u32 = 1;
pub(crate) const VECTORS_OFFSET: usize = 256;
const OFF_VERSION: usize = 8;
const OFF_N: usize = 12;
const OFF_DIM: usize = 20;
const OFF_M: usize = 28;
const OFF_M0: usize = 36;
const OFF_EF: usize = 44;
const OFF_EP_ID: usize = 52;
const OFF_EP_LEVEL: usize = 60;
const OFF_FLAGS: usize = 68;
fn read_u32(buf: &[u8], off: usize) -> u32 {
u32::from_le_bytes(buf[off..off + 4].try_into().unwrap())
}
fn read_u64(buf: &[u8], off: usize) -> u64 {
u64::from_le_bytes(buf[off..off + 8].try_into().unwrap())
}
fn u32_le(v: u32) -> [u8; 4] { v.to_le_bytes() }
fn u64_le(v: u64) -> [u8; 8] { v.to_le_bytes() }
fn f32_le(v: f32) -> [u8; 4] { v.to_le_bytes() }
fn hnsw_section_bytes<D: Distance>(index: &Hnsw<D>) -> u64 {
let n = index.vec_store.len();
let dim = index.dim.unwrap_or(0);
let conn_bytes: u64 = index.connections.iter()
.flat_map(|node_conn| node_conn.iter())
.map(|layer_conn| 4 + (layer_conn.len() as u64) * 8)
.sum();
VECTORS_OFFSET as u64 + (n as u64) * (dim as u64) * 4 + (n as u64) * 4 + (n as u64) * 8 + conn_bytes }
fn fixed_payload_section_bytes<L: Payload>(n: usize) -> Option<u64> {
L::fixed_stride().map(|stride| {
16 + (n as u64) * (stride as u64)
})
}
pub fn save<D: Distance>(index: &Hnsw<D>, path: impl AsRef<Path>) -> io::Result<()> {
let file = File::create(path)?;
let total = hnsw_section_bytes(index)
+ 16; let _ = file.set_len(total);
let mut w = BufWriter::new(file);
write_hnsw(index, &mut w)?;
write_empty_payload(&mut w)?;
w.flush()
}
pub fn save_with_payload<D, L>(
index: &Hnsw<D>,
payloads: &[L],
path: impl AsRef<Path>,
) -> io::Result<()>
where
D: Distance,
L: Payload,
{
assert_eq!(
payloads.len(), index.len(),
"payload count ({}) must match index size ({})",
payloads.len(), index.len()
);
let file = File::create(path)?;
let graph_bytes = hnsw_section_bytes(index);
if let Some(payload_bytes) = fixed_payload_section_bytes::<L>(payloads.len()) {
let _ = file.set_len(graph_bytes + payload_bytes);
}
let mut w = BufWriter::new(file);
write_hnsw(index, &mut w)?;
write_payloads(payloads, &mut w)?;
w.flush()
}
pub fn load<D: Distance>(path: impl AsRef<Path>, metric: D) -> io::Result<Hnsw<D>> {
let mut file = File::open(path)?;
let (index, _) = read_hnsw_owned(&mut file, metric, false)?;
Ok(index)
}
pub fn load_with_payload<D, L>(
path: impl AsRef<Path>,
metric: D,
) -> io::Result<(Hnsw<D>, Vec<L>)>
where
D: Distance,
L: Payload,
{
let mut file = File::open(path)?;
let (index, payload_start) = read_hnsw_owned(&mut file, metric, true)?;
let payloads = read_payloads::<L, _>(&mut file, index.len(), payload_start)?;
Ok((index, payloads))
}
pub fn load_mmap<D: Distance>(path: impl AsRef<Path>, metric: D) -> io::Result<Hnsw<D>> {
let file = File::open(path.as_ref())?;
read_hnsw_mmap_inner(file, metric).map(|(idx, _, _)| idx)
}
pub fn load_mmap_with_payload<D, L>(
path: impl AsRef<Path>,
metric: D,
) -> io::Result<(Hnsw<D>, Vec<L>)>
where
D: Distance,
L: Payload,
{
let file = File::open(path.as_ref())?;
let (index, payload_start, mmap) = read_hnsw_mmap_inner(file, metric)?;
let mut cursor = io::Cursor::new(mmap.as_ref() as &[u8]);
let payloads = read_payloads::<L, _>(&mut cursor, index.len(), payload_start)?;
Ok((index, payloads))
}
pub(crate) fn write_hnsw<D: Distance, W: Write>(
index: &Hnsw<D>,
w: &mut W,
) -> io::Result<()> {
let n = index.vec_store.len();
let dim = index.dim.unwrap_or(0);
let cfg = &index.config;
let mut hdr = [0u8; VECTORS_OFFSET];
hdr[..8].copy_from_slice(MAGIC);
hdr[OFF_VERSION..OFF_VERSION + 4].copy_from_slice(&u32_le(VERSION));
hdr[OFF_N..OFF_N + 8].copy_from_slice(&u64_le(n as u64));
hdr[OFF_DIM..OFF_DIM + 8].copy_from_slice(&u64_le(dim as u64));
hdr[OFF_M..OFF_M + 8].copy_from_slice(&u64_le(cfg.m as u64));
hdr[OFF_M0..OFF_M0 + 8].copy_from_slice(&u64_le(cfg.m0() as u64));
hdr[OFF_EF..OFF_EF + 8].copy_from_slice(&u64_le(cfg.ef_construction as u64));
let (ep_id, ep_level) = index.entry_point.unwrap_or((usize::MAX, 0));
hdr[OFF_EP_ID..OFF_EP_ID + 8].copy_from_slice(&u64_le(ep_id as u64));
hdr[OFF_EP_LEVEL..OFF_EP_LEVEL + 8].copy_from_slice(&u64_le(ep_level as u64));
hdr[OFF_FLAGS] = cfg.use_heuristic as u8;
hdr[OFF_FLAGS + 1] = cfg.extend_candidates as u8;
hdr[OFF_FLAGS + 2] = cfg.keep_pruned as u8;
hdr[OFF_FLAGS + 3] = match cfg.prune_strategy {
PruneStrategy::Simple => 0,
PruneStrategy::Heuristic => 1,
};
w.write_all(&hdr)?;
w.write_all(index.vec_store.as_bytes())?;
for node_conn in &index.connections {
let level = (node_conn.len() as u32).saturating_sub(1);
w.write_all(&u32_le(level))?;
}
let conn_data_base: u64 = VECTORS_OFFSET as u64
+ (n as u64) * (dim as u64) * 4
+ (n as u64) * 4
+ (n as u64) * 8;
let mut running_off = conn_data_base;
for node_conn in &index.connections {
w.write_all(&u64_le(running_off))?;
for layer_conn in node_conn {
running_off += 4 + (layer_conn.len() as u64) * 8;
}
}
let mut conn_buf: Vec<u8> = Vec::new();
for node_conn in &index.connections {
for layer_conn in node_conn {
let n_conns = layer_conn.len();
conn_buf.clear();
conn_buf.reserve(4 + n_conns * 8);
conn_buf.extend_from_slice(&u32_le(n_conns as u32));
for &(id, dist) in layer_conn {
conn_buf.extend_from_slice(&u32_le(id));
conn_buf.extend_from_slice(&f32_le(dist));
}
w.write_all(&conn_buf)?;
}
}
Ok(())
}
fn read_hnsw_owned<D: Distance, R: Read + Seek>(
r: &mut R,
metric: D,
_expect_payload: bool,
) -> io::Result<(Hnsw<D>, u64)> {
let (cfg, n, dim, ep, vec_offset, _file_size) = read_header(r)?;
r.seek(SeekFrom::Start(vec_offset as u64))?;
let n_floats = n * dim;
let mut raw = vec![0u8; n_floats * 4];
r.read_exact(&mut raw)?;
let floats: Vec<f32> = raw
.chunks_exact(4)
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect();
let mut vs = VecStore::new(dim, n);
vs.data = floats;
let (connections, payload_pos) = read_graph(r, n)?;
let index = Hnsw::from_parts(cfg, metric, vs, connections, ep, if dim == 0 { None } else { Some(dim) });
Ok((index, payload_pos))
}
fn read_hnsw_mmap_inner<D: Distance>(
file: File,
metric: D,
) -> io::Result<(Hnsw<D>, u64, Arc<memmap2::Mmap>)> {
let mmap = Arc::new(unsafe { memmap2::Mmap::map(&file)? });
#[cfg(unix)]
let _ = mmap.advise(memmap2::Advice::Random);
let mut cursor = io::Cursor::new(mmap.as_ref() as &[u8]);
let (cfg, n, dim, ep, vec_offset, _file_size) = read_header(&mut cursor)?;
let vec_bytes = n * dim * 4;
if vec_offset + vec_bytes > mmap.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"file too short: vector section extends past end of file",
));
}
let vs = VecStore::from_mmap(Arc::clone(&mmap), vec_offset, n, dim);
cursor.seek(SeekFrom::Start((vec_offset + vec_bytes) as u64))?;
let (connections, payload_pos) = read_graph(&mut cursor, n)?;
let index = Hnsw::from_parts(cfg, metric, vs, connections, ep, if dim == 0 { None } else { Some(dim) });
Ok((index, payload_pos, mmap))
}
fn read_header<R: Read + Seek>(
r: &mut R,
) -> io::Result<(Config, usize, usize, Option<(usize, usize)>, usize, u64)> {
let mut hdr = [0u8; VECTORS_OFFSET];
r.read_exact(&mut hdr)?;
if &hdr[..8] != MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("invalid magic bytes — expected {:?}", MAGIC),
));
}
let version = read_u32(&hdr, OFF_VERSION);
if version != VERSION {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unsupported file version {version} (expected {VERSION})"),
));
}
let n = read_u64(&hdr, OFF_N) as usize;
let dim = read_u64(&hdr, OFF_DIM) as usize;
let m = read_u64(&hdr, OFF_M) as usize;
let m0 = read_u64(&hdr, OFF_M0) as usize;
let ef = read_u64(&hdr, OFF_EF) as usize;
let ep_id = read_u64(&hdr, OFF_EP_ID) as usize;
let ep_level = read_u64(&hdr, OFF_EP_LEVEL) as usize;
let use_heuristic = hdr[OFF_FLAGS] != 0;
let extend_candidates = hdr[OFF_FLAGS + 1] != 0;
let keep_pruned = hdr[OFF_FLAGS + 2] != 0;
let prune_strategy = match hdr[OFF_FLAGS + 3] {
0 => PruneStrategy::Simple,
1 => PruneStrategy::Heuristic,
b => return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unknown prune_strategy byte {b}"),
)),
};
let entry_point = if ep_id == usize::MAX {
None
} else {
Some((ep_id, ep_level))
};
let config = Config {
m,
m0: Some(m0),
ef_construction: ef,
use_heuristic,
extend_candidates,
keep_pruned,
prune_strategy,
capacity: n, };
let pos = r.stream_position()?;
let end = r.seek(SeekFrom::End(0))?;
r.seek(SeekFrom::Start(pos))?;
Ok((config, n, dim, entry_point, VECTORS_OFFSET, end))
}
fn read_graph<R: Read + Seek>(
r: &mut R,
n: usize,
) -> io::Result<(Vec<Vec<Vec<(u32, f32)>>>, u64)> {
let mut raw_levels = vec![0u8; n * 4];
r.read_exact(&mut raw_levels)?;
let levels: Vec<u32> = raw_levels
.chunks_exact(4)
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
.collect();
let mut raw_offsets = vec![0u8; n * 8];
r.read_exact(&mut raw_offsets)?;
let _ = raw_offsets;
let mut connections: Vec<Vec<Vec<(u32, f32)>>> = Vec::with_capacity(n);
let mut pair_buf: Vec<u8> = Vec::new();
let mut buf4 = [0u8; 4];
for &level in &levels {
let n_layers = level as usize + 1;
let mut node_conn: Vec<Vec<(u32, f32)>> = Vec::with_capacity(n_layers);
for _ in 0..n_layers {
r.read_exact(&mut buf4)?;
let n_conns = u32::from_le_bytes(buf4) as usize;
let byte_count = n_conns * 8;
pair_buf.resize(byte_count, 0);
if byte_count > 0 {
r.read_exact(&mut pair_buf[..byte_count])?;
}
let layer_conn: Vec<(u32, f32)> = pair_buf[..byte_count]
.chunks_exact(8)
.map(|c| {
let id = u32::from_le_bytes(c[0..4].try_into().unwrap());
let dist = f32::from_le_bytes(c[4..8].try_into().unwrap());
(id, dist)
})
.collect();
node_conn.push(layer_conn);
}
connections.push(node_conn);
}
let payload_pos = r.stream_position()?;
Ok((connections, payload_pos))
}
pub(crate) fn write_empty_payload<W: Write>(w: &mut W) -> io::Result<()> {
w.write_all(&u64_le(0))?;
w.write_all(&u64_le(0))?;
Ok(())
}
pub(crate) fn write_payloads<L: Payload, W: Write + Seek>(
payloads: &[L],
w: &mut W,
) -> io::Result<()> {
let n = payloads.len();
let stride = L::fixed_stride().unwrap_or(0) as u64;
w.write_all(&u64_le(n as u64))?;
w.write_all(&u64_le(stride))?;
if stride > 0 {
let mut buf = Vec::with_capacity(stride as usize);
for p in payloads {
buf.clear();
p.encode(&mut buf);
debug_assert_eq!(buf.len(), stride as usize);
w.write_all(&buf)?;
}
} else {
let mut encoded: Vec<Vec<u8>> = Vec::with_capacity(n);
let mut buf = Vec::new();
for p in payloads {
buf.clear();
p.encode(&mut buf);
encoded.push(buf.clone());
}
let offsets_table_start = w.stream_position()?;
let data_start: u64 = offsets_table_start + (n as u64) * 8;
let mut offsets: Vec<u64> = Vec::with_capacity(n);
let mut cur = data_start;
for enc in &encoded {
offsets.push(cur);
cur += enc.len() as u64;
}
for &off in &offsets {
w.write_all(&u64_le(off))?;
}
for enc in &encoded {
w.write_all(enc)?;
}
}
Ok(())
}
pub(crate) fn read_payloads<L: Payload, R: Read + Seek>(
r: &mut R,
n: usize,
payload_section_pos: u64,
) -> io::Result<Vec<L>> {
r.seek(SeekFrom::Start(payload_section_pos))?;
let mut buf8 = [0u8; 8];
r.read_exact(&mut buf8)?;
let payload_count = u64::from_le_bytes(buf8) as usize;
r.read_exact(&mut buf8)?;
let stride = u64::from_le_bytes(buf8) as usize;
if payload_count == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"file contains no payload section",
));
}
if payload_count != n {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("payload count {payload_count} != index size {n}"),
));
}
let mut payloads = Vec::with_capacity(n);
if stride > 0 {
let total_bytes = n * stride;
let mut raw = vec![0u8; total_bytes];
r.read_exact(&mut raw)?;
for chunk in raw.chunks_exact(stride) {
let (p, _) = L::decode(chunk).map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, e.to_string())
})?;
payloads.push(p);
}
} else {
let mut offsets = vec![0u64; n];
for off in &mut offsets {
r.read_exact(&mut buf8)?;
*off = u64::from_le_bytes(buf8);
}
let mut buf = Vec::new();
for i in 0..n {
r.seek(SeekFrom::Start(offsets[i]))?;
let end = if i + 1 < n {
offsets[i + 1]
} else {
r.seek(SeekFrom::End(0))?
};
let byte_len = (end - offsets[i]) as usize;
buf.resize(byte_len, 0);
r.seek(SeekFrom::Start(offsets[i]))?;
r.read_exact(&mut buf)?;
let (p, _) = L::decode(&buf).map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, e.to_string())
})?;
payloads.push(p);
}
}
Ok(payloads)
}