1mod code;
2mod distance;
3mod sdc;
4
5pub use code::{PQCode, PQCode8, bytes_for_nbits};
6pub use distance::PQDistanceTable;
7pub use sdc::SDCTable;
8
9use std::fmt;
10use std::sync::Arc;
11
12use bb_core::{
13 Codec,
14 embedding::{Embedding, EmbeddingSpace},
15 index::{OpId, OpRef},
16};
17use bb_ml::KMeans;
18
19#[derive(Clone)]
44pub struct ProductQuantizer<S: EmbeddingSpace, const M: usize, const NBITS: usize>
45where
46 [(); bytes_for_nbits(NBITS)]:,
47{
48 space: S,
49 dsub: usize,
50 d: usize,
51 centroids: Arc<Vec<f32>>,
53 sdc_table: Option<Arc<SDCTable<M, NBITS>>>,
54 trained: bool,
55 next_op_id: u64,
56 subvec_buffer: Vec<f32>,
58}
59
60impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> fmt::Debug for ProductQuantizer<S, M, NBITS>
61where
62 [(); bytes_for_nbits(NBITS)]:,
63{
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("ProductQuantizer")
66 .field("M", &M)
67 .field("NBITS", &NBITS)
68 .field("ksub", &(1usize << NBITS))
69 .field("dsub", &self.dsub)
70 .field("trained", &self.trained)
71 .finish()
72 }
73}
74
75impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> PartialEq for ProductQuantizer<S, M, NBITS>
76where
77 [(); bytes_for_nbits(NBITS)]:,
78{
79 fn eq(&self, other: &Self) -> bool {
80 self.dsub == other.dsub
81 && self.trained == other.trained
82 && self.centroids == other.centroids
83 }
84}
85
86impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> Eq for ProductQuantizer<S, M, NBITS>
87where
88 [(); bytes_for_nbits(NBITS)]:,
89{}
90
91impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> ProductQuantizer<S, M, NBITS>
92where
93 [(); bytes_for_nbits(NBITS)]:,
94 <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
95{
96 pub const KSUB: usize = 1 << NBITS;
98
99 pub fn new(space: S) -> Self {
107 let d = S::EmbeddingData::length();
108 assert!(
109 d % M == 0,
110 "dimension {} must be divisible by M={}",
111 d,
112 M
113 );
114 let dsub = d / M;
115
116 Self {
117 space,
118 dsub,
119 d,
120 centroids: Arc::new(Vec::new()),
121 sdc_table: None,
122 trained: false,
123 next_op_id: 1,
124 subvec_buffer: vec![0.0; dsub],
126 }
127 }
128
129 pub fn space(&self) -> &S {
131 &self.space
132 }
133
134 fn alloc_op_id(&mut self) -> OpId {
135 let id = OpId(self.next_op_id);
136 self.next_op_id += 1;
137 id
138 }
139
140 pub fn m(&self) -> usize {
142 M
143 }
144
145 pub fn ksub(&self) -> usize {
147 Self::KSUB
148 }
149
150 pub fn dsub(&self) -> usize {
151 self.dsub
152 }
153
154 fn find_nearest_centroid(&self, subspace: usize) -> usize {
157 let ksub = Self::KSUB;
158 let mut best_idx = 0;
159 let mut best_dist = f32::MAX;
160
161 for k in 0..ksub {
162 let centroid_offset = (subspace * ksub + k) * self.dsub;
163 let centroid = &self.centroids[centroid_offset..centroid_offset + self.dsub];
164
165 let dist: f32 = self.subvec_buffer
166 .iter()
167 .zip(centroid.iter())
168 .map(|(&s, &c)| {
169 let diff = s - c;
170 diff * diff
171 })
172 .sum();
173
174 if dist < best_dist {
175 best_dist = dist;
176 best_idx = k;
177 }
178 }
179
180 best_idx
181 }
182
183 fn fill_subvec_buffer(&mut self, slice: &[<S::EmbeddingData as Embedding>::Scalar], subspace: usize) {
185 let start = subspace * self.dsub;
186 for i in 0..self.dsub {
187 self.subvec_buffer[i] = slice[start + i].into();
188 }
189 }
190
191 pub fn encode_embedding(&mut self, embedding: &S::EmbeddingData) -> PQCode<M, NBITS> {
193 assert!(self.trained, "codec must be trained before encoding");
194
195 let slice = embedding.as_slice();
196 let mut code = PQCode::<M, NBITS>::zeros();
197
198 for subspace in 0..M {
199 self.fill_subvec_buffer(slice, subspace);
200 let nearest = self.find_nearest_centroid(subspace);
201 code.set(subspace, nearest as u32);
202 }
203
204 code
205 }
206
207 pub fn decode_code(&self, code: &PQCode<M, NBITS>) -> S::EmbeddingData {
209 assert!(self.trained, "codec must be trained before decoding");
210
211 let ksub = Self::KSUB;
212 let mut result = vec![0.0f32; self.d];
213
214 for m in 0..M {
215 let c = code.get(m) as usize;
216 let centroid_offset = (m * ksub + c) * self.dsub;
217 let centroid = &self.centroids[centroid_offset..centroid_offset + self.dsub];
218
219 let start = m * self.dsub;
220 result[start..start + self.dsub].copy_from_slice(centroid);
221 }
222
223 let scalars: Vec<<S::EmbeddingData as Embedding>::Scalar> =
224 result.into_iter().map(|x| x.into()).collect();
225 S::EmbeddingData::from_slice(&scalars)
226 }
227
228 pub fn train_on(&mut self, data: &[S::EmbeddingData]) {
230 assert!(!data.is_empty(), "training data cannot be empty");
231
232 let ksub = Self::KSUB;
233 assert!(
234 data.len() >= ksub,
235 "need at least {} data points (ksub), got {}",
236 ksub,
237 data.len()
238 );
239
240 let mut centroids = vec![0.0; M * ksub * self.dsub];
242
243 for subspace in 0..M {
245 let subvectors: Vec<Vec<f32>> = data
247 .iter()
248 .map(|emb| {
249 let slice = emb.as_slice();
250 let start = subspace * self.dsub;
251 let end = start + self.dsub;
252 slice[start..end].iter().map(|&s| s.into()).collect()
253 })
254 .collect();
255
256 let kmeans = KMeans::fit(&subvectors, ksub, 25);
258
259 for (k, centroid) in kmeans.centroids.iter().enumerate() {
261 let offset = (subspace * ksub + k) * self.dsub;
262 centroids[offset..offset + self.dsub].copy_from_slice(centroid);
263 }
264 }
265
266 let sdc = SDCTable::from_centroids_with_distance(
268 ¢roids,
269 self.dsub,
270 S::slice_distance,
271 );
272 self.centroids = Arc::new(centroids);
273 self.sdc_table = Some(Arc::new(sdc));
274 self.trained = true;
275 }
276
277 pub fn build_distance_table(&mut self, query: &S::EmbeddingData) -> PQDistanceTable<S, M, NBITS> {
279 assert!(self.trained, "codec must be trained before distance computation");
280
281 let ksub = Self::KSUB;
282 let query_slice = query.as_slice();
283 let mut table = Vec::with_capacity(M * ksub);
284
285 for subspace in 0..M {
286 self.fill_subvec_buffer(query_slice, subspace);
287
288 for k in 0..ksub {
289 let centroid_offset = (subspace * ksub + k) * self.dsub;
290 let centroid = &self.centroids[centroid_offset..centroid_offset + self.dsub];
291
292 let dist = S::slice_distance(&self.subvec_buffer, centroid);
293
294 table.push(S::DistanceValue::from(dist));
295 }
296 }
297
298 PQDistanceTable::new(table, ksub)
299 }
300
301 pub fn sdc_table(&self) -> Option<&SDCTable<M, NBITS>> {
305 self.sdc_table.as_deref()
306 }
307}
308
309pub struct EagerOpRef<T, E> {
316 id: OpId,
317 result: Option<Result<T, E>>,
318 cached_error: Option<E>,
320}
321
322impl<T, E: Clone> EagerOpRef<T, E> {
323 pub fn ok(id: OpId, value: T) -> Self {
324 Self {
325 id,
326 result: Some(Ok(value)),
327 cached_error: None,
328 }
329 }
330
331 pub fn err(id: OpId, error: E) -> Self {
332 Self {
333 id,
334 result: Some(Err(error.clone())),
335 cached_error: Some(error),
336 }
337 }
338}
339
340pub trait FinishableError: Clone {
342 fn already_finished() -> Self;
344}
345
346impl FinishableError for PQError {
347 fn already_finished() -> Self {
348 PQError::AlreadyFinished
349 }
350}
351
352impl<T, E: FinishableError> OpRef for EagerOpRef<T, E> {
353 type Info = ();
354 type Stats = ();
355 type Result = T;
356 type Error = E;
357
358 fn id(&self) -> &OpId {
359 &self.id
360 }
361
362 fn info(&self) -> Option<Self::Info> {
363 Some(())
364 }
365
366 fn stats(&self) -> Option<Self::Stats> {
367 Some(())
368 }
369
370 fn is_finished(&self) -> bool {
371 true
372 }
373
374 fn finish(&mut self) -> Result<Self::Result, Self::Error> {
375 match self.result.take() {
376 Some(result) => result,
377 None => {
378 Err(self.cached_error.clone().unwrap_or_else(E::already_finished))
380 }
381 }
382 }
383}
384
385#[derive(Debug, Clone, PartialEq, Eq)]
387pub enum PQError {
388 NotTrained,
390 AlreadyFinished,
392 #[cfg(feature = "codec")]
394 SerializationError(String),
395}
396
397impl std::fmt::Display for PQError {
398 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399 match self {
400 PQError::NotTrained => write!(f, "codec not trained"),
401 PQError::AlreadyFinished => write!(f, "operation already finished"),
402 #[cfg(feature = "codec")]
403 PQError::SerializationError(e) => write!(f, "serialization error: {}", e),
404 }
405 }
406}
407
408impl std::error::Error for PQError {}
409
410impl<S: EmbeddingSpace + Default, const M: usize, const NBITS: usize> Default
411 for ProductQuantizer<S, M, NBITS>
412where
413 [(); bytes_for_nbits(NBITS)]:,
414 <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
415{
416 fn default() -> Self {
417 Self::new(S::default())
418 }
419}
420
421#[cfg(feature = "serde")]
428#[derive(serde::Serialize, serde::Deserialize)]
429pub struct PQCodebook {
430 pub centroids: Vec<f32>,
431 pub sdc_table: Vec<f32>,
432 pub sdc_ksub: usize,
433 pub dsub: usize,
434 pub d: usize,
435 pub m: usize,
436 pub nbits: usize,
437}
438
439#[cfg(feature = "codec")]
440impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> ProductQuantizer<S, M, NBITS>
441where
442 [(); bytes_for_nbits(NBITS)]:,
443 <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
444{
445 pub fn save_codebook(&self, path: &std::path::Path) -> Result<(), PQError> {
447 if !self.trained {
448 return Err(PQError::NotTrained);
449 }
450 let sdc = self.sdc_table.as_ref().ok_or(PQError::NotTrained)?;
451 let codebook = PQCodebook {
452 centroids: (*self.centroids).clone(),
453 sdc_table: sdc.table_data().to_vec(),
454 sdc_ksub: sdc.ksub(),
455 dsub: self.dsub,
456 d: self.d,
457 m: M,
458 nbits: NBITS,
459 };
460 let encoded = bincode::serialize(&codebook)
461 .map_err(|e| PQError::SerializationError(e.to_string()))?;
462 std::fs::write(path, encoded)
463 .map_err(|e| PQError::SerializationError(e.to_string()))?;
464 Ok(())
465 }
466
467 pub fn load_codebook(path: &std::path::Path, space: S) -> Result<Self, PQError> {
469 let data = std::fs::read(path)
470 .map_err(|e| PQError::SerializationError(e.to_string()))?;
471 let codebook: PQCodebook = bincode::deserialize(&data)
472 .map_err(|e| PQError::SerializationError(e.to_string()))?;
473
474 if codebook.m != M || codebook.nbits != NBITS {
475 return Err(PQError::SerializationError(format!(
476 "Codebook M={}, NBITS={} does not match expected M={}, NBITS={}",
477 codebook.m, codebook.nbits, M, NBITS
478 )));
479 }
480
481 let sdc_table = SDCTable::from_raw(codebook.sdc_table, codebook.sdc_ksub);
482
483 Ok(Self {
484 space,
485 dsub: codebook.dsub,
486 d: codebook.d,
487 centroids: Arc::new(codebook.centroids),
488 sdc_table: Some(Arc::new(sdc_table)),
489 trained: true,
490 next_op_id: 0,
491 subvec_buffer: vec![0.0; codebook.dsub],
492 })
493 }
494}
495
496impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> Codec<S> for ProductQuantizer<S, M, NBITS>
497where
498 [(); bytes_for_nbits(NBITS)]:,
499 <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
500{
501 type Encoded = PQCode<M, NBITS>;
502 type EncodeRef<'b> = EagerOpRef<PQCode<M, NBITS>, PQError> where Self: 'b;
503 type DecodeRef<'b> = EagerOpRef<S::EmbeddingData, PQError> where Self: 'b;
504 type TrainRef<'b> = EagerOpRef<(), PQError> where Self: 'b;
505 type ObserveRef<'b> = EagerOpRef<(), PQError> where Self: 'b;
506
507 fn encode(&mut self, embedding: &S::EmbeddingData) -> Self::EncodeRef<'_> {
508 let id = OpId(0);
509 if !self.trained {
510 return EagerOpRef::err(id, PQError::NotTrained);
511 }
512 EagerOpRef::ok(id, self.encode_embedding(embedding))
513 }
514
515 fn encode_batch(&mut self, embeddings: &[S::EmbeddingData]) -> Vec<Self::EncodeRef<'_>> {
516 embeddings.iter().map(|e| {
517 let id = OpId(0);
518 if !self.trained {
519 return EagerOpRef::err(id, PQError::NotTrained);
520 }
521 EagerOpRef::ok(id, self.encode_embedding(e))
522 }).collect()
523 }
524
525 fn decode(&self, encoded: &Self::Encoded) -> Self::DecodeRef<'_> {
526 let id = OpId(0);
527 if !self.trained {
528 return EagerOpRef::err(id, PQError::NotTrained);
529 }
530 EagerOpRef::ok(id, self.decode_code(encoded))
531 }
532
533 fn decode_batch(&self, encoded: &[Self::Encoded]) -> Vec<Self::DecodeRef<'_>> {
534 encoded.iter().map(|e| self.decode(e)).collect()
535 }
536
537 fn code_size(&self) -> Option<usize> {
538 Some(PQCode::<M, NBITS>::TOTAL_BYTES)
539 }
540
541 fn train(&mut self, embeddings: &[S::EmbeddingData]) -> Self::TrainRef<'_> {
542 let id = self.alloc_op_id();
543 self.train_on(embeddings);
544 EagerOpRef::ok(id, ())
545 }
546
547 fn observe(&mut self, _embedding: &S::EmbeddingData) -> Self::ObserveRef<'_> {
548 EagerOpRef::ok(self.alloc_op_id(), ())
550 }
551
552 fn observe_batch(&mut self, embeddings: &[S::EmbeddingData]) -> Vec<Self::ObserveRef<'_>> {
553 embeddings.iter().map(|_| self.observe(&S::EmbeddingData::zeros())).collect()
554 }
555
556 fn is_trained(&self) -> bool {
557 self.trained
558 }
559}
560
561impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> EmbeddingSpace for ProductQuantizer<S, M, NBITS>
563where
564 [(); bytes_for_nbits(NBITS)]:,
565 <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
566{
567 type EmbeddingData = PQCode<M, NBITS>;
568 type DistanceValue = S::DistanceValue;
569 type Prepared = PQCode<M, NBITS>;
570
571 fn space_id(&self) -> &'static str {
572 "pq"
573 }
574
575 fn distance(&self, lhs: &Self::EmbeddingData, rhs: &Self::EmbeddingData) -> Self::DistanceValue {
576 let sdc = self.sdc_table.as_ref().expect("ProductQuantizer must be trained before computing distances");
577 S::DistanceValue::from(sdc.distance(lhs, rhs))
578 }
579
580 fn prepare(&self, embedding: &Self::EmbeddingData) -> Self::Prepared {
581 embedding.clone()
582 }
583
584 fn distance_prepared(
585 &self,
586 prepared: &Self::Prepared,
587 target: &Self::EmbeddingData,
588 ) -> Self::DistanceValue {
589 self.distance(prepared, target)
590 }
591
592 fn length() -> usize {
593 PQCode::<M, NBITS>::TOTAL_BYTES
594 }
595
596 fn slice_distance(a: &[f32], b: &[f32]) -> f32 {
597 S::slice_distance(a, b)
598 }
599
600 fn infinite_mapping(native_distance: &Self::DistanceValue) -> f32 {
601 S::infinite_mapping(native_distance)
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608 use bb_core::embedding::{F32Embedding, F32L2Space};
609
610 type Space = F32L2Space<8>;
611
612 fn make_test_vectors(n: usize) -> Vec<F32Embedding<8>> {
613 (0..n)
614 .map(|i| {
615 let val = i as f32;
616 F32Embedding([val, val + 0.1, val + 0.2, val + 0.3, val + 0.4, val + 0.5, val + 0.6, val + 0.7])
617 })
618 .collect()
619 }
620
621 #[test]
622 fn test_pq_creation() {
623 let space = F32L2Space::<8>;
624 let pq = ProductQuantizer::<Space, 2, 8>::new(space);
625 assert_eq!(pq.m(), 2);
626 assert_eq!(pq.ksub(), 256);
627 assert_eq!(pq.dsub(), 4);
628 assert!(!pq.is_trained());
629 }
630
631 #[test]
632 fn test_pq_creation_nbits2() {
633 let space = F32L2Space::<8>;
634 let pq = ProductQuantizer::<Space, 2, 2>::new(space);
635 assert_eq!(pq.m(), 2);
636 assert_eq!(pq.ksub(), 4); assert_eq!(pq.dsub(), 4);
638 assert!(!pq.is_trained());
639 }
640
641 #[test]
642 fn test_pq_train_and_encode() {
643 let space = F32L2Space::<8>;
644 let mut pq = ProductQuantizer::<Space, 2, 2>::new(space); let data = make_test_vectors(100);
646
647 pq.train_on(&data);
648 assert!(pq.is_trained());
649
650 let code = pq.encode_embedding(&data[0]);
651 assert_eq!(code.m(), 2);
652 }
653
654 #[test]
655 fn test_pq_encode_decode() {
656 let space = F32L2Space::<8>;
657 let mut pq = ProductQuantizer::<Space, 2, 4>::new(space);
659 let data = make_test_vectors(100);
660
661 pq.train_on(&data);
662
663 let original = &data[50];
664 let code = pq.encode_embedding(original);
665 let decoded = pq.decode_code(&code);
666
667 let orig_slice = original.as_slice();
669 let dec_slice = decoded.as_slice();
670
671 for i in 0..8 {
672 let diff = (orig_slice[i] - dec_slice[i]).abs();
673 assert!(diff < 10.0, "dimension {} differs by {}", i, diff);
674 }
675 }
676
677 #[test]
678 fn test_pq_adc_distance() {
679 let space = F32L2Space::<8>;
680 let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
681 let data = make_test_vectors(100);
682
683 pq.train_on(&data);
684
685 let query = &data[0];
686 let target = pq.encode_embedding(&data[1]);
687
688 let table = pq.build_distance_table(query);
689 let dist = table.distance(&target);
690
691 assert!(dist.value() >= 0.0);
693 }
694
695 #[test]
696 fn test_pq_sdc() {
697 let space = F32L2Space::<8>;
698 let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
699 let data = make_test_vectors(100);
700
701 pq.train_on(&data);
702
703 let code1 = pq.encode_embedding(&data[0]);
705 let code2 = pq.encode_embedding(&data[1]);
706
707 let sdc = pq.sdc_table().expect("should have SDC table after training");
708
709 assert_eq!(sdc.distance(&code1, &code1), 0.0);
711
712 let dist = sdc.distance(&code1, &code2);
714 assert!(dist >= 0.0);
715 }
716
717 #[test]
718 fn test_pq_embedding_space() {
719 let space = F32L2Space::<8>;
720 let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
721 let data = make_test_vectors(100);
722
723 pq.train_on(&data);
724
725 let code1 = pq.encode_embedding(&data[0]);
726 let code2 = pq.encode_embedding(&data[1]);
727
728 let dist = pq.distance(&code1, &code2);
730 assert!(dist.value() >= 0.0);
731
732 assert_eq!(pq.distance(&code1, &code1).value(), 0.0);
734
735 let prepared = pq.prepare(&code1);
737 let dist_prepared = pq.distance_prepared(&prepared, &code2);
738 assert_eq!(dist, dist_prepared);
739 }
740
741 #[test]
742 fn test_codec_trait() {
743 let space = F32L2Space::<8>;
744 let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
745 let data = make_test_vectors(100);
746
747 let mut train_ref = pq.train(&data);
748 assert!(train_ref.is_finished());
749 train_ref.finish().unwrap();
750
751 assert!(pq.is_trained());
752
753 let mut encode_ref = pq.encode(&data[0]);
754 assert!(encode_ref.is_finished());
755 let code = encode_ref.finish().unwrap();
756
757 let mut decode_ref = pq.decode(&code);
758 assert!(decode_ref.is_finished());
759 let _decoded = decode_ref.finish().unwrap();
760 }
761
762 #[test]
763 fn test_code_size() {
764 let pq = ProductQuantizer::<Space, 4, 8>::new(F32L2Space::<8>);
766 assert_eq!(pq.code_size(), Some(4));
767
768 let pq2 = ProductQuantizer::<Space, 4, 10>::new(F32L2Space::<8>);
770 assert_eq!(pq2.code_size(), Some(8));
771 }
772
773 #[test]
774 #[ignore] fn test_pq_nbits10() {
776 let space = F32L2Space::<8>;
779 let mut pq = ProductQuantizer::<Space, 2, 10>::new(space);
780
781 let data: Vec<F32Embedding<8>> = (0..2000)
783 .map(|i| {
784 let val = (i as f32) * 0.01;
785 F32Embedding([val, val + 0.1, val + 0.2, val + 0.3, val + 0.4, val + 0.5, val + 0.6, val + 0.7])
786 })
787 .collect();
788
789 pq.train_on(&data);
790 assert!(pq.is_trained());
791 assert_eq!(pq.ksub(), 1024);
792
793 let code = pq.encode_embedding(&data[500]);
794 let idx0 = code.get(0);
796 let idx1 = code.get(1);
797 assert!(idx0 < 1024);
798 assert!(idx1 < 1024);
799 }
800}