use crate::multivector::types::WarpIndexConfig;
use crate::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResidualCodec {
centroids: Vec<f32>,
num_centroids: usize,
dim: usize,
bucket_cutoffs: Vec<f32>,
bucket_weights: Vec<f32>,
nbits: u8,
}
impl ResidualCodec {
pub fn train(
embeddings: &[f32],
dim: usize,
num_centroids: usize,
nbits: u8,
iterations: usize,
) -> Result<Self> {
if nbits != 2 && nbits != 4 {
return Err(crate::Error::InvalidInput("nbits must be 2 or 4".to_string()));
}
if dim == 0 {
return Err(crate::Error::InvalidInput("dim must be > 0".to_string()));
}
let n = embeddings.len() / dim;
if n < num_centroids {
return Err(crate::Error::InvalidInput(format!(
"Insufficient training data: {n} samples for {num_centroids} centroids"
)));
}
contract_pre_embedding_lookup!(embeddings);
let centroids = Self::kmeans_clustering(embeddings, dim, num_centroids, iterations);
let residuals = Self::compute_all_residuals(embeddings, dim, ¢roids, num_centroids);
let (bucket_cutoffs, bucket_weights) =
Self::learn_quantization_params(&residuals, dim, nbits);
Ok(Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits })
}
#[must_use]
pub fn with_params(
centroids: Vec<f32>,
num_centroids: usize,
dim: usize,
bucket_cutoffs: Vec<f32>,
bucket_weights: Vec<f32>,
nbits: u8,
) -> Self {
assert!(dim > 0, "dim must be > 0: division by zero in centroid/residual arithmetic");
Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits }
}
#[must_use]
pub fn num_centroids(&self) -> usize {
self.num_centroids
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn nbits(&self) -> u8 {
self.nbits
}
#[must_use]
pub fn packed_size(&self) -> usize {
(self.dim * self.nbits as usize + 7) / 8
}
#[must_use]
pub fn centroid(&self, id: usize) -> &[f32] {
let start = id * self.dim;
&self.centroids[start..start + self.dim]
}
#[must_use]
pub fn centroids(&self) -> &[f32] {
&self.centroids
}
#[must_use]
pub fn find_nearest_centroid(&self, embedding: &[f32]) -> usize {
contract_pre_configuration!(embedding);
let mut best_id = 0;
let mut best_dist = f32::MAX;
for c in 0..self.num_centroids {
let centroid = self.centroid(c);
let dist = Self::squared_distance(embedding, centroid);
if dist < best_dist {
best_dist = dist;
best_id = c;
}
}
best_id
}
#[must_use]
pub fn compress(&self, embedding: &[f32]) -> (usize, Vec<u8>) {
contract_pre_embedding_lookup!(embedding);
let centroid_id = self.find_nearest_centroid(embedding);
let centroid = self.centroid(centroid_id);
let residual: Vec<f32> =
embedding.iter().zip(centroid.iter()).map(|(e, c)| e - c).collect();
let codes = self.quantize_residual(&residual);
let packed = self.pack_codes(&codes);
(centroid_id, packed)
}
#[must_use]
pub fn decompress_score(
&self,
query_token: &[f32],
centroid_id: usize,
centroid_score: f32,
packed_residual: &[u8],
) -> f32 {
let _ = centroid_id;
let codes = self.unpack_codes(packed_residual);
let num_buckets = 1usize << self.nbits;
let residual_score: f32 = codes
.iter()
.enumerate()
.map(|(d, &code)| {
let weight_idx = d * num_buckets + code as usize;
query_token[d] * self.bucket_weights[weight_idx]
})
.sum();
centroid_score + residual_score
}
#[must_use]
pub fn centroid_score(&self, query_token: &[f32], centroid_id: usize) -> f32 {
let centroid = self.centroid(centroid_id);
Self::dot_product(query_token, centroid)
}
fn quantize_residual(&self, residual: &[f32]) -> Vec<u8> {
let num_buckets = 1usize << self.nbits;
residual
.iter()
.enumerate()
.map(|(d, &value)| {
let cutoff_start = d * (num_buckets - 1);
let cutoffs = &self.bucket_cutoffs[cutoff_start..cutoff_start + num_buckets - 1];
cutoffs.iter().position(|&c| value < c).unwrap_or(num_buckets - 1) as u8
})
.collect()
}
fn pack_codes(&self, codes: &[u8]) -> Vec<u8> {
match self.nbits {
2 => {
codes
.chunks(4)
.map(|chunk| {
let mut byte = 0u8;
for (i, &code) in chunk.iter().enumerate() {
byte |= (code & 0x03) << (i * 2);
}
byte
})
.collect()
}
4 => {
codes
.chunks(2)
.map(|chunk| {
let low = chunk.first().copied().unwrap_or(0) & 0x0F;
let high = chunk.get(1).copied().unwrap_or(0) & 0x0F;
low | (high << 4)
})
.collect()
}
_ => panic!("Unsupported nbits: {}", self.nbits),
}
}
fn unpack_codes(&self, packed: &[u8]) -> Vec<u8> {
match self.nbits {
2 => packed
.iter()
.flat_map(|&byte| (0..4).map(move |i| (byte >> (i * 2)) & 0x03))
.take(self.dim)
.collect(),
4 => packed
.iter()
.flat_map(|&byte| vec![byte & 0x0F, (byte >> 4) & 0x0F])
.take(self.dim)
.collect(),
_ => panic!("Unsupported nbits: {}", self.nbits),
}
}
fn kmeans_clustering(embeddings: &[f32], dim: usize, k: usize, iterations: usize) -> Vec<f32> {
let n = embeddings.len() / dim;
let mut centroids = Self::kmeans_plus_plus_init(embeddings, dim, k);
let mut assignments = vec![0usize; n];
for _ in 0..iterations {
for i in 0..n {
let point = &embeddings[i * dim..(i + 1) * dim];
let mut best_dist = f32::MAX;
let mut best_c = 0;
for c in 0..k {
let centroid = ¢roids[c * dim..(c + 1) * dim];
let dist = Self::squared_distance(point, centroid);
if dist < best_dist {
best_dist = dist;
best_c = c;
}
}
assignments[i] = best_c;
}
let mut new_centroids = vec![0.0f32; k * dim];
let mut counts = vec![0usize; k];
for i in 0..n {
let c = assignments[i];
counts[c] += 1;
let point = &embeddings[i * dim..(i + 1) * dim];
for d in 0..dim {
new_centroids[c * dim + d] += point[d];
}
}
for c in 0..k {
if counts[c] > 0 {
for d in 0..dim {
new_centroids[c * dim + d] /= counts[c] as f32;
}
} else {
for d in 0..dim {
new_centroids[c * dim + d] = centroids[c * dim + d];
}
}
}
centroids = new_centroids;
}
centroids
}
fn kmeans_plus_plus_init(embeddings: &[f32], dim: usize, k: usize) -> Vec<f32> {
let n = embeddings.len() / dim;
let mut centroids = Vec::with_capacity(k * dim);
let mut rng_state = 42u64;
let first_idx = Self::simple_random(&mut rng_state, n);
centroids.extend_from_slice(&embeddings[first_idx * dim..(first_idx + 1) * dim]);
let mut distances = vec![f32::MAX; n];
for _ in 1..k {
let num_centroids = centroids.len() / dim;
for i in 0..n {
let point = &embeddings[i * dim..(i + 1) * dim];
let centroid = ¢roids[(num_centroids - 1) * dim..num_centroids * dim];
let dist = Self::squared_distance(point, centroid);
distances[i] = distances[i].min(dist);
}
let total: f32 = distances.iter().sum();
if total <= 0.0 {
let idx = Self::simple_random(&mut rng_state, n);
centroids.extend_from_slice(&embeddings[idx * dim..(idx + 1) * dim]);
continue;
}
let threshold = Self::simple_random_f32(&mut rng_state) * total;
let mut cumsum = 0.0f32;
let mut chosen = 0;
for (i, &d) in distances.iter().enumerate() {
cumsum += d;
if cumsum >= threshold {
chosen = i;
break;
}
}
centroids.extend_from_slice(&embeddings[chosen * dim..(chosen + 1) * dim]);
}
centroids
}
fn compute_all_residuals(
embeddings: &[f32],
dim: usize,
centroids: &[f32],
num_centroids: usize,
) -> Vec<f32> {
let n = embeddings.len() / dim;
let mut residuals = Vec::with_capacity(n * dim);
for i in 0..n {
let point = &embeddings[i * dim..(i + 1) * dim];
let mut best_c = 0;
let mut best_dist = f32::MAX;
for c in 0..num_centroids {
let centroid = ¢roids[c * dim..(c + 1) * dim];
let dist = Self::squared_distance(point, centroid);
if dist < best_dist {
best_dist = dist;
best_c = c;
}
}
let centroid = ¢roids[best_c * dim..(best_c + 1) * dim];
for d in 0..dim {
residuals.push(point[d] - centroid[d]);
}
}
residuals
}
fn learn_quantization_params(residuals: &[f32], dim: usize, nbits: u8) -> (Vec<f32>, Vec<f32>) {
let num_buckets = 1usize << nbits;
let n = residuals.len() / dim;
let mut cutoffs = Vec::with_capacity(dim * (num_buckets - 1));
let mut weights = Vec::with_capacity(dim * num_buckets);
for d in 0..dim {
let mut values: Vec<f32> = (0..n).map(|i| residuals[i * dim + d]).collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
for b in 1..num_buckets {
let quantile_idx = (b * n) / num_buckets;
cutoffs.push(values[quantile_idx.min(n - 1)]);
}
for b in 0..num_buckets {
let start = (b * n) / num_buckets;
let end = ((b + 1) * n) / num_buckets;
let end = end.max(start + 1).min(n);
let sum: f32 = values[start..end].iter().sum();
let mean = sum / (end - start) as f32;
weights.push(mean);
}
}
(cutoffs, weights)
}
fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
}
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
fn simple_random(state: &mut u64, max: usize) -> usize {
*state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((*state >> 33) as usize) % max
}
fn simple_random_f32(state: &mut u64) -> f32 {
*state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((*state >> 33) as f32) / (u32::MAX as f32)
}
}
pub struct ResidualCodecBuilder {
config: WarpIndexConfig,
}
impl ResidualCodecBuilder {
#[must_use]
pub fn new(config: WarpIndexConfig) -> Self {
Self { config }
}
pub fn train(&self, embeddings: &[f32]) -> Result<ResidualCodec> {
contract_pre_embedding_lookup!(embeddings);
ResidualCodec::train(
embeddings,
self.config.token_dim,
self.config.num_centroids,
self.config.nbits,
self.config.kmeans_iterations,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_test_embeddings(n: usize, dim: usize) -> Vec<f32> {
let mut embeddings = Vec::with_capacity(n * dim);
let mut rng_state = 12345u64;
for _ in 0..(n * dim) {
rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
let val = ((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
embeddings.push(val);
}
embeddings
}
#[test]
fn test_codec_train_2bit() {
let embeddings = generate_test_embeddings(1000, 32);
let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
assert_eq!(codec.num_centroids(), 16);
assert_eq!(codec.dim(), 32);
assert_eq!(codec.nbits(), 2);
}
#[test]
fn test_codec_train_4bit() {
let embeddings = generate_test_embeddings(1000, 32);
let codec = ResidualCodec::train(&embeddings, 32, 16, 4, 5).unwrap();
assert_eq!(codec.nbits(), 4);
}
#[test]
fn test_codec_train_insufficient_data() {
let embeddings = generate_test_embeddings(5, 32);
let result = ResidualCodec::train(&embeddings, 32, 16, 2, 5);
assert!(result.is_err());
}
#[test]
fn test_codec_train_invalid_nbits() {
let embeddings = generate_test_embeddings(100, 32);
let result = ResidualCodec::train(&embeddings, 32, 16, 3, 5);
assert!(result.is_err());
}
#[test]
fn test_codec_train_dim_zero() {
let result = ResidualCodec::train(&[], 0, 4, 2, 3);
assert!(result.is_err());
}
#[test]
#[should_panic(expected = "dim must be > 0")]
fn test_codec_with_params_dim_zero() {
let _ = ResidualCodec::with_params(vec![], 0, 0, vec![], vec![], 2);
}
#[test]
fn test_codec_compress() {
let embeddings = generate_test_embeddings(500, 32);
let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
let test_vec = &embeddings[0..32];
let (centroid_id, packed) = codec.compress(test_vec);
assert!(centroid_id < 16);
assert_eq!(packed.len(), codec.packed_size());
}
#[test]
fn test_codec_packed_size_2bit() {
let embeddings = generate_test_embeddings(500, 128);
let codec = ResidualCodec::train(&embeddings, 128, 16, 2, 5).unwrap();
assert_eq!(codec.packed_size(), 32);
}
#[test]
fn test_codec_packed_size_4bit() {
let embeddings = generate_test_embeddings(500, 128);
let codec = ResidualCodec::train(&embeddings, 128, 16, 4, 5).unwrap();
assert_eq!(codec.packed_size(), 64);
}
#[test]
fn test_pack_unpack_2bit() {
let embeddings = generate_test_embeddings(500, 8);
let codec = ResidualCodec::train(&embeddings, 8, 16, 2, 5).unwrap();
let codes: Vec<u8> = vec![0, 1, 2, 3, 0, 1, 2, 3];
let packed = codec.pack_codes(&codes);
let unpacked = codec.unpack_codes(&packed);
assert_eq!(codes, unpacked);
}
#[test]
fn test_pack_unpack_4bit() {
let embeddings = generate_test_embeddings(500, 8);
let codec = ResidualCodec::train(&embeddings, 8, 16, 4, 5).unwrap();
let codes: Vec<u8> = vec![0, 5, 10, 15, 1, 6, 11, 14];
let packed = codec.pack_codes(&codes);
let unpacked = codec.unpack_codes(&packed);
assert_eq!(codes, unpacked);
}
#[test]
fn test_decompress_score() {
let embeddings = generate_test_embeddings(500, 32);
let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
let query = &embeddings[0..32];
let doc = &embeddings[32..64];
let (centroid_id, packed) = codec.compress(doc);
let centroid_score = codec.centroid_score(query, centroid_id);
let approx_score = codec.decompress_score(query, centroid_id, centroid_score, &packed);
let exact_score: f32 = query.iter().zip(doc.iter()).map(|(q, d)| q * d).sum();
let error = (approx_score - exact_score).abs();
assert!(
error < exact_score.abs() * 0.5 + 1.0,
"Error too large: approx={}, exact={}, error={}",
approx_score,
exact_score,
error
);
}
#[test]
fn test_centroid_score() {
let embeddings = generate_test_embeddings(500, 32);
let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
let query = &embeddings[0..32];
let centroid = codec.centroid(0);
let expected: f32 = query.iter().zip(centroid.iter()).map(|(q, c)| q * c).sum();
let actual = codec.centroid_score(query, 0);
assert!((expected - actual).abs() < 1e-6);
}
#[test]
fn test_find_nearest_centroid() {
let embeddings = generate_test_embeddings(500, 32);
let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
let centroid_0 = codec.centroid(0).to_vec();
let nearest = codec.find_nearest_centroid(¢roid_0);
assert_eq!(nearest, 0);
}
#[test]
fn test_codec_builder() {
let config = WarpIndexConfig::new(2, 16, 32).with_kmeans_iterations(5);
let builder = ResidualCodecBuilder::new(config);
let embeddings = generate_test_embeddings(500, 32);
let codec = builder.train(&embeddings).unwrap();
assert_eq!(codec.num_centroids(), 16);
assert_eq!(codec.dim(), 32);
}
#[test]
fn test_codec_serialization() {
let embeddings = generate_test_embeddings(500, 16);
let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 5).unwrap();
let json = serde_json::to_string(&codec).unwrap();
let deserialized: ResidualCodec = serde_json::from_str(&json).unwrap();
assert_eq!(codec.num_centroids(), deserialized.num_centroids());
assert_eq!(codec.dim(), deserialized.dim());
assert_eq!(codec.nbits(), deserialized.nbits());
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_compress_produces_valid_centroid_id(
seed in 0u64..1000
) {
let mut embeddings = Vec::with_capacity(200 * 16);
let mut rng_state = seed;
for _ in 0..(200 * 16) {
rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
embeddings.push(((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0);
}
let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 3).unwrap();
let test_vec = &embeddings[0..16];
let (centroid_id, _) = codec.compress(test_vec);
prop_assert!(centroid_id < 8);
}
#[test]
fn prop_packed_size_matches_config(
nbits in prop::sample::select(vec![2u8, 4]),
dim in 8usize..64
) {
let num_samples = 100 * dim;
let embeddings = generate_test_embeddings(num_samples / dim, dim);
if let Ok(codec) = ResidualCodec::train(&embeddings, dim, 8, nbits, 3) {
let expected_size = (dim * nbits as usize + 7) / 8;
prop_assert_eq!(codec.packed_size(), expected_size);
}
}
}
}