1use crate::multivector::types::WarpIndexConfig;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ResidualCodec {
35 centroids: Vec<f32>,
37 num_centroids: usize,
39 dim: usize,
41 bucket_cutoffs: Vec<f32>,
43 bucket_weights: Vec<f32>,
45 nbits: u8,
47}
48
49impl ResidualCodec {
50 pub fn train(
64 embeddings: &[f32],
65 dim: usize,
66 num_centroids: usize,
67 nbits: u8,
68 iterations: usize,
69 ) -> Result<Self> {
70 if nbits != 2 && nbits != 4 {
71 return Err(crate::Error::InvalidInput("nbits must be 2 or 4".to_string()));
72 }
73
74 if dim == 0 {
75 return Err(crate::Error::InvalidInput("dim must be > 0".to_string()));
76 }
77 let n = embeddings.len() / dim;
78 if n < num_centroids {
79 return Err(crate::Error::InvalidInput(format!(
80 "Insufficient training data: {n} samples for {num_centroids} centroids"
81 )));
82 }
83
84 contract_pre_embedding_lookup!(embeddings);
86
87 let centroids = Self::kmeans_clustering(embeddings, dim, num_centroids, iterations);
89
90 let residuals = Self::compute_all_residuals(embeddings, dim, ¢roids, num_centroids);
92
93 let (bucket_cutoffs, bucket_weights) =
95 Self::learn_quantization_params(&residuals, dim, nbits);
96
97 Ok(Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits })
98 }
99
100 #[must_use]
106 pub fn with_params(
107 centroids: Vec<f32>,
108 num_centroids: usize,
109 dim: usize,
110 bucket_cutoffs: Vec<f32>,
111 bucket_weights: Vec<f32>,
112 nbits: u8,
113 ) -> Self {
114 assert!(dim > 0, "dim must be > 0: division by zero in centroid/residual arithmetic");
115 Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits }
116 }
117
118 #[must_use]
120 pub fn num_centroids(&self) -> usize {
121 self.num_centroids
122 }
123
124 #[must_use]
126 pub fn dim(&self) -> usize {
127 self.dim
128 }
129
130 #[must_use]
132 pub fn nbits(&self) -> u8 {
133 self.nbits
134 }
135
136 #[must_use]
138 pub fn packed_size(&self) -> usize {
139 (self.dim * self.nbits as usize + 7) / 8
140 }
141
142 #[must_use]
144 pub fn centroid(&self, id: usize) -> &[f32] {
145 let start = id * self.dim;
146 &self.centroids[start..start + self.dim]
147 }
148
149 #[must_use]
151 pub fn centroids(&self) -> &[f32] {
152 &self.centroids
153 }
154
155 #[must_use]
157 pub fn find_nearest_centroid(&self, embedding: &[f32]) -> usize {
158 contract_pre_configuration!(embedding);
160 let mut best_id = 0;
161 let mut best_dist = f32::MAX;
162
163 for c in 0..self.num_centroids {
164 let centroid = self.centroid(c);
165 let dist = Self::squared_distance(embedding, centroid);
166 if dist < best_dist {
167 best_dist = dist;
168 best_id = c;
169 }
170 }
171
172 best_id
173 }
174
175 #[must_use]
177 pub fn compress(&self, embedding: &[f32]) -> (usize, Vec<u8>) {
178 contract_pre_embedding_lookup!(embedding);
180 let centroid_id = self.find_nearest_centroid(embedding);
182 let centroid = self.centroid(centroid_id);
183
184 let residual: Vec<f32> =
186 embedding.iter().zip(centroid.iter()).map(|(e, c)| e - c).collect();
187
188 let codes = self.quantize_residual(&residual);
190
191 let packed = self.pack_codes(&codes);
193
194 (centroid_id, packed)
195 }
196
197 #[must_use]
208 pub fn decompress_score(
209 &self,
210 query_token: &[f32],
211 centroid_id: usize,
212 centroid_score: f32,
213 packed_residual: &[u8],
214 ) -> f32 {
215 let _ = centroid_id; let codes = self.unpack_codes(packed_residual);
219
220 let num_buckets = 1usize << self.nbits;
222 let residual_score: f32 = codes
223 .iter()
224 .enumerate()
225 .map(|(d, &code)| {
226 let weight_idx = d * num_buckets + code as usize;
227 query_token[d] * self.bucket_weights[weight_idx]
228 })
229 .sum();
230
231 centroid_score + residual_score
232 }
233
234 #[must_use]
236 pub fn centroid_score(&self, query_token: &[f32], centroid_id: usize) -> f32 {
237 let centroid = self.centroid(centroid_id);
238 Self::dot_product(query_token, centroid)
239 }
240
241 fn quantize_residual(&self, residual: &[f32]) -> Vec<u8> {
243 let num_buckets = 1usize << self.nbits;
244
245 residual
246 .iter()
247 .enumerate()
248 .map(|(d, &value)| {
249 let cutoff_start = d * (num_buckets - 1);
251 let cutoffs = &self.bucket_cutoffs[cutoff_start..cutoff_start + num_buckets - 1];
252
253 cutoffs.iter().position(|&c| value < c).unwrap_or(num_buckets - 1) as u8
255 })
256 .collect()
257 }
258
259 fn pack_codes(&self, codes: &[u8]) -> Vec<u8> {
261 match self.nbits {
262 2 => {
263 codes
265 .chunks(4)
266 .map(|chunk| {
267 let mut byte = 0u8;
268 for (i, &code) in chunk.iter().enumerate() {
269 byte |= (code & 0x03) << (i * 2);
270 }
271 byte
272 })
273 .collect()
274 }
275 4 => {
276 codes
278 .chunks(2)
279 .map(|chunk| {
280 let low = chunk.first().copied().unwrap_or(0) & 0x0F;
281 let high = chunk.get(1).copied().unwrap_or(0) & 0x0F;
282 low | (high << 4)
283 })
284 .collect()
285 }
286 _ => panic!("Unsupported nbits: {}", self.nbits),
287 }
288 }
289
290 fn unpack_codes(&self, packed: &[u8]) -> Vec<u8> {
292 match self.nbits {
293 2 => packed
294 .iter()
295 .flat_map(|&byte| (0..4).map(move |i| (byte >> (i * 2)) & 0x03))
296 .take(self.dim)
297 .collect(),
298 4 => packed
299 .iter()
300 .flat_map(|&byte| vec![byte & 0x0F, (byte >> 4) & 0x0F])
301 .take(self.dim)
302 .collect(),
303 _ => panic!("Unsupported nbits: {}", self.nbits),
304 }
305 }
306
307 fn kmeans_clustering(embeddings: &[f32], dim: usize, k: usize, iterations: usize) -> Vec<f32> {
311 let n = embeddings.len() / dim;
312
313 let mut centroids = Self::kmeans_plus_plus_init(embeddings, dim, k);
315 let mut assignments = vec![0usize; n];
316
317 for _ in 0..iterations {
318 for i in 0..n {
320 let point = &embeddings[i * dim..(i + 1) * dim];
321 let mut best_dist = f32::MAX;
322 let mut best_c = 0;
323
324 for c in 0..k {
325 let centroid = ¢roids[c * dim..(c + 1) * dim];
326 let dist = Self::squared_distance(point, centroid);
327 if dist < best_dist {
328 best_dist = dist;
329 best_c = c;
330 }
331 }
332 assignments[i] = best_c;
333 }
334
335 let mut new_centroids = vec![0.0f32; k * dim];
337 let mut counts = vec![0usize; k];
338
339 for i in 0..n {
340 let c = assignments[i];
341 counts[c] += 1;
342 let point = &embeddings[i * dim..(i + 1) * dim];
343 for d in 0..dim {
344 new_centroids[c * dim + d] += point[d];
345 }
346 }
347
348 for c in 0..k {
349 if counts[c] > 0 {
350 for d in 0..dim {
351 new_centroids[c * dim + d] /= counts[c] as f32;
352 }
353 } else {
354 for d in 0..dim {
356 new_centroids[c * dim + d] = centroids[c * dim + d];
357 }
358 }
359 }
360
361 centroids = new_centroids;
362 }
363
364 centroids
365 }
366
367 fn kmeans_plus_plus_init(embeddings: &[f32], dim: usize, k: usize) -> Vec<f32> {
369 let n = embeddings.len() / dim;
370 let mut centroids = Vec::with_capacity(k * dim);
371 let mut rng_state = 42u64; let first_idx = Self::simple_random(&mut rng_state, n);
375 centroids.extend_from_slice(&embeddings[first_idx * dim..(first_idx + 1) * dim]);
376
377 let mut distances = vec![f32::MAX; n];
378
379 for _ in 1..k {
380 let num_centroids = centroids.len() / dim;
381
382 for i in 0..n {
384 let point = &embeddings[i * dim..(i + 1) * dim];
385 let centroid = ¢roids[(num_centroids - 1) * dim..num_centroids * dim];
386 let dist = Self::squared_distance(point, centroid);
387 distances[i] = distances[i].min(dist);
388 }
389
390 let total: f32 = distances.iter().sum();
392 if total <= 0.0 {
393 let idx = Self::simple_random(&mut rng_state, n);
395 centroids.extend_from_slice(&embeddings[idx * dim..(idx + 1) * dim]);
396 continue;
397 }
398
399 let threshold = Self::simple_random_f32(&mut rng_state) * total;
400 let mut cumsum = 0.0f32;
401 let mut chosen = 0;
402
403 for (i, &d) in distances.iter().enumerate() {
404 cumsum += d;
405 if cumsum >= threshold {
406 chosen = i;
407 break;
408 }
409 }
410
411 centroids.extend_from_slice(&embeddings[chosen * dim..(chosen + 1) * dim]);
412 }
413
414 centroids
415 }
416
417 fn compute_all_residuals(
419 embeddings: &[f32],
420 dim: usize,
421 centroids: &[f32],
422 num_centroids: usize,
423 ) -> Vec<f32> {
424 let n = embeddings.len() / dim;
425 let mut residuals = Vec::with_capacity(n * dim);
426
427 for i in 0..n {
428 let point = &embeddings[i * dim..(i + 1) * dim];
429
430 let mut best_c = 0;
432 let mut best_dist = f32::MAX;
433 for c in 0..num_centroids {
434 let centroid = ¢roids[c * dim..(c + 1) * dim];
435 let dist = Self::squared_distance(point, centroid);
436 if dist < best_dist {
437 best_dist = dist;
438 best_c = c;
439 }
440 }
441
442 let centroid = ¢roids[best_c * dim..(best_c + 1) * dim];
444 for d in 0..dim {
445 residuals.push(point[d] - centroid[d]);
446 }
447 }
448
449 residuals
450 }
451
452 fn learn_quantization_params(residuals: &[f32], dim: usize, nbits: u8) -> (Vec<f32>, Vec<f32>) {
454 let num_buckets = 1usize << nbits;
455 let n = residuals.len() / dim;
456
457 let mut cutoffs = Vec::with_capacity(dim * (num_buckets - 1));
458 let mut weights = Vec::with_capacity(dim * num_buckets);
459
460 for d in 0..dim {
461 let mut values: Vec<f32> = (0..n).map(|i| residuals[i * dim + d]).collect();
463 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
464
465 for b in 1..num_buckets {
467 let quantile_idx = (b * n) / num_buckets;
468 cutoffs.push(values[quantile_idx.min(n - 1)]);
469 }
470
471 for b in 0..num_buckets {
473 let start = (b * n) / num_buckets;
474 let end = ((b + 1) * n) / num_buckets;
475 let end = end.max(start + 1).min(n);
476
477 let sum: f32 = values[start..end].iter().sum();
478 let mean = sum / (end - start) as f32;
479 weights.push(mean);
480 }
481 }
482
483 (cutoffs, weights)
484 }
485
486 fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
489 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
490 }
491
492 fn dot_product(a: &[f32], b: &[f32]) -> f32 {
493 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
494 }
495
496 fn simple_random(state: &mut u64, max: usize) -> usize {
497 *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
498 ((*state >> 33) as usize) % max
499 }
500
501 fn simple_random_f32(state: &mut u64) -> f32 {
502 *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
503 ((*state >> 33) as f32) / (u32::MAX as f32)
504 }
505}
506
507pub struct ResidualCodecBuilder {
509 config: WarpIndexConfig,
510}
511
512impl ResidualCodecBuilder {
513 #[must_use]
515 pub fn new(config: WarpIndexConfig) -> Self {
516 Self { config }
517 }
518
519 pub fn train(&self, embeddings: &[f32]) -> Result<ResidualCodec> {
521 contract_pre_embedding_lookup!(embeddings);
523 ResidualCodec::train(
524 embeddings,
525 self.config.token_dim,
526 self.config.num_centroids,
527 self.config.nbits,
528 self.config.kmeans_iterations,
529 )
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 fn generate_test_embeddings(n: usize, dim: usize) -> Vec<f32> {
538 let mut embeddings = Vec::with_capacity(n * dim);
539 let mut rng_state = 12345u64;
540
541 for _ in 0..(n * dim) {
542 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
543 let val = ((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
544 embeddings.push(val);
545 }
546
547 embeddings
548 }
549
550 #[test]
553 fn test_codec_train_2bit() {
554 let embeddings = generate_test_embeddings(1000, 32);
555 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
556
557 assert_eq!(codec.num_centroids(), 16);
558 assert_eq!(codec.dim(), 32);
559 assert_eq!(codec.nbits(), 2);
560 }
561
562 #[test]
563 fn test_codec_train_4bit() {
564 let embeddings = generate_test_embeddings(1000, 32);
565 let codec = ResidualCodec::train(&embeddings, 32, 16, 4, 5).unwrap();
566
567 assert_eq!(codec.nbits(), 4);
568 }
569
570 #[test]
571 fn test_codec_train_insufficient_data() {
572 let embeddings = generate_test_embeddings(5, 32);
573 let result = ResidualCodec::train(&embeddings, 32, 16, 2, 5);
574
575 assert!(result.is_err());
576 }
577
578 #[test]
579 fn test_codec_train_invalid_nbits() {
580 let embeddings = generate_test_embeddings(100, 32);
581 let result = ResidualCodec::train(&embeddings, 32, 16, 3, 5);
582
583 assert!(result.is_err());
584 }
585
586 #[test]
588 fn test_codec_train_dim_zero() {
589 let result = ResidualCodec::train(&[], 0, 4, 2, 3);
590 assert!(result.is_err());
591 }
592
593 #[test]
595 #[should_panic(expected = "dim must be > 0")]
596 fn test_codec_with_params_dim_zero() {
597 let _ = ResidualCodec::with_params(vec![], 0, 0, vec![], vec![], 2);
598 }
599
600 #[test]
603 fn test_codec_compress() {
604 let embeddings = generate_test_embeddings(500, 32);
605 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
606
607 let test_vec = &embeddings[0..32];
608 let (centroid_id, packed) = codec.compress(test_vec);
609
610 assert!(centroid_id < 16);
611 assert_eq!(packed.len(), codec.packed_size());
612 }
613
614 #[test]
615 fn test_codec_packed_size_2bit() {
616 let embeddings = generate_test_embeddings(500, 128);
617 let codec = ResidualCodec::train(&embeddings, 128, 16, 2, 5).unwrap();
618
619 assert_eq!(codec.packed_size(), 32);
621 }
622
623 #[test]
624 fn test_codec_packed_size_4bit() {
625 let embeddings = generate_test_embeddings(500, 128);
626 let codec = ResidualCodec::train(&embeddings, 128, 16, 4, 5).unwrap();
627
628 assert_eq!(codec.packed_size(), 64);
630 }
631
632 #[test]
635 fn test_pack_unpack_2bit() {
636 let embeddings = generate_test_embeddings(500, 8);
637 let codec = ResidualCodec::train(&embeddings, 8, 16, 2, 5).unwrap();
638
639 let codes: Vec<u8> = vec![0, 1, 2, 3, 0, 1, 2, 3];
640 let packed = codec.pack_codes(&codes);
641 let unpacked = codec.unpack_codes(&packed);
642
643 assert_eq!(codes, unpacked);
644 }
645
646 #[test]
647 fn test_pack_unpack_4bit() {
648 let embeddings = generate_test_embeddings(500, 8);
649 let codec = ResidualCodec::train(&embeddings, 8, 16, 4, 5).unwrap();
650
651 let codes: Vec<u8> = vec![0, 5, 10, 15, 1, 6, 11, 14];
652 let packed = codec.pack_codes(&codes);
653 let unpacked = codec.unpack_codes(&packed);
654
655 assert_eq!(codes, unpacked);
656 }
657
658 #[test]
661 fn test_decompress_score() {
662 let embeddings = generate_test_embeddings(500, 32);
663 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
664
665 let query = &embeddings[0..32];
666 let doc = &embeddings[32..64];
667
668 let (centroid_id, packed) = codec.compress(doc);
670
671 let centroid_score = codec.centroid_score(query, centroid_id);
673
674 let approx_score = codec.decompress_score(query, centroid_id, centroid_score, &packed);
676
677 let exact_score: f32 = query.iter().zip(doc.iter()).map(|(q, d)| q * d).sum();
679
680 let error = (approx_score - exact_score).abs();
682 assert!(
683 error < exact_score.abs() * 0.5 + 1.0,
684 "Error too large: approx={}, exact={}, error={}",
685 approx_score,
686 exact_score,
687 error
688 );
689 }
690
691 #[test]
692 fn test_centroid_score() {
693 let embeddings = generate_test_embeddings(500, 32);
694 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
695
696 let query = &embeddings[0..32];
697 let centroid = codec.centroid(0);
698
699 let expected: f32 = query.iter().zip(centroid.iter()).map(|(q, c)| q * c).sum();
700 let actual = codec.centroid_score(query, 0);
701
702 assert!((expected - actual).abs() < 1e-6);
703 }
704
705 #[test]
708 fn test_find_nearest_centroid() {
709 let embeddings = generate_test_embeddings(500, 32);
710 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
711
712 let centroid_0 = codec.centroid(0).to_vec();
714 let nearest = codec.find_nearest_centroid(¢roid_0);
715 assert_eq!(nearest, 0);
716 }
717
718 #[test]
721 fn test_codec_builder() {
722 let config = WarpIndexConfig::new(2, 16, 32).with_kmeans_iterations(5);
723 let builder = ResidualCodecBuilder::new(config);
724
725 let embeddings = generate_test_embeddings(500, 32);
726 let codec = builder.train(&embeddings).unwrap();
727
728 assert_eq!(codec.num_centroids(), 16);
729 assert_eq!(codec.dim(), 32);
730 }
731
732 #[test]
735 fn test_codec_serialization() {
736 let embeddings = generate_test_embeddings(500, 16);
737 let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 5).unwrap();
738
739 let json = serde_json::to_string(&codec).unwrap();
740 let deserialized: ResidualCodec = serde_json::from_str(&json).unwrap();
741
742 assert_eq!(codec.num_centroids(), deserialized.num_centroids());
743 assert_eq!(codec.dim(), deserialized.dim());
744 assert_eq!(codec.nbits(), deserialized.nbits());
745 }
746
747 use proptest::prelude::*;
750
751 proptest! {
752 #[test]
753 fn prop_compress_produces_valid_centroid_id(
754 seed in 0u64..1000
755 ) {
756 let mut embeddings = Vec::with_capacity(200 * 16);
757 let mut rng_state = seed;
758 for _ in 0..(200 * 16) {
759 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
760 embeddings.push(((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0);
761 }
762
763 let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 3).unwrap();
764 let test_vec = &embeddings[0..16];
765 let (centroid_id, _) = codec.compress(test_vec);
766
767 prop_assert!(centroid_id < 8);
768 }
769
770 #[test]
771 fn prop_packed_size_matches_config(
772 nbits in prop::sample::select(vec![2u8, 4]),
773 dim in 8usize..64
774 ) {
775 let num_samples = 100 * dim;
776 let embeddings = generate_test_embeddings(num_samples / dim, dim);
777
778 if let Ok(codec) = ResidualCodec::train(&embeddings, dim, 8, nbits, 3) {
779 let expected_size = (dim * nbits as usize + 7) / 8;
780 prop_assert_eq!(codec.packed_size(), expected_size);
781 }
782 }
783 }
784}