use std::fmt;
const MAGIC: &[u8; 4] = b"VRB1";
const HEADER_LEN: usize = 16;
const ID_WIDTH: u8 = 8;
#[derive(Debug, Clone, PartialEq)]
pub struct RawBulk {
pub ids: Vec<u64>,
pub vectors: Vec<f32>,
pub dimension: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VrbError {
TooShort {
got: usize,
},
BadMagic,
BadIdWidth(u8),
ReservedNotZero,
Overflow,
LengthMismatch {
got: usize,
expected: usize,
},
}
impl fmt::Display for VrbError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooShort { got } => {
write!(f, "body too short: {got} bytes (header needs {HEADER_LEN})")
}
Self::BadMagic => write!(f, "bad magic: expected b\"VRB1\""),
Self::BadIdWidth(w) => {
write!(
f,
"unsupported id_width {w}: only {ID_WIDTH} (u64) is supported"
)
}
Self::ReservedNotZero => write!(f, "reserved header bytes must be zero"),
Self::Overflow => write!(f, "overflow computing body length"),
Self::LengthMismatch { got, expected } => {
write!(f, "body length {got} != expected {expected}")
}
}
}
}
impl std::error::Error for VrbError {}
fn parse_header(body: &[u8]) -> Result<(usize, usize), VrbError> {
if body.len() < HEADER_LEN {
return Err(VrbError::TooShort { got: body.len() });
}
if &body[0..4] != MAGIC {
return Err(VrbError::BadMagic);
}
let count = u32::from_le_bytes([body[4], body[5], body[6], body[7]]) as usize;
let dim = u32::from_le_bytes([body[8], body[9], body[10], body[11]]) as usize;
if body[12] != ID_WIDTH {
return Err(VrbError::BadIdWidth(body[12]));
}
if body[13] != 0 || body[14] != 0 || body[15] != 0 {
return Err(VrbError::ReservedNotZero);
}
Ok((count, dim))
}
fn expected_body_len(count: usize, dim: usize) -> Result<usize, VrbError> {
let ids_bytes = count.checked_mul(8).ok_or(VrbError::Overflow)?;
let vec_elems = count.checked_mul(dim).ok_or(VrbError::Overflow)?;
let vec_bytes = vec_elems.checked_mul(4).ok_or(VrbError::Overflow)?;
HEADER_LEN
.checked_add(ids_bytes)
.and_then(|h| h.checked_add(vec_bytes))
.ok_or(VrbError::Overflow)
}
fn decode_ids(body: &[u8], count: usize) -> Vec<u64> {
let start = HEADER_LEN;
let end = start + count * 8;
body[start..end]
.chunks_exact(8)
.map(|c| u64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
.collect()
}
fn decode_vectors(body: &[u8], count: usize, dim: usize) -> Vec<f32> {
let start = HEADER_LEN + count * 8;
let end = start + count * dim * 4;
body[start..end]
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
pub fn decode(body: &[u8]) -> Result<RawBulk, VrbError> {
let (count, dim) = parse_header(body)?;
let expected = expected_body_len(count, dim)?;
if body.len() != expected {
return Err(VrbError::LengthMismatch {
got: body.len(),
expected,
});
}
Ok(RawBulk {
ids: decode_ids(body, count),
vectors: decode_vectors(body, count, dim),
dimension: dim,
})
}
#[must_use]
pub fn encode(ids: &[u64], vectors: &[f32], dimension: usize) -> Vec<u8> {
let count = ids.len();
let mut buf = Vec::with_capacity(HEADER_LEN + count * 8 + vectors.len() * 4);
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&u32::try_from(count).unwrap_or(u32::MAX).to_le_bytes());
buf.extend_from_slice(&u32::try_from(dimension).unwrap_or(u32::MAX).to_le_bytes());
buf.push(ID_WIDTH);
buf.extend_from_slice(&[0u8; 3]);
for id in ids {
buf.extend_from_slice(&id.to_le_bytes());
}
for v in vectors {
buf.extend_from_slice(&v.to_le_bytes());
}
buf
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
#[test]
fn roundtrip_decode_encode() {
let ids = [1u64, 2, 3];
let vectors = [0.1f32, 0.2, 0.3, 0.4, 0.5, 0.6];
let body = encode(&ids, &vectors, 2);
let raw = decode(&body).expect("valid body decodes");
assert_eq!(raw.ids, vec![1, 2, 3]);
assert_eq!(raw.vectors, vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]);
assert_eq!(raw.dimension, 2);
}
#[test]
fn encode_is_deterministic_and_pinned() {
let ids = [7u64, 42];
let vectors = [1.0f32, 2.0, 3.0, 4.0];
let a = encode(&ids, &vectors, 2);
let b = encode(&ids, &vectors, 2);
assert_eq!(a, b, "encoding must be deterministic");
assert_eq!(&a[0..4], b"VRB1");
assert_eq!(&a[4..8], &2u32.to_le_bytes());
assert_eq!(&a[8..12], &2u32.to_le_bytes());
assert_eq!(a[12], 8);
assert_eq!(&a[13..16], &[0, 0, 0]);
}
#[test]
fn empty_batch_roundtrips() {
let body = encode(&[], &[], 4);
let raw = decode(&body).expect("empty batch decodes");
assert!(raw.ids.is_empty());
assert!(raw.vectors.is_empty());
assert_eq!(raw.dimension, 4);
}
#[test]
fn bad_magic_rejected() {
let mut body = encode(&[1], &[0.0, 0.0], 2);
body[0] = b'X';
assert_eq!(decode(&body), Err(VrbError::BadMagic));
}
#[test]
fn short_body_rejected() {
let body = vec![0u8; 4];
assert_eq!(decode(&body), Err(VrbError::TooShort { got: 4 }));
}
#[test]
fn bad_id_width_rejected() {
let mut body = encode(&[1], &[0.0, 0.0], 2);
body[12] = 4; assert_eq!(decode(&body), Err(VrbError::BadIdWidth(4)));
}
#[test]
fn reserved_not_zero_rejected() {
let mut body = encode(&[1], &[0.0, 0.0], 2);
body[13] = 1;
assert_eq!(decode(&body), Err(VrbError::ReservedNotZero));
}
#[test]
fn length_mismatch_rejected() {
let mut body = encode(&[1, 2], &[0.0, 0.0, 0.0, 0.0], 2);
body.pop(); match decode(&body) {
Err(VrbError::LengthMismatch { .. }) => {}
other => panic!("expected LengthMismatch, got {other:?}"),
}
}
}