use std::{
cmp::Ordering,
fs::{File, metadata},
io::{BufReader, BufWriter, Error as IoError, ErrorKind, Read, Write},
path::{Path, PathBuf},
sync::Arc,
};
use rayon::prelude::*;
use crate::superfile::{
BuildError,
format::{
self, FST_SEPARATOR, RESERVED_PREFIX,
checksum::{crc32c, crc32c_append},
vec::{
CLUSTER_IDX_COUNT_OFFSET, CLUSTER_IDX_ENTRY_BYTES, MAGIC_BYTES, U32_BYTES, U64_BYTES,
sub_hdr,
},
},
vector::{
distance::{Metric, SQ8_RESIDUAL_DIVISOR, l2_sq},
kmeans::{assign_to_centroids, kmeans},
quant::BitQuantizer,
rerank_codec::RerankCodec,
reservoir::{Reservoir, default_kmeans_sample_size},
rotation::RandomRotation,
spill::{ChunkedVectorSource, InMemoryVectorSource, MmapVectorSource, SpillWriter},
sq8_simd::{Sq8EncodeConsts, sq8_encode_row, update_min_max},
},
};
const OUTER_HEADER_SIZE: usize = format::vec::OUTER_HEADER_SIZE;
const DIR_ENTRY_SIZE: usize = format::vec::DIR_ENTRY_SIZE;
const SUB_HEADER_SIZE: usize = format::vec::SUB_HEADER_SIZE;
const VECTOR_DIM_MIN: usize = 16;
const VECTOR_DIM_MAX: usize = 4096;
const RESERVOIR_SEED_XOR_MASK: u64 = 0x5a5a_5a5a_5a5a_5a5a;
const KMEANS_ITERS: usize = 5;
const PASS2_CHUNK_MEM_BUDGET_BYTES: usize = 128 << 20;
const PASS2_CHUNK_ROWS_MIN: usize = 1024;
const PASS2_CHUNK_ROWS_MAX: usize = 65_536;
const N_CENT_LARGE_DOC_THRESHOLD: usize = 5_000_000;
const N_CENT_LARGE: usize = 4096;
const N_CENT_MEDIUM_DOC_THRESHOLD: usize = 100_000;
const N_CENT_MEDIUM: usize = 1024;
const N_CENT_SMALL: usize = 64;
const SUMMARY_RADIUS_SCALE: f32 = 100.0;
const SQ8_CODE_MAX: f32 = 255.0;
const SQ8_RESIDUAL_I8_CLAMP: f32 = 127.0;
fn n_cent_row_count_cap(n_docs: usize) -> usize {
if n_docs >= N_CENT_LARGE_DOC_THRESHOLD {
N_CENT_LARGE
} else if n_docs >= N_CENT_MEDIUM_DOC_THRESHOLD {
N_CENT_MEDIUM
} else {
N_CENT_SMALL
}
}
fn metric_id(m: Metric) -> u32 {
match m {
Metric::L2Sq => format::vec::METRIC_ID_L2SQ,
Metric::Cosine => format::vec::METRIC_ID_COSINE,
Metric::NegDot => format::vec::METRIC_ID_NEGDOT,
}
}
#[derive(Debug, Clone)]
pub struct VectorConfig {
pub column: String,
pub dim: usize,
pub n_cent: usize,
pub rot_seed: u64,
pub metric: Metric,
pub rerank_codec: RerankCodec,
}
impl VectorConfig {
pub fn new(column: String, dim: usize, n_cent: usize, rot_seed: u64, metric: Metric) -> Self {
Self {
column,
dim,
n_cent,
rot_seed,
metric,
rerank_codec: RerankCodec::default(),
}
}
#[must_use]
pub fn with_rerank_codec(mut self, codec: RerankCodec) -> Self {
self.rerank_codec = codec;
self
}
}
const DEFAULT_SPILL_THRESHOLD_BYTES: usize = 256 * 1024 * 1024;
struct ColumnState {
config: VectorConfig,
n_docs: u32,
reservoir: Reservoir,
pre_spill_buffer: Vec<f32>,
spill: Option<SpillWriter>,
spill_threshold_bytes: usize,
}
#[derive(Default)]
struct ScratchDir {
parent: Option<PathBuf>,
tempdir: Option<tempfile::TempDir>,
}
impl ScratchDir {
fn in_parent(parent: PathBuf) -> Result<Self, BuildError> {
let meta = metadata(&parent)?;
if !meta.is_dir() {
return Err(BuildError::Io(IoError::new(
ErrorKind::InvalidInput,
format!("VectorBuilder scratch path is not a directory: {parent:?}"),
)));
}
Ok(Self {
parent: Some(parent),
tempdir: None,
})
}
fn path(&mut self) -> Result<&Path, BuildError> {
if self.tempdir.is_none() {
let tmp = if let Some(parent) = &self.parent {
tempfile::TempDir::new_in(parent)?
} else {
tempfile::tempdir()?
};
self.tempdir = Some(tmp);
}
Ok(self
.tempdir
.as_ref()
.expect("scratch tempdir initialized")
.path())
}
}
pub struct VectorBuilder {
columns: Vec<ColumnState>,
scratch_dir: ScratchDir,
spill_threshold_bytes: usize,
}
impl Default for VectorBuilder {
fn default() -> Self {
Self::new()
}
}
impl VectorBuilder {
pub fn new() -> Self {
Self {
columns: Vec::new(),
scratch_dir: ScratchDir::default(),
spill_threshold_bytes: DEFAULT_SPILL_THRESHOLD_BYTES,
}
}
pub fn with_scratch(scratch: PathBuf) -> Result<Self, BuildError> {
Ok(Self {
columns: Vec::new(),
scratch_dir: ScratchDir::in_parent(scratch)?,
spill_threshold_bytes: DEFAULT_SPILL_THRESHOLD_BYTES,
})
}
pub fn set_spill_threshold_bytes(&mut self, threshold: usize) {
self.spill_threshold_bytes = threshold;
}
pub fn register_column(&mut self, config: VectorConfig) -> Result<u32, BuildError> {
if config.column.as_bytes().contains(&FST_SEPARATOR) {
return Err(BuildError::ReservedSeparatorInColumnName(config.column));
}
if config.column.starts_with(RESERVED_PREFIX) {
return Err(BuildError::ReservedPrefixInColumnName(config.column));
}
if !(VECTOR_DIM_MIN..=VECTOR_DIM_MAX).contains(&config.dim) {
return Err(BuildError::VectorDimOutOfRange {
column: config.column.clone(),
dim: config.dim,
});
}
if self
.columns
.iter()
.any(|c| c.config.column == config.column)
{
return Err(BuildError::DuplicateColumnName(config.column));
}
if !config.rerank_codec.is_implemented() {
return Err(BuildError::VectorRerankCodecUnimplemented {
column: config.column.clone(),
codec: config.rerank_codec.name(),
});
}
let column_id = self.columns.len() as u32;
let sample_size = default_kmeans_sample_size(config.n_cent);
let reservoir_seed = config.rot_seed ^ RESERVOIR_SEED_XOR_MASK;
let reservoir = Reservoir::new(sample_size, config.dim, reservoir_seed);
let spill_threshold_bytes = self.spill_threshold_bytes;
self.columns.push(ColumnState {
config,
n_docs: 0,
reservoir,
pre_spill_buffer: Vec::new(),
spill: None,
spill_threshold_bytes,
});
Ok(column_id)
}
pub fn set_kmeans_sample_size(
&mut self,
column_id: u32,
sample_size: usize,
) -> Result<(), BuildError> {
let idx = column_id as usize;
if idx >= self.columns.len() {
return Err(BuildError::FtsColumnTypeInvalid {
column: format!("(unregistered vector column_id {column_id})"),
actual: "n/a".to_string(),
});
}
let col = &mut self.columns[idx];
let reservoir_seed = col.config.rot_seed ^ RESERVOIR_SEED_XOR_MASK;
col.reservoir = Reservoir::new(sample_size, col.config.dim, reservoir_seed);
Ok(())
}
pub fn add(&mut self, column_id: u32, vec: &[f32]) -> Result<(), BuildError> {
let idx = column_id as usize;
if idx >= self.columns.len() {
return Err(BuildError::FtsColumnTypeInvalid {
column: format!("(unregistered vector column_id {column_id})"),
actual: "n/a".to_string(),
});
}
{
let col = &mut self.columns[idx];
if vec.len() != col.config.dim {
return Err(BuildError::FtsColumnTypeInvalid {
column: col.config.column.clone(),
actual: format!("vec.len()={} != dim={}", vec.len(), col.config.dim),
});
}
col.reservoir.update(vec);
let vec_bytes = vec.len() * 4;
let buf_bytes = col.pre_spill_buffer.len() * 4;
if let Some(spill) = col.spill.as_mut() {
spill.write_vec(vec)?;
col.n_docs += 1;
return Ok(());
}
if buf_bytes + vec_bytes <= col.spill_threshold_bytes {
col.pre_spill_buffer.extend_from_slice(vec);
col.n_docs += 1;
return Ok(());
}
}
let path = self
.scratch_dir
.path()?
.join(format!("infino_input_spill_col{column_id}.bin"));
let col = &mut self.columns[idx];
let mut spill = SpillWriter::create(path)?;
spill.write_all(bytemuck::cast_slice(&col.pre_spill_buffer))?;
spill.write_vec(vec)?;
col.pre_spill_buffer = Vec::new();
col.spill = Some(spill);
col.n_docs += 1;
Ok(())
}
pub fn finish(self) -> Result<Vec<u8>, BuildError> {
let header_dir_hint = OUTER_HEADER_SIZE + (self.columns.len() * DIR_ENTRY_SIZE) + 8;
let mut buf: Vec<u8> = Vec::with_capacity(header_dir_hint);
self.finish_to(&mut buf)?;
Ok(buf)
}
pub fn finish_to<W: Write>(self, mut w: W) -> Result<(), BuildError> {
let VectorBuilder {
columns,
mut scratch_dir,
spill_threshold_bytes: _,
} = self;
let n_columns = columns.len() as u32;
let n_docs: u64 = columns.iter().map(|c| c.n_docs as u64).max().unwrap_or(0);
let column_configs: Vec<(VectorConfig, u32)> = columns
.iter()
.map(|c| (c.config.clone(), c.n_docs))
.collect();
let mut subsections: Vec<SubsectionBytes> = Vec::with_capacity(columns.len());
if !columns.is_empty() {
let scratch_path = scratch_dir.path()?.to_path_buf();
for (col_idx, col) in columns.into_iter().enumerate() {
subsections.push(build_subsection_streaming(
col_idx as u32,
col,
&scratch_path,
)?);
}
}
let directory_offset = OUTER_HEADER_SIZE as u64;
let directory_size = (n_columns as usize) * DIR_ENTRY_SIZE;
let mut subsection_start_off =
directory_offset + directory_size as u64 + format::CRC_BYTES as u64;
let mut directory: Vec<u8> = Vec::with_capacity(directory_size);
for (i, sub) in subsections.iter().enumerate() {
let (cfg, _) = &column_configs[i];
let summary_offset_abs = subsection_start_off + sub.summary_offset_in_sub as u64;
directory.extend_from_slice(&(i as u32).to_le_bytes()); directory.extend_from_slice(&(cfg.dim as u32).to_le_bytes()); directory.extend_from_slice(&(sub.n_cent as u32).to_le_bytes()); directory.extend_from_slice(&metric_id(cfg.metric).to_le_bytes()); directory.extend_from_slice(&cfg.rot_seed.to_le_bytes()); directory.extend_from_slice(&subsection_start_off.to_le_bytes()); directory.extend_from_slice(&(sub.bytes.len() as u64).to_le_bytes()); directory.extend_from_slice(&summary_offset_abs.to_le_bytes()); directory.extend_from_slice(&((cfg.dim * 4) as u32).to_le_bytes()); directory.push(cfg.rerank_codec.codec_id()); directory.extend_from_slice(&[0u8; 3]); directory.extend_from_slice(&(sub.codec_meta_offset_in_sub as u32).to_le_bytes());
directory.extend_from_slice(&(sub.codec_meta_size as u32).to_le_bytes());
debug_assert_eq!(directory.len() % DIR_ENTRY_SIZE, 0);
subsection_start_off += sub.bytes.len() as u64;
}
let dir_crc = crc32c(&directory);
let mut outer_header: [u8; OUTER_HEADER_SIZE] = [0; OUTER_HEADER_SIZE];
{
let mut cursor = &mut outer_header[..];
cursor
.write_all(format::vec::OUTER_MAGIC) .map_err(BuildError::Io)?;
cursor
.write_all(&format::vec::VERSION.to_le_bytes()) .map_err(BuildError::Io)?;
cursor
.write_all(&n_columns.to_le_bytes()) .map_err(BuildError::Io)?;
cursor
.write_all(&n_docs.to_le_bytes()) .map_err(BuildError::Io)?;
cursor
.write_all(&directory_offset.to_le_bytes()) .map_err(BuildError::Io)?;
debug_assert!(cursor.is_empty());
}
let mut outer_crc_acc: u32 = 0;
w.write_all(&outer_header).map_err(BuildError::Io)?;
outer_crc_acc = crc32c_append(outer_crc_acc, &outer_header);
w.write_all(&directory).map_err(BuildError::Io)?;
outer_crc_acc = crc32c_append(outer_crc_acc, &directory);
let dir_crc_le = dir_crc.to_le_bytes();
w.write_all(&dir_crc_le).map_err(BuildError::Io)?;
outer_crc_acc = crc32c_append(outer_crc_acc, &dir_crc_le);
drop(directory);
for sub in subsections.drain(..) {
w.write_all(&sub.bytes).map_err(BuildError::Io)?;
outer_crc_acc = crc32c_append(outer_crc_acc, &sub.bytes);
}
let outer_crc_le = outer_crc_acc.to_le_bytes();
w.write_all(&outer_crc_le).map_err(BuildError::Io)?;
drop(scratch_dir);
Ok(())
}
}
struct SubsectionBytes {
bytes: Vec<u8>,
n_cent: usize,
summary_offset_in_sub: usize,
codec_meta_offset_in_sub: usize,
codec_meta_size: usize,
}
const BUCKET_BUF_SIZE: usize = 64 * 1024;
fn chunk_rows_for_dim(dim: usize) -> usize {
let cap_by_mem = PASS2_CHUNK_MEM_BUDGET_BYTES / (dim.max(1) * 4);
cap_by_mem.clamp(PASS2_CHUNK_ROWS_MIN, PASS2_CHUNK_ROWS_MAX)
}
fn build_subsection_streaming(
column_id: u32,
col: ColumnState,
scratch: &Path,
) -> Result<SubsectionBytes, BuildError> {
let ColumnState {
config: cfg,
n_docs: n_docs_u32,
reservoir,
pre_spill_buffer,
spill,
spill_threshold_bytes: _,
} = col;
let dim = cfg.dim;
let n_docs = n_docs_u32 as usize;
let sample_rows = reservoir.n_rows();
let n_cent = cfg
.n_cent
.max(1)
.min(n_cent_row_count_cap(n_docs))
.min(n_docs.max(1))
.min(sample_rows.max(1));
let centroids = if sample_rows == 0 || n_docs == 0 {
vec![0.0f32; n_cent * dim]
} else {
kmeans(reservoir.sample(), dim, n_cent, KMEANS_ITERS, cfg.rot_seed)
};
drop(reservoir);
let mut summary_centroid = vec![0.0f32; dim];
if !centroids.is_empty() {
let mut acc = vec![0.0f64; dim];
for c in 0..n_cent {
let cv = ¢roids[c * dim..(c + 1) * dim];
for (a, &x) in acc.iter_mut().zip(cv) {
*a += x as f64;
}
}
let inv = 1.0 / (n_cent as f64);
for (s, a) in summary_centroid.iter_mut().zip(&acc) {
*s = (*a * inv) as f32;
}
}
let rotation = RandomRotation::new(dim, cfg.rot_seed);
let quant = BitQuantizer::new(dim);
let code_bytes = quant.code_bytes();
let mut bucket_writers: Vec<BufWriter<File>> = Vec::with_capacity(n_cent);
for c in 0..n_cent {
let path = scratch.join(format!("infino_bucket_col{column_id}_c{c}.bin"));
let file = File::create(&path)?;
bucket_writers.push(BufWriter::with_capacity(BUCKET_BUF_SIZE, file));
}
let mut bucket_counts = vec![0u32; n_cent];
let chunk_rows = chunk_rows_for_dim(dim);
let mut summary_radius_sq_max: f32 = 0.0;
let codec = cfg.rerank_codec;
let sq8_family = matches!(codec, RerankCodec::Sq8ResidualEpsilon);
let (mut sq8_min_arr, mut sq8_max_arr): (Vec<f32>, Vec<f32>) = if sq8_family {
(
vec![f32::INFINITY; n_cent * dim],
vec![f32::NEG_INFINITY; n_cent * dim],
)
} else {
(Vec::new(), Vec::new())
};
if n_docs > 0 {
let mut source: Box<dyn ChunkedVectorSource> = if let Some(spill) = spill {
debug_assert!(
pre_spill_buffer.is_empty(),
"spill active but pre_spill_buffer still has {} f32s",
pre_spill_buffer.len()
);
let path = spill.finish()?;
Box::new(MmapVectorSource::open(&path, dim, chunk_rows)?)
} else {
Box::new(InMemoryVectorSource::new(
Arc::new(pre_spill_buffer),
dim,
chunk_rows,
))
};
let sq8_acc: Option<(&mut [f32], &mut [f32])> = if sq8_family {
Some((&mut sq8_min_arr, &mut sq8_max_arr))
} else {
None
};
run_pass2(
source.as_mut(),
dim,
n_cent,
code_bytes,
¢roids,
&rotation,
&quant,
&summary_centroid,
&mut bucket_writers,
&mut bucket_counts,
&mut summary_radius_sq_max,
codec,
sq8_acc,
)?;
}
let sq8_quantizers: Vec<(Vec<f32>, Vec<f32>)> = if sq8_family {
(0..n_cent)
.map(|c| {
let off = c * dim;
derive_sq8_quantizer_from_min_max(
&sq8_min_arr[off..off + dim],
&sq8_max_arr[off..off + dim],
)
})
.collect()
} else {
Vec::new()
};
drop(sq8_min_arr);
drop(sq8_max_arr);
let mut bucket_files: Vec<File> = Vec::with_capacity(n_cent);
for w in bucket_writers {
let mut inner = w.into_inner().map_err(|e| BuildError::Io(e.into_error()))?;
inner.flush()?;
bucket_files.push(inner);
}
drop(bucket_files);
let summary_radius_x100 = (summary_radius_sq_max.sqrt() * SUMMARY_RADIUS_SCALE)
.max(0.0)
.min(u32::MAX as f32) as u32;
let cluster_order = centroid_storage_order(¢roids, n_cent, dim);
let summary_size = dim * 4;
let centroids_size = n_cent * dim * 4;
let cluster_idx_size = n_cent * CLUSTER_IDX_ENTRY_BYTES;
let codec_meta_size = codec.codec_meta_bytes(dim, n_docs, n_cent, cfg.metric);
let per_vec_bytes = codec.per_vector_bytes(dim);
let per_cluster_blocks_size = n_docs * (code_bytes + format::vec::DOC_ID_BYTES + per_vec_bytes);
let summary_off = SUB_HEADER_SIZE;
let centroids_off = summary_off + summary_size;
let cluster_idx_off = centroids_off + centroids_size;
let codec_meta_off = cluster_idx_off + cluster_idx_size;
let per_cluster_blocks_off = codec_meta_off + codec_meta_size;
let total_size_before_crc = SUB_HEADER_SIZE
+ summary_size
+ centroids_size
+ cluster_idx_size
+ codec_meta_size
+ per_cluster_blocks_size;
let mut bytes = vec![0u8; total_size_before_crc];
bytes[0..MAGIC_BYTES].copy_from_slice(format::vec::SUB_MAGIC);
bytes[sub_hdr::VERSION_OFF..sub_hdr::VERSION_OFF + U32_BYTES]
.copy_from_slice(&format::vec::SUBSECTION_VERSION.to_le_bytes());
bytes[sub_hdr::CODEC_META_SIZE_OFF..sub_hdr::CODEC_META_SIZE_OFF + U32_BYTES]
.copy_from_slice(&(codec_meta_size as u32).to_le_bytes());
bytes[sub_hdr::SUMMARY_OFF_OFF..sub_hdr::SUMMARY_OFF_OFF + U64_BYTES]
.copy_from_slice(&(summary_off as u64).to_le_bytes());
bytes[sub_hdr::SUMMARY_RADIUS_X100_OFF..sub_hdr::SUMMARY_RADIUS_X100_OFF + U32_BYTES]
.copy_from_slice(&summary_radius_x100.to_le_bytes());
bytes[sub_hdr::CENTROIDS_OFF_OFF..sub_hdr::CENTROIDS_OFF_OFF + U64_BYTES]
.copy_from_slice(&(centroids_off as u64).to_le_bytes());
bytes[sub_hdr::CLUSTER_IDX_OFF_OFF..sub_hdr::CLUSTER_IDX_OFF_OFF + U64_BYTES]
.copy_from_slice(&(cluster_idx_off as u64).to_le_bytes());
bytes[sub_hdr::PER_CLUSTER_BLOCKS_OFF_OFF..sub_hdr::PER_CLUSTER_BLOCKS_OFF_OFF + U64_BYTES]
.copy_from_slice(&(per_cluster_blocks_off as u64).to_le_bytes());
bytes[summary_off..summary_off + summary_size]
.copy_from_slice(bytemuck::cast_slice(&summary_centroid));
bytes[centroids_off..centroids_off + centroids_size]
.copy_from_slice(bytemuck::cast_slice(¢roids));
let mut cluster_index: Vec<(usize, u32, u32)> = Vec::with_capacity(n_cent);
{
let mut acc_off = 0u32;
for ¢roid_id in &cluster_order {
let cnt = bucket_counts[centroid_id];
let idx_base = cluster_idx_off + centroid_id * CLUSTER_IDX_ENTRY_BYTES;
bytes[idx_base..idx_base + CLUSTER_IDX_COUNT_OFFSET]
.copy_from_slice(&acc_off.to_le_bytes());
bytes[idx_base + CLUSTER_IDX_COUNT_OFFSET..idx_base + CLUSTER_IDX_ENTRY_BYTES]
.copy_from_slice(&cnt.to_le_bytes());
cluster_index.push((centroid_id, acc_off, cnt));
acc_off += cnt;
}
debug_assert_eq!(acc_off as usize, n_docs);
}
let sq8_scale_block_off = codec_meta_off;
let sq8_offset_block_off = sq8_scale_block_off + n_cent * dim * 4;
let sq8_norms_block_off = if sq8_family && matches!(cfg.metric, Metric::L2Sq | Metric::Cosine) {
Some(sq8_offset_block_off + n_cent * dim * 4)
} else {
None
};
if sq8_family {
for (cid, (scale_c, offset_c)) in sq8_quantizers.iter().enumerate().take(n_cent) {
let sc_off = sq8_scale_block_off + cid * dim * 4;
bytes[sc_off..sc_off + dim * 4].copy_from_slice(bytemuck::cast_slice(scale_c));
let oc_off = sq8_offset_block_off + cid * dim * 4;
bytes[oc_off..oc_off + dim * 4].copy_from_slice(bytemuck::cast_slice(offset_c));
}
}
let full_row_bytes_in_bucket = if codec.writes_full() { dim * 4 } else { 0 };
let mut id_block: Vec<u8> = Vec::new();
let mut code_block: Vec<u8> = Vec::new();
let mut full_block: Vec<u8> = Vec::new();
let cluster_stride = code_bytes + format::vec::DOC_ID_BYTES + per_vec_bytes;
let mut block_cursor = 0usize;
for &(centroid_id, cluster_off_u32, cluster_count_u32) in &cluster_index {
if cluster_count_u32 == 0 {
continue;
}
let cluster_off = cluster_off_u32 as usize;
let cluster_count = cluster_count_u32 as usize;
let path = scratch.join(format!("infino_bucket_col{column_id}_c{centroid_id}.bin"));
let mut reader = BufReader::with_capacity(BUCKET_BUF_SIZE, File::open(&path)?);
id_block.resize(cluster_count * 4, 0);
code_block.resize(cluster_count * code_bytes, 0);
if full_row_bytes_in_bucket > 0 {
full_block.resize(cluster_count * full_row_bytes_in_bucket, 0);
}
for i in 0..cluster_count {
reader.read_exact(&mut id_block[i * 4..(i + 1) * 4])?;
reader.read_exact(&mut code_block[i * code_bytes..(i + 1) * code_bytes])?;
if full_row_bytes_in_bucket > 0 {
let off = i * full_row_bytes_in_bucket;
reader.read_exact(&mut full_block[off..off + full_row_bytes_in_bucket])?;
}
}
let block_base = per_cluster_blocks_off + block_cursor;
let codes_len = cluster_count * code_bytes;
let ids_len = cluster_count * 4;
bytes[block_base..block_base + codes_len].copy_from_slice(&code_block);
bytes[block_base + codes_len..block_base + codes_len + ids_len].copy_from_slice(&id_block);
match codec {
RerankCodec::RabitqOnly => {}
RerankCodec::Fp32 => {
let full_base = block_base + codes_len + ids_len;
bytes[full_base..full_base + cluster_count * dim * 4].copy_from_slice(&full_block);
}
RerankCodec::Sq8ResidualEpsilon => {
let cluster_rows: &[f32] = bytemuck::cast_slice(&full_block);
let (scale_c, offset_c) = &sq8_quantizers[centroid_id];
let ec = Sq8EncodeConsts::from_scale_offset(scale_c, offset_c);
let full_chunk_base = block_base + codes_len + ids_len;
encode_sq8_residual_cluster_simd(
cluster_rows,
dim,
cluster_count,
cluster_off,
full_chunk_base,
sq8_norms_block_off,
&ec.inv_scale,
&ec.c2,
scale_c,
offset_c,
&mut bytes,
);
}
}
block_cursor += cluster_count * cluster_stride;
}
debug_assert_eq!(block_cursor, per_cluster_blocks_size);
debug_assert_eq!(bytes.len(), total_size_before_crc);
let crc = crc32c(&bytes);
let mut out = bytes;
out.extend_from_slice(&crc.to_le_bytes());
Ok(SubsectionBytes {
bytes: out,
n_cent,
summary_offset_in_sub: summary_off,
codec_meta_offset_in_sub: if codec_meta_size == 0 {
0
} else {
codec_meta_off
},
codec_meta_size,
})
}
#[allow(clippy::too_many_arguments)]
fn encode_sq8_residual_cluster_simd(
cluster_rows: &[f32],
dim: usize,
cluster_count: usize,
cluster_doc_off: usize,
full_chunk_base: usize,
sq8_norms_block_off: Option<usize>,
inv_scale_c: &[f32],
c2_c: &[f32],
scale_c: &[f32],
offset_c: &[f32],
bytes: &mut [u8],
) {
debug_assert_eq!(cluster_rows.len(), cluster_count * dim);
let residual_divisor = SQ8_RESIDUAL_DIVISOR;
let row_bytes = dim * 2;
for i in 0..cluster_count {
let src = &cluster_rows[i * dim..(i + 1) * dim];
let pos = cluster_doc_off + i;
let row_off = full_chunk_base + i * row_bytes;
let code_off = row_off;
let res_off = row_off + dim;
sq8_encode_row(src, inv_scale_c, c2_c, &mut bytes[code_off..code_off + dim]);
let mut acc = 0.0f64;
for d in 0..dim {
let qc = bytes[code_off + d];
let base = (qc as f32) * scale_c[d] + offset_c[d];
let step = scale_c[d] / residual_divisor;
let rq = if step > 0.0 {
((src[d] - base) / step)
.round()
.clamp(-SQ8_RESIDUAL_I8_CLAMP, SQ8_RESIDUAL_I8_CLAMP) as i8
} else {
0
};
bytes[res_off + d] = rq.to_le_bytes()[0];
if sq8_norms_block_off.is_some() {
let x = base + (rq as f32) * step;
acc += (x as f64) * (x as f64);
}
}
if let Some(norms_off) = sq8_norms_block_off {
let n_off = norms_off + pos * 4;
bytes[n_off..n_off + 4].copy_from_slice(&(acc as f32).to_le_bytes());
}
}
}
#[inline]
fn derive_sq8_quantizer_from_min_max(min: &[f32], max: &[f32]) -> (Vec<f32>, Vec<f32>) {
debug_assert_eq!(min.len(), max.len());
let dim = min.len();
let mut scale = vec![0.0f32; dim];
let mut offset = vec![0.0f32; dim];
for d in 0..dim {
let span = max[d] - min[d];
if span > 0.0 && span.is_finite() {
offset[d] = min[d];
scale[d] = span / SQ8_CODE_MAX;
} else {
offset[d] = if min[d].is_finite() { min[d] } else { 0.0 };
scale[d] = 1.0;
}
}
(scale, offset)
}
fn centroid_storage_order(centroids: &[f32], n_cent: usize, dim: usize) -> Vec<usize> {
let mut order: Vec<usize> = (0..n_cent).collect();
order_centroids_recursive(&mut order, centroids, dim);
order
}
fn order_centroids_recursive(order: &mut [usize], centroids: &[f32], dim: usize) {
if order.len() <= 1 || dim == 0 {
return;
}
let mut best_dim = 0usize;
let mut best_span = 0.0f32;
for d in 0..dim {
let mut lo = f32::INFINITY;
let mut hi = f32::NEG_INFINITY;
for &c in order.iter() {
let v = centroids[c * dim + d];
lo = lo.min(v);
hi = hi.max(v);
}
let span = hi - lo;
if span > best_span {
best_span = span;
best_dim = d;
}
}
order.sort_unstable_by(|&a, &b| {
centroids[a * dim + best_dim]
.partial_cmp(¢roids[b * dim + best_dim])
.unwrap_or(Ordering::Equal)
.then_with(|| a.cmp(&b))
});
let mid = order.len() / 2;
let (left, right) = order.split_at_mut(mid);
order_centroids_recursive(left, centroids, dim);
order_centroids_recursive(right, centroids, dim);
}
#[allow(clippy::too_many_arguments)]
fn run_pass2(
source: &mut dyn ChunkedVectorSource,
dim: usize,
n_cent: usize,
code_bytes: usize,
centroids: &[f32],
rotation: &RandomRotation,
quant: &BitQuantizer,
summary_centroid: &[f32],
bucket_writers: &mut [BufWriter<File>],
bucket_counts: &mut [u32],
summary_radius_sq_max: &mut f32,
codec: RerankCodec,
mut sq8_min_max: Option<(&mut [f32], &mut [f32])>,
) -> Result<(), BuildError> {
let chunk_rows_cap = source.chunk_rows();
let mut chunk_rotated = vec![0f32; chunk_rows_cap * dim];
let mut chunk_assignments = vec![0u32; chunk_rows_cap];
let mut chunk_codes = vec![0u8; chunk_rows_cap * code_bytes];
let mut global_doc_id: u32 = 0;
while let Some(chunk) = source.next_chunk() {
let actual_rows = chunk.len() / dim;
debug_assert!(actual_rows <= chunk_rows_cap);
let asgn = &mut chunk_assignments[..actual_rows];
assign_to_centroids(&chunk[..actual_rows * dim], centroids, dim, n_cent, asgn);
chunk_rotated[..actual_rows * dim]
.par_chunks_mut(dim)
.zip(chunk[..actual_rows * dim].par_chunks(dim))
.for_each(|(dst, src)| rotation.apply(src, dst));
chunk_codes[..actual_rows * code_bytes]
.par_chunks_mut(code_bytes)
.enumerate()
.for_each(|(r, code_out)| {
let rot_row = &chunk_rotated[r * dim..(r + 1) * dim];
quant.encode_rotated_into(rot_row, code_out);
});
let chunk_max = (0..actual_rows)
.into_par_iter()
.map(|r| {
let v = &chunk[r * dim..(r + 1) * dim];
l2_sq(v, summary_centroid)
})
.reduce(|| 0.0f32, f32::max);
if chunk_max > *summary_radius_sq_max {
*summary_radius_sq_max = chunk_max;
}
let write_full = codec.writes_full();
let mut sq8_acc = sq8_min_max.as_mut();
for r in 0..actual_rows {
let cid = asgn[r] as usize;
let local_doc_id = global_doc_id + r as u32;
let writer = &mut bucket_writers[cid];
writer.write_all(&local_doc_id.to_le_bytes())?;
writer.write_all(&chunk_codes[r * code_bytes..(r + 1) * code_bytes])?;
if write_full {
writer.write_all(bytemuck::cast_slice(&chunk[r * dim..(r + 1) * dim]))?;
}
if let Some((mn, mx)) = sq8_acc.as_deref_mut() {
let row = &chunk[r * dim..(r + 1) * dim];
let off = cid * dim;
update_min_max(row, &mut mn[off..off + dim], &mut mx[off..off + dim]);
}
bucket_counts[cid] += 1;
}
global_doc_id += actual_rows as u32;
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::fs::{read, write};
use super::*;
fn cfg(name: &str, dim: usize) -> VectorConfig {
VectorConfig {
column: name.to_string(),
dim,
n_cent: 4,
rot_seed: 7,
metric: Metric::L2Sq,
rerank_codec: RerankCodec::Fp32,
}
}
#[test]
fn register_column_returns_sequential_ids() {
let mut b = VectorBuilder::new();
assert_eq!(b.register_column(cfg("a", 16)).expect("register column"), 0);
assert_eq!(b.register_column(cfg("b", 32)).expect("register column"), 1);
}
#[test]
fn register_column_rejects_separator_in_name() {
let mut b = VectorBuilder::new();
let bad = cfg("a\x1Fb", 16);
let err = b.register_column(bad).expect_err("expected error");
assert!(matches!(err, BuildError::ReservedSeparatorInColumnName(_)));
}
#[test]
fn register_column_rejects_inf_prefix() {
let mut b = VectorBuilder::new();
let bad = cfg("inf.embedding", 16);
let err = b.register_column(bad).expect_err("expected error");
assert!(matches!(err, BuildError::ReservedPrefixInColumnName(_)));
}
#[test]
fn register_column_rejects_dim_too_small() {
let mut b = VectorBuilder::new();
let err = b.register_column(cfg("a", 8)).expect_err("expected error");
assert!(matches!(err, BuildError::VectorDimOutOfRange { .. }));
}
#[test]
fn register_column_rejects_dim_too_large() {
let mut b = VectorBuilder::new();
let err = b
.register_column(cfg("a", 5000))
.expect_err("expected error");
assert!(matches!(err, BuildError::VectorDimOutOfRange { .. }));
}
#[test]
fn register_column_rejects_duplicate() {
let mut b = VectorBuilder::new();
b.register_column(cfg("a", 16)).expect("register column");
let err = b.register_column(cfg("a", 32)).expect_err("expected error");
assert!(matches!(err, BuildError::DuplicateColumnName(_)));
}
#[test]
fn add_rejects_unknown_column_id() {
let mut b = VectorBuilder::new();
b.register_column(cfg("a", 16)).expect("register column");
let err = b.add(99, &[0.0; 16]).expect_err("expected error");
assert!(matches!(err, BuildError::FtsColumnTypeInvalid { .. }));
}
#[test]
fn add_rejects_wrong_dim() {
let mut b = VectorBuilder::new();
b.register_column(cfg("a", 16)).expect("register column");
let err = b.add(0, &[0.0; 8]).expect_err("expected error");
assert!(matches!(err, BuildError::FtsColumnTypeInvalid { .. }));
}
#[test]
fn finish_emits_valid_outer_header() {
let mut b = VectorBuilder::new();
b.register_column(cfg("a", 16)).expect("register column");
for i in 0..32 {
let v: Vec<f32> = (0..16).map(|j| (i + j) as f32).collect();
b.add(0, &v).expect("add to vector builder");
}
let blob = b.finish().expect("finish");
assert_eq!(&blob[0..8], format::vec::OUTER_MAGIC);
let version = u32::from_le_bytes([blob[8], blob[9], blob[10], blob[11]]);
assert_eq!(version, format::vec::VERSION);
let n_cols = u32::from_le_bytes([blob[12], blob[13], blob[14], blob[15]]);
assert_eq!(n_cols, 1);
}
#[test]
fn finish_with_no_docs_produces_valid_blob() {
let mut b = VectorBuilder::new();
b.register_column(cfg("a", 16)).expect("register column");
let blob = b.finish().expect("finish");
assert_eq!(&blob[0..8], format::vec::OUTER_MAGIC);
let mut buf = [0u8; 8];
buf.copy_from_slice(&blob[16..24]);
assert_eq!(u64::from_le_bytes(buf), 0);
}
#[test]
fn sq8_tiny_shard_writes_physical_n_cent_to_directory() {
use bytes::Bytes;
use crate::superfile::vector::reader::VectorReader;
let dim = 16;
let configured_n_cent = 4;
let mut b = VectorBuilder::new();
b.register_column(VectorConfig {
column: "v".into(),
dim,
n_cent: configured_n_cent,
rot_seed: 7,
metric: Metric::Cosine,
rerank_codec: RerankCodec::Sq8ResidualEpsilon,
})
.expect("register sq8 column");
b.add(0, &[1.0; 16]).expect("add single row");
let blob = b.finish().expect("finish tiny sq8 shard");
let dir_off = OUTER_HEADER_SIZE;
let physical_n_cent = u32::from_le_bytes(
blob[dir_off + 8..dir_off + 12]
.try_into()
.expect("n_cent bytes"),
);
assert_eq!(
physical_n_cent, 1,
"directory must describe physical IVF layout, not configured n_cent"
);
let json = format!(
r#"[{{"column":"v","dim":{dim},"n_cent":{configured_n_cent},"rot_seed":7,"metric":"cosine"}}]"#
);
let reader = VectorReader::open(Bytes::from(blob), &json).expect("open tiny sq8 shard");
assert_eq!(reader.n_docs(), 1);
}
#[test]
fn finish_two_columns_at_different_dims() {
let mut b = VectorBuilder::new();
b.register_column(cfg("a", 16)).expect("register column");
b.register_column(cfg("b", 32)).expect("register column");
for _ in 0..16 {
b.add(0, &[1.0; 16]).expect("add to vector builder");
b.add(1, &[1.0; 32]).expect("add to vector builder");
}
let blob = b.finish().expect("finish");
let n_cols = u32::from_le_bytes([blob[12], blob[13], blob[14], blob[15]]);
assert_eq!(n_cols, 2);
let dir_off = OUTER_HEADER_SIZE;
let entry_a_dim = u32::from_le_bytes([
blob[dir_off + 4],
blob[dir_off + 5],
blob[dir_off + 6],
blob[dir_off + 7],
]);
let entry_b_dim = u32::from_le_bytes([
blob[dir_off + DIR_ENTRY_SIZE + 4],
blob[dir_off + DIR_ENTRY_SIZE + 5],
blob[dir_off + DIR_ENTRY_SIZE + 6],
blob[dir_off + DIR_ENTRY_SIZE + 7],
]);
assert_eq!(entry_a_dim, 16);
assert_eq!(entry_b_dim, 32);
}
#[test]
fn build_via_forced_spill_path_round_trips() {
let dim = 16;
let n_docs = 64usize;
let n_cent = 4usize;
let mut b = VectorBuilder::new();
b.set_spill_threshold_bytes(0);
b.register_column(VectorConfig {
column: "v".into(),
dim,
n_cent,
rot_seed: 7,
metric: Metric::L2Sq,
rerank_codec: RerankCodec::Fp32,
})
.expect("register column");
let mut corpus = Vec::with_capacity(n_docs * dim);
for d in 0..n_docs {
let mut row = vec![0.0f32; dim];
row[0] = d as f32;
row[1] = (d as f32) * 0.5;
row[2] = -(d as f32);
corpus.extend_from_slice(&row);
b.add(0, &row).expect("add via forced-spill path");
}
let blob = b.finish().expect("finish via forced-spill path");
assert_eq!(&blob[0..8], format::vec::OUTER_MAGIC);
let n_cols = u32::from_le_bytes([blob[12], blob[13], blob[14], blob[15]]);
assert_eq!(n_cols, 1);
let n_docs_hdr = u64::from_le_bytes(blob[16..24].try_into().expect("8 bytes"));
assert_eq!(n_docs_hdr, n_docs as u64);
}
#[tokio::test]
async fn forced_spill_path_matches_in_ram_path_on_self_nn() {
use bytes::Bytes;
use crate::superfile::vector::reader::VectorReader;
let dim = 16;
let n_docs = 50;
let n_cent = 4;
let mut corpus = Vec::with_capacity(n_docs * dim);
for d in 0..n_docs {
let mut row = vec![0.0f32; dim];
for (j, slot) in row.iter_mut().enumerate() {
*slot = ((d as f32) * 0.07 + (j as f32) * 0.13).sin();
}
corpus.extend_from_slice(&row);
}
let build = |force_spill: bool| -> Vec<u8> {
let mut b = VectorBuilder::new();
if force_spill {
b.set_spill_threshold_bytes(0);
}
b.register_column(VectorConfig {
column: "v".into(),
dim,
n_cent,
rot_seed: 7,
metric: Metric::L2Sq,
rerank_codec: RerankCodec::Fp32,
})
.expect("register column");
for d in 0..n_docs {
b.add(0, &corpus[d * dim..(d + 1) * dim])
.expect("add to vector builder");
}
b.finish().expect("finish")
};
let blob_ram = build(false);
let blob_spill = build(true);
let json = format!(
r#"[{{"column":"v","dim":{dim},"n_cent":{n_cent},"rot_seed":7,"metric":"l2sq"}}]"#
);
let r_ram = VectorReader::open(Bytes::from(blob_ram), &json).expect("open ram");
let r_spill = VectorReader::open(Bytes::from(blob_spill), &json).expect("open spill");
let nprobe = n_cent;
let rerank_mult = n_docs + 1;
for q in 0..n_docs {
let query = &corpus[q * dim..(q + 1) * dim];
let top_ram = r_ram
.search("v", query, 1, nprobe, rerank_mult)
.await
.expect("search ram");
let top_spill = r_spill
.search("v", query, 1, nprobe, rerank_mult)
.await
.expect("search spill");
assert_eq!(
top_ram[0].0 as usize, q,
"in-RAM path missed self-NN at q={q}"
);
assert_eq!(
top_spill[0].0 as usize, q,
"spill path missed self-NN at q={q}"
);
}
}
#[test]
fn finish_to_matches_finish_byte_for_byte() {
let build = || -> VectorBuilder {
let mut b = VectorBuilder::new();
b.register_column(cfg("v", 16)).expect("register column");
for i in 0..32 {
let v: Vec<f32> = (0..16).map(|j| ((i + j) as f32) * 0.1).collect();
b.add(0, &v).expect("add to vector builder");
}
b
};
let blob_finish = build().finish().expect("finish");
let mut blob_finish_to: Vec<u8> = Vec::new();
build()
.finish_to(&mut blob_finish_to)
.expect("finish_to Vec<u8>");
assert_eq!(
blob_finish, blob_finish_to,
"finish_to must produce identical bytes to finish"
);
}
#[test]
fn finish_to_cursor_round_trips_outer_crc() {
use std::io::Cursor;
let mut b = VectorBuilder::new();
b.register_column(cfg("v", 16)).expect("register column");
for i in 0..32 {
let v: Vec<f32> = (0..16).map(|j| ((i + j) as f32) * 0.1).collect();
b.add(0, &v).expect("add to vector builder");
}
let mut buf: Vec<u8> = Vec::new();
{
let cursor = Cursor::new(&mut buf);
b.finish_to(cursor).expect("finish_to Cursor");
}
assert_eq!(
&buf[0..8],
format::vec::OUTER_MAGIC,
"outer magic preserved"
);
assert!(
buf.len() >= OUTER_HEADER_SIZE + DIR_ENTRY_SIZE + 4 + 4,
"blob too short: {} bytes",
buf.len()
);
let body_len = buf.len() - 4;
let trailing_crc = u32::from_le_bytes([
buf[body_len],
buf[body_len + 1],
buf[body_len + 2],
buf[body_len + 3],
]);
let recomputed = crc32c(&buf[..body_len]);
assert_eq!(
trailing_crc, recomputed,
"trailing outer CRC32C must match recomputed body CRC"
);
}
#[tokio::test]
async fn finish_to_temp_file_round_trips_through_reader() {
use std::io::BufWriter;
use bytes::Bytes;
use crate::superfile::vector::reader::VectorReader;
let dim = 16usize;
let n_docs = 32usize;
let n_cent = 4usize;
let mut b = VectorBuilder::new();
b.register_column(VectorConfig {
column: "v".into(),
dim,
n_cent,
rot_seed: 7,
metric: Metric::L2Sq,
rerank_codec: RerankCodec::Fp32,
})
.expect("register column");
for d in 0..n_docs {
let row: Vec<f32> = (0..dim)
.map(|j| ((d as f32) * 0.07 + (j as f32) * 0.13).sin())
.collect();
b.add(0, &row).expect("add to vector builder");
}
let tmp = tempfile::tempdir().expect("tempdir");
let path = tmp.path().join("vector_blob.bin");
{
let file = File::create(&path).expect("create blob file");
let writer = BufWriter::new(file);
b.finish_to(writer).expect("finish_to BufWriter<File>");
}
let blob = read(&path).expect("read blob file");
let json = format!(
r#"[{{"column":"v","dim":{dim},"n_cent":{n_cent},"rot_seed":7,"metric":"l2sq"}}]"#
);
let reader = VectorReader::open(Bytes::from(blob), &json)
.expect("open VectorReader from streamed blob");
let query: Vec<f32> = (0..dim).map(|j| ((j as f32) * 0.13).sin()).collect();
let hits = reader
.search("v", &query, 5, n_cent, n_docs + 1)
.await
.expect("kNN search");
assert!(!hits.is_empty(), "search returned no hits");
}
#[test]
fn vector_config_new_and_with_rerank_codec() {
let dim = 16usize;
let n_cent = 4usize;
let rot_seed = 7u64;
let base = VectorConfig::new("v".into(), dim, n_cent, rot_seed, Metric::Cosine);
assert_eq!(base.column, "v");
assert_eq!(base.dim, dim);
assert_eq!(base.n_cent, n_cent);
assert_eq!(base.rot_seed, rot_seed);
assert_eq!(base.metric, Metric::Cosine);
assert_eq!(base.rerank_codec, RerankCodec::default());
let overridden = base.with_rerank_codec(RerankCodec::Fp32);
assert_eq!(overridden.rerank_codec, RerankCodec::Fp32);
assert_eq!(overridden.column, "v");
}
#[test]
fn vector_builder_default_matches_new() {
let mut b = VectorBuilder::default();
assert_eq!(b.register_column(cfg("a", 16)).expect("register column"), 0);
}
#[test]
fn set_kmeans_sample_size_ok_and_unregistered() {
const SAMPLE_SIZE: usize = 1024;
let mut b = VectorBuilder::new();
b.register_column(cfg("a", 16)).expect("register column");
b.set_kmeans_sample_size(0, SAMPLE_SIZE)
.expect("resize sample for registered column");
let err = b
.set_kmeans_sample_size(9, SAMPLE_SIZE)
.expect_err("unregistered column id");
assert!(matches!(err, BuildError::FtsColumnTypeInvalid { .. }));
}
#[test]
fn with_scratch_accepts_dir_and_rejects_file() {
let dir = tempfile::tempdir().expect("tempdir");
let mut b = VectorBuilder::with_scratch(dir.path().to_path_buf())
.expect("scratch under existing dir");
assert_eq!(b.register_column(cfg("a", 16)).expect("register column"), 0);
let file_path = dir.path().join("not-a-dir");
write(&file_path, b"x").expect("write file");
match VectorBuilder::with_scratch(file_path) {
Ok(_) => panic!("scratch path is a file, expected rejection"),
Err(err) => assert!(matches!(err, BuildError::Io(_))),
}
}
}