pub const IVF_INDEX_MAGIC: &[u8; 4] = b"IVF1";
#[derive(Debug, Clone, PartialEq)]
pub struct IvfListLayout {
pub centroid: Vec<f32>,
pub ids: Vec<u64>,
pub vectors: Vec<Vec<f32>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct IvfIndexLayout {
pub n_lists: usize,
pub n_probes: usize,
pub dimension: usize,
pub max_iterations: usize,
pub convergence_threshold: f32,
pub trained: bool,
pub count: usize,
pub next_id: u64,
pub lists: Vec<IvfListLayout>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IvfCodecError {
TooShort,
InvalidMagic,
Truncated,
}
impl std::fmt::Display for IvfCodecError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IvfCodecError::TooShort => write!(f, "data too short"),
IvfCodecError::InvalidMagic => write!(f, "invalid IVF magic"),
IvfCodecError::Truncated => write!(f, "truncated IVF payload"),
}
}
}
impl std::error::Error for IvfCodecError {}
pub fn encode_ivf_index(layout: &IvfIndexLayout) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend_from_slice(IVF_INDEX_MAGIC);
bytes.extend_from_slice(&(layout.n_lists as u32).to_le_bytes());
bytes.extend_from_slice(&(layout.n_probes as u32).to_le_bytes());
bytes.extend_from_slice(&(layout.dimension as u32).to_le_bytes());
bytes.extend_from_slice(&(layout.max_iterations as u32).to_le_bytes());
bytes.extend_from_slice(&layout.convergence_threshold.to_le_bytes());
bytes.push(if layout.trained { 1 } else { 0 });
bytes.extend_from_slice(&(layout.count as u64).to_le_bytes());
bytes.extend_from_slice(&layout.next_id.to_le_bytes());
bytes.extend_from_slice(&(layout.lists.len() as u32).to_le_bytes());
for list in &layout.lists {
bytes.extend_from_slice(&(list.centroid.len() as u32).to_le_bytes());
for value in &list.centroid {
bytes.extend_from_slice(&value.to_le_bytes());
}
bytes.extend_from_slice(&(list.ids.len() as u32).to_le_bytes());
for id in &list.ids {
bytes.extend_from_slice(&id.to_le_bytes());
}
bytes.extend_from_slice(&(list.vectors.len() as u32).to_le_bytes());
for vector in &list.vectors {
bytes.extend_from_slice(&(vector.len() as u32).to_le_bytes());
for value in vector {
bytes.extend_from_slice(&value.to_le_bytes());
}
}
}
bytes
}
fn read_u32(buf: &[u8], pos: &mut usize) -> Result<u32, IvfCodecError> {
if *pos + 4 > buf.len() {
return Err(IvfCodecError::Truncated);
}
let value = u32::from_le_bytes([buf[*pos], buf[*pos + 1], buf[*pos + 2], buf[*pos + 3]]);
*pos += 4;
Ok(value)
}
fn read_u64(buf: &[u8], pos: &mut usize) -> Result<u64, IvfCodecError> {
if *pos + 8 > buf.len() {
return Err(IvfCodecError::Truncated);
}
let value = u64::from_le_bytes([
buf[*pos],
buf[*pos + 1],
buf[*pos + 2],
buf[*pos + 3],
buf[*pos + 4],
buf[*pos + 5],
buf[*pos + 6],
buf[*pos + 7],
]);
*pos += 8;
Ok(value)
}
fn read_f32(buf: &[u8], pos: &mut usize) -> Result<f32, IvfCodecError> {
if *pos + 4 > buf.len() {
return Err(IvfCodecError::Truncated);
}
let value = f32::from_le_bytes([buf[*pos], buf[*pos + 1], buf[*pos + 2], buf[*pos + 3]]);
*pos += 4;
Ok(value)
}
pub fn decode_ivf_index(bytes: &[u8]) -> Result<IvfIndexLayout, IvfCodecError> {
if bytes.len() < 41 {
return Err(IvfCodecError::TooShort);
}
if &bytes[0..4] != IVF_INDEX_MAGIC {
return Err(IvfCodecError::InvalidMagic);
}
let mut pos = 4usize;
let n_lists = read_u32(bytes, &mut pos)? as usize;
let n_probes = read_u32(bytes, &mut pos)? as usize;
let dimension = read_u32(bytes, &mut pos)? as usize;
let max_iterations = read_u32(bytes, &mut pos)? as usize;
let convergence_threshold = read_f32(bytes, &mut pos)?;
if pos >= bytes.len() {
return Err(IvfCodecError::Truncated);
}
let trained = bytes[pos] == 1;
pos += 1;
let count = read_u64(bytes, &mut pos)? as usize;
let next_id = read_u64(bytes, &mut pos)?;
let list_count = read_u32(bytes, &mut pos)? as usize;
let mut lists = Vec::with_capacity(list_count);
for _ in 0..list_count {
let centroid_len = read_u32(bytes, &mut pos)? as usize;
let mut centroid = Vec::with_capacity(centroid_len);
for _ in 0..centroid_len {
centroid.push(read_f32(bytes, &mut pos)?);
}
let id_count = read_u32(bytes, &mut pos)? as usize;
let mut ids = Vec::with_capacity(id_count);
for _ in 0..id_count {
ids.push(read_u64(bytes, &mut pos)?);
}
let vector_count = read_u32(bytes, &mut pos)? as usize;
let mut vectors = Vec::with_capacity(vector_count);
for _ in 0..vector_count {
let vector_len = read_u32(bytes, &mut pos)? as usize;
let mut vector = Vec::with_capacity(vector_len);
for _ in 0..vector_len {
vector.push(read_f32(bytes, &mut pos)?);
}
vectors.push(vector);
}
lists.push(IvfListLayout {
centroid,
ids,
vectors,
});
}
Ok(IvfIndexLayout {
n_lists,
n_probes,
dimension,
max_iterations,
convergence_threshold,
trained,
count,
next_id,
lists,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn sample() -> IvfIndexLayout {
IvfIndexLayout {
n_lists: 4,
n_probes: 2,
dimension: 3,
max_iterations: 50,
convergence_threshold: 1e-4,
trained: true,
count: 3,
next_id: 3,
lists: vec![
IvfListLayout {
centroid: vec![0.0, 1.0, 2.0],
ids: vec![0, 2],
vectors: vec![vec![0.0, 1.0, 2.0], vec![0.1, 1.1, 2.1]],
},
IvfListLayout {
centroid: vec![9.0, 9.0, 9.0],
ids: vec![1],
vectors: vec![vec![9.0, 9.0, 9.0]],
},
],
}
}
#[test]
fn round_trip_preserves_layout() {
let layout = sample();
let bytes = encode_ivf_index(&layout);
let decoded = decode_ivf_index(&bytes).expect("decode");
assert_eq!(decoded, layout);
}
#[test]
fn fixture_bytes_are_byte_identical() {
let layout = sample();
let bytes = encode_ivf_index(&layout);
assert_eq!(&bytes[0..4], b"IVF1", "magic must lead the payload");
assert_eq!(&bytes[4..8], &4u32.to_le_bytes());
assert_eq!(bytes[24], 1);
}
#[test]
fn rejects_short_and_bad_magic() {
assert_eq!(decode_ivf_index(&[0u8; 10]), Err(IvfCodecError::TooShort));
let mut bytes = encode_ivf_index(&sample());
bytes[0] = b'X';
assert_eq!(decode_ivf_index(&bytes), Err(IvfCodecError::InvalidMagic));
}
}