use std::io::{Read, Write};
use crate::error::{RabitqError, Result};
use crate::index::{AnnIndex, RabitqPlusIndex};
pub const MAGIC: &[u8; 8] = b"rbpx0001";
pub const VERSION: u32 = 1;
pub const MAX_DIM: u32 = 8192;
pub const MAX_N: u32 = 100_000_000;
pub const MAX_RERANK_FACTOR: u32 = 1024;
fn io_err(msg: impl Into<String>) -> RabitqError {
RabitqError::InvalidParameter(msg.into())
}
fn write_all<W: Write>(w: &mut W, buf: &[u8]) -> Result<()> {
w.write_all(buf).map_err(|e| io_err(format!("write: {e}")))
}
fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<()> {
r.read_exact(buf).map_err(|e| io_err(format!("read: {e}")))
}
fn write_u32<W: Write>(w: &mut W, v: u32) -> Result<()> {
write_all(w, &v.to_le_bytes())
}
fn write_u64<W: Write>(w: &mut W, v: u64) -> Result<()> {
write_all(w, &v.to_le_bytes())
}
fn write_f32<W: Write>(w: &mut W, v: f32) -> Result<()> {
write_all(w, &v.to_le_bytes())
}
fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
let mut b = [0u8; 4];
read_exact(r, &mut b)?;
Ok(u32::from_le_bytes(b))
}
fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
let mut b = [0u8; 8];
read_exact(r, &mut b)?;
Ok(u64::from_le_bytes(b))
}
fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
let mut b = [0u8; 4];
read_exact(r, &mut b)?;
Ok(f32::from_le_bytes(b))
}
pub fn save_index<W: Write>(
idx: &RabitqPlusIndex,
seed: u64,
items: &[(usize, Vec<f32>)],
w: &mut W,
) -> Result<()> {
let dim = idx.dim();
let n = idx.len();
let rerank_factor = idx.rerank_factor();
if items.len() != n {
return Err(io_err(format!(
"items.len()={} but index.len()={}",
items.len(),
n
)));
}
for (i, (_, v)) in items.iter().enumerate() {
if v.len() != dim {
return Err(RabitqError::DimensionMismatch {
expected: dim,
actual: v.len(),
})
.map_err(|_| io_err(format!("item {i}: vector dim {} != {}", v.len(), dim)));
}
}
if dim == 0 || dim as u32 > MAX_DIM {
return Err(io_err(format!("dim {dim} out of range (1..={MAX_DIM})")));
}
if n as u64 > MAX_N as u64 {
return Err(io_err(format!("n {n} exceeds cap {MAX_N}")));
}
if rerank_factor as u32 > MAX_RERANK_FACTOR {
return Err(io_err(format!(
"rerank_factor {rerank_factor} exceeds cap {MAX_RERANK_FACTOR}"
)));
}
write_all(w, MAGIC)?;
write_u32(w, VERSION)?;
write_u32(w, dim as u32)?;
write_u64(w, seed)?;
write_u32(w, rerank_factor as u32)?;
write_u32(w, n as u32)?;
for (id, v) in items {
if *id > u32::MAX as usize {
return Err(io_err(format!("id {id} exceeds u32::MAX")));
}
write_u32(w, *id as u32)?;
for &x in v {
write_f32(w, x)?;
}
}
Ok(())
}
pub fn load_index<R: Read>(r: &mut R) -> Result<RabitqPlusIndex> {
let mut magic = [0u8; 8];
read_exact(r, &mut magic)?;
if &magic != MAGIC {
return Err(io_err(format!(
"bad magic: expected {:?}, got {:?}",
MAGIC, &magic
)));
}
let version = read_u32(r)?;
if version > VERSION {
return Err(io_err(format!(
"unsupported version {version} (max {VERSION})"
)));
}
let dim = read_u32(r)?;
if dim == 0 || dim > MAX_DIM {
return Err(io_err(format!("dim {dim} out of range (1..={MAX_DIM})")));
}
let seed = read_u64(r)?;
let rerank_factor = read_u32(r)?;
if rerank_factor > MAX_RERANK_FACTOR {
return Err(io_err(format!(
"rerank_factor {rerank_factor} exceeds cap {MAX_RERANK_FACTOR}"
)));
}
let n = read_u32(r)?;
if n > MAX_N {
return Err(io_err(format!("n {n} exceeds cap {MAX_N}")));
}
let dim_usize = dim as usize;
let mut items: Vec<(usize, Vec<f32>)> = Vec::with_capacity(n as usize);
for _ in 0..n {
let id = read_u32(r)? as usize;
let mut v = Vec::with_capacity(dim_usize);
for _ in 0..dim_usize {
v.push(read_f32(r)?);
}
items.push((id, v));
}
RabitqPlusIndex::from_vectors_parallel(dim_usize, seed, rerank_factor as usize, items)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::AnnIndex;
use rand::{Rng as _, SeedableRng as _};
fn make_dataset(n: usize, d: usize, seed: u64) -> Vec<(usize, Vec<f32>)> {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
(0..n)
.map(|i| {
let v: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
(i, v)
})
.collect()
}
#[test]
fn serialize_roundtrip_preserves_search_results() {
let d = 32;
let n = 100;
let seed = 1337u64;
let rerank_factor = 3;
let data = make_dataset(n, d, seed);
let mut original = RabitqPlusIndex::new(d, seed, rerank_factor);
for (id, v) in &data {
original.add(*id, v.clone()).unwrap();
}
let mut buf: Vec<u8> = Vec::new();
save_index(&original, seed, &data, &mut buf).unwrap();
assert_eq!(buf.len(), 32 + n * (4 + d * 4));
let mut cursor = std::io::Cursor::new(&buf);
let loaded = load_index(&mut cursor).unwrap();
assert_eq!(loaded.len(), n);
assert_eq!(loaded.dim(), d);
assert_eq!(loaded.rerank_factor(), rerank_factor);
assert_eq!(
original.external_ids(),
loaded.external_ids(),
"external ids must be preserved through persist roundtrip",
);
let mut rng = rand::rngs::StdRng::seed_from_u64(seed.wrapping_add(7));
let k = 5;
for q_idx in 0..10 {
let q: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
let a = original.search(&q, k).unwrap();
let b = loaded.search(&q, k).unwrap();
assert_eq!(a.len(), b.len(), "query {q_idx}: result count");
for (ra, rb) in a.iter().zip(b.iter()) {
assert_eq!(ra.id, rb.id, "query {q_idx}: id mismatch");
assert_eq!(
ra.score.to_bits(),
rb.score.to_bits(),
"query {q_idx}: score bits differ ({} vs {})",
ra.score,
rb.score
);
}
}
}
#[test]
fn persist_preserves_non_dense_ids() {
let d = 24;
let n = 50;
let seed = 20_260_423_u64;
let rerank_factor = 4;
let external_ids: Vec<usize> = (0..n).map(|i| i * 13 + 7).collect();
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let items: Vec<(usize, Vec<f32>)> = external_ids
.iter()
.map(|&id| {
let v: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
(id, v)
})
.collect();
let mut original = RabitqPlusIndex::new(d, seed, rerank_factor);
for (id, v) in &items {
original.add(*id, v.clone()).unwrap();
}
let expected_u32: Vec<u32> = external_ids.iter().map(|&id| id as u32).collect();
assert_eq!(
original.external_ids(),
expected_u32.as_slice(),
"source index dropped non-dense ids before persist",
);
let mut buf: Vec<u8> = Vec::new();
save_index(&original, seed, &items, &mut buf).unwrap();
let mut cursor = std::io::Cursor::new(&buf);
let loaded = load_index(&mut cursor).unwrap();
assert_eq!(
loaded.external_ids(),
expected_u32.as_slice(),
"load_index flattened non-dense ids — regression of the \
rulake warm_from_dir limitation",
);
let expected_u64: Vec<u64> = external_ids.iter().map(|&id| id as u64).collect();
assert_eq!(loaded.ids_u64(), expected_u64);
let mut qrng = rand::rngs::StdRng::seed_from_u64(seed.wrapping_add(42));
let k = 5;
for q_idx in 0..5 {
let q: Vec<f32> = (0..d).map(|_| qrng.gen::<f32>() * 2.0 - 1.0).collect();
let a = original.search(&q, k).unwrap();
let b = loaded.search(&q, k).unwrap();
assert_eq!(a.len(), b.len(), "query {q_idx}: result count");
for (ra, rb) in a.iter().zip(b.iter()) {
assert_eq!(ra.id, rb.id, "query {q_idx}: id mismatch after load");
assert!(
external_ids.contains(&ra.id),
"query {q_idx}: returned id {} is not in the \
persisted id set (smells like a row index)",
ra.id,
);
}
}
}
fn expect_err(res: Result<RabitqPlusIndex>) -> RabitqError {
match res {
Ok(_) => panic!("expected load_index to reject the input"),
Err(e) => e,
}
}
#[test]
fn reject_bad_magic() {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(b"NOPEBAD!");
buf.extend_from_slice(&1u32.to_le_bytes()); let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("bad magic"), "got: {msg}");
}
#[test]
fn reject_version_too_new() {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&(VERSION + 1).to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("unsupported version"), "got: {msg}");
}
#[test]
fn reject_oversize_fields() {
{
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&(MAX_DIM + 1).to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("dim"), "got: {msg}");
}
{
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("dim"), "got: {msg}");
}
{
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&32u32.to_le_bytes()); buf.extend_from_slice(&0u64.to_le_bytes()); buf.extend_from_slice(&(MAX_RERANK_FACTOR + 1).to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("rerank_factor"), "got: {msg}");
}
{
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&32u32.to_le_bytes()); buf.extend_from_slice(&0u64.to_le_bytes()); buf.extend_from_slice(&1u32.to_le_bytes()); buf.extend_from_slice(&(MAX_N + 1).to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let err = expect_err(load_index(&mut cursor));
let msg = format!("{err}");
assert!(msg.contains("n "), "got: {msg}");
}
}
}