pub const HNSW_INDEX_MAGIC: [u8; 4] = *b"HNSW";
pub const HNSW_INDEX_VERSION_V1: u32 = 1;
pub const HNSW_INDEX_NO_ENTRY_POINT: u64 = u64::MAX;
pub const HNSW_INDEX_HEADER_LEN: usize = 4 + 4 + 4 + 4 + 4 + 4 + 4 + 8 + 1 + 4 + 8 + 8;
#[derive(Debug, Clone, PartialEq)]
pub struct HnswNodeFrame {
pub id: u64,
pub max_layer: u32,
pub vector: Vec<f32>,
pub connections: Vec<Vec<u64>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct HnswIndexFrame {
pub dimension: u32,
pub m: u32,
pub m_max0: u32,
pub ef_construction: u32,
pub ef_search: u32,
pub ml: f64,
pub metric: u8,
pub max_layer: u32,
pub entry_point: Option<u64>,
pub nodes: Vec<HnswNodeFrame>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HnswIndexFrameError {
TooShort,
InvalidMagic,
UnsupportedVersion(u32),
Truncated { offset: usize, reason: &'static str },
}
impl std::fmt::Display for HnswIndexFrameError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooShort => write!(f, "Data too short"),
Self::InvalidMagic => write!(f, "Invalid magic number"),
Self::UnsupportedVersion(version) => write!(f, "Unsupported version: {version}"),
Self::Truncated { offset, reason } => {
write!(f, "truncated HNSW payload at offset {offset}: {reason}")
}
}
}
}
impl std::error::Error for HnswIndexFrameError {}
pub fn encode_hnsw_index_frame(frame: &HnswIndexFrame) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend_from_slice(&HNSW_INDEX_MAGIC);
bytes.extend_from_slice(&HNSW_INDEX_VERSION_V1.to_le_bytes());
bytes.extend_from_slice(&frame.dimension.to_le_bytes());
bytes.extend_from_slice(&frame.m.to_le_bytes());
bytes.extend_from_slice(&frame.m_max0.to_le_bytes());
bytes.extend_from_slice(&frame.ef_construction.to_le_bytes());
bytes.extend_from_slice(&frame.ef_search.to_le_bytes());
bytes.extend_from_slice(&frame.ml.to_le_bytes());
bytes.push(frame.metric);
bytes.extend_from_slice(&frame.max_layer.to_le_bytes());
bytes.extend_from_slice(
&frame
.entry_point
.unwrap_or(HNSW_INDEX_NO_ENTRY_POINT)
.to_le_bytes(),
);
bytes.extend_from_slice(&(frame.nodes.len() as u64).to_le_bytes());
for node in &frame.nodes {
bytes.extend_from_slice(&node.id.to_le_bytes());
bytes.extend_from_slice(&node.max_layer.to_le_bytes());
for &val in &node.vector {
bytes.extend_from_slice(&val.to_le_bytes());
}
for conns in &node.connections {
bytes.extend_from_slice(&(conns.len() as u32).to_le_bytes());
for &conn in conns {
bytes.extend_from_slice(&conn.to_le_bytes());
}
}
}
bytes
}
pub fn decode_hnsw_index_frame(bytes: &[u8]) -> Result<HnswIndexFrame, HnswIndexFrameError> {
if bytes.len() < 8 {
return Err(HnswIndexFrameError::TooShort);
}
if bytes[0..4] != HNSW_INDEX_MAGIC {
return Err(HnswIndexFrameError::InvalidMagic);
}
let version = u32::from_le_bytes(bytes[4..8].try_into().expect("u32 length checked"));
if version != HNSW_INDEX_VERSION_V1 {
return Err(HnswIndexFrameError::UnsupportedVersion(version));
}
let mut pos = 8;
let dimension = read_u32(bytes, &mut pos, "dimension")?;
let m = read_u32(bytes, &mut pos, "m")?;
let m_max0 = read_u32(bytes, &mut pos, "m_max0")?;
let ef_construction = read_u32(bytes, &mut pos, "ef_construction")?;
let ef_search = read_u32(bytes, &mut pos, "ef_search")?;
let ml = read_f64(bytes, &mut pos, "ml")?;
let metric = read_u8(bytes, &mut pos, "metric")?;
let max_layer = read_u32(bytes, &mut pos, "max_layer")?;
let ep_value = read_u64(bytes, &mut pos, "entry_point")?;
let entry_point = if ep_value == HNSW_INDEX_NO_ENTRY_POINT {
None
} else {
Some(ep_value)
};
let node_count = read_u64(bytes, &mut pos, "node_count")?;
let dim = usize::try_from(dimension).map_err(|_| HnswIndexFrameError::Truncated {
offset: pos,
reason: "dimension",
})?;
let mut nodes = Vec::new();
for _ in 0..node_count {
let id = read_u64(bytes, &mut pos, "node id")?;
let node_max_layer = read_u32(bytes, &mut pos, "node max_layer")?;
let mut vector = Vec::new();
for _ in 0..dim {
vector.push(read_f32(bytes, &mut pos, "node vector")?);
}
let layer_count = node_max_layer
.checked_add(1)
.ok_or(HnswIndexFrameError::Truncated {
offset: pos,
reason: "node layer count",
})?;
let mut connections = Vec::new();
for _ in 0..layer_count {
let conn_count = read_u32(bytes, &mut pos, "connection count")?;
let mut conn_list = Vec::new();
for _ in 0..conn_count {
conn_list.push(read_u64(bytes, &mut pos, "connection")?);
}
connections.push(conn_list);
}
nodes.push(HnswNodeFrame {
id,
max_layer: node_max_layer,
vector,
connections,
});
}
Ok(HnswIndexFrame {
dimension,
m,
m_max0,
ef_construction,
ef_search,
ml,
metric,
max_layer,
entry_point,
nodes,
})
}
fn read_u8(bytes: &[u8], pos: &mut usize, reason: &'static str) -> Result<u8, HnswIndexFrameError> {
if *pos + 1 > bytes.len() {
return Err(HnswIndexFrameError::Truncated {
offset: *pos,
reason,
});
}
let value = bytes[*pos];
*pos += 1;
Ok(value)
}
fn read_u32(
bytes: &[u8],
pos: &mut usize,
reason: &'static str,
) -> Result<u32, HnswIndexFrameError> {
if *pos + 4 > bytes.len() {
return Err(HnswIndexFrameError::Truncated {
offset: *pos,
reason,
});
}
let value = u32::from_le_bytes(
bytes[*pos..*pos + 4]
.try_into()
.expect("u32 length checked"),
);
*pos += 4;
Ok(value)
}
fn read_u64(
bytes: &[u8],
pos: &mut usize,
reason: &'static str,
) -> Result<u64, HnswIndexFrameError> {
if *pos + 8 > bytes.len() {
return Err(HnswIndexFrameError::Truncated {
offset: *pos,
reason,
});
}
let value = u64::from_le_bytes(
bytes[*pos..*pos + 8]
.try_into()
.expect("u64 length checked"),
);
*pos += 8;
Ok(value)
}
fn read_f32(
bytes: &[u8],
pos: &mut usize,
reason: &'static str,
) -> Result<f32, HnswIndexFrameError> {
if *pos + 4 > bytes.len() {
return Err(HnswIndexFrameError::Truncated {
offset: *pos,
reason,
});
}
let value = f32::from_le_bytes(
bytes[*pos..*pos + 4]
.try_into()
.expect("f32 length checked"),
);
*pos += 4;
Ok(value)
}
fn read_f64(
bytes: &[u8],
pos: &mut usize,
reason: &'static str,
) -> Result<f64, HnswIndexFrameError> {
if *pos + 8 > bytes.len() {
return Err(HnswIndexFrameError::Truncated {
offset: *pos,
reason,
});
}
let value = f64::from_le_bytes(
bytes[*pos..*pos + 8]
.try_into()
.expect("f64 length checked"),
);
*pos += 8;
Ok(value)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_frame() -> HnswIndexFrame {
HnswIndexFrame {
dimension: 3,
m: 16,
m_max0: 32,
ef_construction: 100,
ef_search: 50,
ml: 0.360_673_760_222_104_4,
metric: 1,
max_layer: 2,
entry_point: Some(7),
nodes: vec![
HnswNodeFrame {
id: 7,
max_layer: 2,
vector: vec![1.0, 2.0, 3.0],
connections: vec![vec![1, 2], vec![2], vec![]],
},
HnswNodeFrame {
id: 1,
max_layer: 0,
vector: vec![-1.5, 0.0, 4.25],
connections: vec![vec![7]],
},
],
}
}
#[test]
fn hnsw_index_frame_round_trips() {
let frame = sample_frame();
let encoded = encode_hnsw_index_frame(&frame);
let decoded = decode_hnsw_index_frame(&encoded).unwrap();
assert_eq!(decoded, frame);
assert_eq!(encode_hnsw_index_frame(&decoded), encoded);
}
#[test]
fn hnsw_index_frame_pins_byte_layout() {
let frame = sample_frame();
let encoded = encode_hnsw_index_frame(&frame);
assert_eq!(&encoded[0..4], b"HNSW");
assert_eq!(&encoded[4..8], &1u32.to_le_bytes());
assert_eq!(&encoded[8..12], &3u32.to_le_bytes());
let metric_off = 8 + 4 * 5 + 8;
assert_eq!(encoded[metric_off], 1);
}
#[test]
fn hnsw_index_frame_encodes_missing_entry_point_as_sentinel() {
let mut frame = sample_frame();
frame.entry_point = None;
frame.nodes.clear();
let encoded = encode_hnsw_index_frame(&frame);
let ep_off = HNSW_INDEX_HEADER_LEN - 8 - 8; assert_eq!(
&encoded[ep_off..ep_off + 8],
&HNSW_INDEX_NO_ENTRY_POINT.to_le_bytes()
);
let decoded = decode_hnsw_index_frame(&encoded).unwrap();
assert_eq!(decoded.entry_point, None);
}
#[test]
fn hnsw_index_frame_rejects_bad_input() {
assert_eq!(
decode_hnsw_index_frame(&[0u8; 4]),
Err(HnswIndexFrameError::TooShort)
);
let mut bad_magic = encode_hnsw_index_frame(&sample_frame());
bad_magic[0] = b'X';
assert_eq!(
decode_hnsw_index_frame(&bad_magic),
Err(HnswIndexFrameError::InvalidMagic)
);
let mut bad_version = encode_hnsw_index_frame(&sample_frame());
bad_version[4..8].copy_from_slice(&2u32.to_le_bytes());
assert_eq!(
decode_hnsw_index_frame(&bad_version),
Err(HnswIndexFrameError::UnsupportedVersion(2))
);
let encoded = encode_hnsw_index_frame(&sample_frame());
assert!(matches!(
decode_hnsw_index_frame(&encoded[..encoded.len() - 1]),
Err(HnswIndexFrameError::Truncated { .. })
));
}
#[test]
fn hnsw_index_frame_does_not_preallocate_untrusted_counts() {
let mut frame = sample_frame();
frame.nodes.clear();
let mut encoded = encode_hnsw_index_frame(&frame);
let node_count_off = HNSW_INDEX_HEADER_LEN - 8;
encoded[node_count_off..node_count_off + 8].copy_from_slice(&u64::MAX.to_le_bytes());
assert!(matches!(
decode_hnsw_index_frame(&encoded),
Err(HnswIndexFrameError::Truncated { .. })
));
}
}