Skip to main content

bytesandbrains_codec/pq/
mod.rs

1mod code;
2mod distance;
3mod sdc;
4
5pub use code::{PQCode, PQCode8, bytes_for_nbits};
6pub use distance::PQDistanceTable;
7pub use sdc::SDCTable;
8
9use std::fmt;
10use std::sync::Arc;
11
12use bb_core::{
13    Codec,
14    embedding::{Embedding, EmbeddingSpace},
15    index::{OpId, OpRef},
16};
17use bb_ml::KMeans;
18
19/// Product Quantizer for vector compression and fast distance computation.
20///
21/// Product Quantization splits a D-dimensional vector into M subvectors,
22/// learns a codebook for each subspace via k-means, and encodes each
23/// subvector as the index of its nearest centroid.
24///
25/// Const generics:
26/// - M: number of subquantizers
27/// - NBITS: bits per centroid index (determines storage and centroid count)
28///
29/// The embedding dimension must be divisible by M.
30///
31/// After training, ProductQuantizer implements `EmbeddingSpace` with
32/// `EmbeddingData = PQCode<M, NBITS>`, allowing direct use with FlatIndex
33/// and other structures.
34///
35/// This implementation supports:
36/// - Training via k-means++ on each subspace
37/// - Encoding: vector -> PQCode<M, NBITS>
38/// - Decoding: PQCode<M, NBITS> -> reconstructed vector
39/// - ADC: Asymmetric Distance Computation via precomputed distance tables
40/// - SDC: Symmetric Distance Computation via precomputed centroid-to-centroid distances
41///
42/// Note: PQ internally uses L2 distance for subspace quantization.
43#[derive(Clone)]
44pub struct ProductQuantizer<S: EmbeddingSpace, const M: usize, const NBITS: usize>
45where
46    [(); bytes_for_nbits(NBITS)]:,
47{
48    space: S,
49    dsub: usize,
50    d: usize,
51    /// Shared codebook data — Arc so cloning PQ across actors is cheap (~pointer copy).
52    centroids: Arc<Vec<f32>>,
53    sdc_table: Option<Arc<SDCTable<M, NBITS>>>,
54    trained: bool,
55    next_op_id: u64,
56    /// Reusable buffer for subvector operations to avoid per-call allocations
57    subvec_buffer: Vec<f32>,
58}
59
60impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> fmt::Debug for ProductQuantizer<S, M, NBITS>
61where
62    [(); bytes_for_nbits(NBITS)]:,
63{
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_struct("ProductQuantizer")
66            .field("M", &M)
67            .field("NBITS", &NBITS)
68            .field("ksub", &(1usize << NBITS))
69            .field("dsub", &self.dsub)
70            .field("trained", &self.trained)
71            .finish()
72    }
73}
74
75impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> PartialEq for ProductQuantizer<S, M, NBITS>
76where
77    [(); bytes_for_nbits(NBITS)]:,
78{
79    fn eq(&self, other: &Self) -> bool {
80        self.dsub == other.dsub
81            && self.trained == other.trained
82            && self.centroids == other.centroids
83    }
84}
85
86impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> Eq for ProductQuantizer<S, M, NBITS>
87where
88    [(); bytes_for_nbits(NBITS)]:,
89{}
90
91impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> ProductQuantizer<S, M, NBITS>
92where
93    [(); bytes_for_nbits(NBITS)]:,
94    <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
95{
96    /// Number of centroids per subspace (2^NBITS)
97    pub const KSUB: usize = 1 << NBITS;
98
99    /// Create a new Product Quantizer.
100    ///
101    /// # Arguments
102    /// * `space` - The embedding space
103    ///
104    /// # Panics
105    /// Panics if dimension is not divisible by M.
106    pub fn new(space: S) -> Self {
107        let d = S::EmbeddingData::length();
108        assert!(
109            d % M == 0,
110            "dimension {} must be divisible by M={}",
111            d,
112            M
113        );
114        let dsub = d / M;
115
116        Self {
117            space,
118            dsub,
119            d,
120            centroids: Arc::new(Vec::new()),
121            sdc_table: None,
122            trained: false,
123            next_op_id: 1,
124            // Pre-allocate buffer for subvector operations
125            subvec_buffer: vec![0.0; dsub],
126        }
127    }
128
129    /// Get a reference to the underlying embedding space.
130    pub fn space(&self) -> &S {
131        &self.space
132    }
133
134    fn alloc_op_id(&mut self) -> OpId {
135        let id = OpId(self.next_op_id);
136        self.next_op_id += 1;
137        id
138    }
139
140    /// Number of subquantizers.
141    pub fn m(&self) -> usize {
142        M
143    }
144
145    /// Number of centroids per subspace.
146    pub fn ksub(&self) -> usize {
147        Self::KSUB
148    }
149
150    pub fn dsub(&self) -> usize {
151        self.dsub
152    }
153
154    /// Find the nearest centroid in a given subspace.
155    /// The subvector must already be in self.subvec_buffer.
156    fn find_nearest_centroid(&self, subspace: usize) -> usize {
157        let ksub = Self::KSUB;
158        let mut best_idx = 0;
159        let mut best_dist = f32::MAX;
160
161        for k in 0..ksub {
162            let centroid_offset = (subspace * ksub + k) * self.dsub;
163            let centroid = &self.centroids[centroid_offset..centroid_offset + self.dsub];
164
165            let dist: f32 = self.subvec_buffer
166                .iter()
167                .zip(centroid.iter())
168                .map(|(&s, &c)| {
169                    let diff = s - c;
170                    diff * diff
171                })
172                .sum();
173
174            if dist < best_dist {
175                best_dist = dist;
176                best_idx = k;
177            }
178        }
179
180        best_idx
181    }
182
183    /// Copy a subvector from the embedding slice into the reusable buffer.
184    fn fill_subvec_buffer(&mut self, slice: &[<S::EmbeddingData as Embedding>::Scalar], subspace: usize) {
185        let start = subspace * self.dsub;
186        for i in 0..self.dsub {
187            self.subvec_buffer[i] = slice[start + i].into();
188        }
189    }
190
191    /// Encode a single embedding to a PQ code.
192    pub fn encode_embedding(&mut self, embedding: &S::EmbeddingData) -> PQCode<M, NBITS> {
193        assert!(self.trained, "codec must be trained before encoding");
194
195        let slice = embedding.as_slice();
196        let mut code = PQCode::<M, NBITS>::zeros();
197
198        for subspace in 0..M {
199            self.fill_subvec_buffer(slice, subspace);
200            let nearest = self.find_nearest_centroid(subspace);
201            code.set(subspace, nearest as u32);
202        }
203
204        code
205    }
206
207    /// Decode a PQ code to a reconstructed embedding.
208    pub fn decode_code(&self, code: &PQCode<M, NBITS>) -> S::EmbeddingData {
209        assert!(self.trained, "codec must be trained before decoding");
210
211        let ksub = Self::KSUB;
212        let mut result = vec![0.0f32; self.d];
213
214        for m in 0..M {
215            let c = code.get(m) as usize;
216            let centroid_offset = (m * ksub + c) * self.dsub;
217            let centroid = &self.centroids[centroid_offset..centroid_offset + self.dsub];
218
219            let start = m * self.dsub;
220            result[start..start + self.dsub].copy_from_slice(centroid);
221        }
222
223        let scalars: Vec<<S::EmbeddingData as Embedding>::Scalar> =
224            result.into_iter().map(|x| x.into()).collect();
225        S::EmbeddingData::from_slice(&scalars)
226    }
227
228    /// Train the quantizer on a dataset using k-means++.
229    pub fn train_on(&mut self, data: &[S::EmbeddingData]) {
230        assert!(!data.is_empty(), "training data cannot be empty");
231
232        let ksub = Self::KSUB;
233        assert!(
234            data.len() >= ksub,
235            "need at least {} data points (ksub), got {}",
236            ksub,
237            data.len()
238        );
239
240        // Allocate flat centroid storage: M * ksub * dsub
241        let mut centroids = vec![0.0; M * ksub * self.dsub];
242
243        // Train each subspace independently
244        for subspace in 0..M {
245            // Extract subvectors for this subspace
246            let subvectors: Vec<Vec<f32>> = data
247                .iter()
248                .map(|emb| {
249                    let slice = emb.as_slice();
250                    let start = subspace * self.dsub;
251                    let end = start + self.dsub;
252                    slice[start..end].iter().map(|&s| s.into()).collect()
253                })
254                .collect();
255
256            // Run k-means++ on this subspace
257            let kmeans = KMeans::fit(&subvectors, ksub, 25);
258
259            // Copy centroids to flat storage
260            for (k, centroid) in kmeans.centroids.iter().enumerate() {
261                let offset = (subspace * ksub + k) * self.dsub;
262                centroids[offset..offset + self.dsub].copy_from_slice(centroid);
263            }
264        }
265
266        // Build SDC table using the embedding space's distance metric
267        let sdc = SDCTable::from_centroids_with_distance(
268            &centroids,
269            self.dsub,
270            S::slice_distance,
271        );
272        self.centroids = Arc::new(centroids);
273        self.sdc_table = Some(Arc::new(sdc));
274        self.trained = true;
275    }
276
277    /// Build a distance table for ADC (Asymmetric Distance Computation).
278    pub fn build_distance_table(&mut self, query: &S::EmbeddingData) -> PQDistanceTable<S, M, NBITS> {
279        assert!(self.trained, "codec must be trained before distance computation");
280
281        let ksub = Self::KSUB;
282        let query_slice = query.as_slice();
283        let mut table = Vec::with_capacity(M * ksub);
284
285        for subspace in 0..M {
286            self.fill_subvec_buffer(query_slice, subspace);
287
288            for k in 0..ksub {
289                let centroid_offset = (subspace * ksub + k) * self.dsub;
290                let centroid = &self.centroids[centroid_offset..centroid_offset + self.dsub];
291
292                let dist = S::slice_distance(&self.subvec_buffer, centroid);
293
294                table.push(S::DistanceValue::from(dist));
295            }
296        }
297
298        PQDistanceTable::new(table, ksub)
299    }
300
301    /// Get the SDC table for symmetric distance computation.
302    ///
303    /// Returns None if the quantizer is not trained.
304    pub fn sdc_table(&self) -> Option<&SDCTable<M, NBITS>> {
305        self.sdc_table.as_deref()
306    }
307}
308
309/// Eager operation reference for local/synchronous operations.
310///
311/// This is a simple wrapper for synchronous operations that complete immediately.
312/// The result is stored and returned on first call to `finish()`. Subsequent calls
313/// will return a clone of the original error (if the operation failed) or an
314/// `AlreadyFinished` error (if the operation succeeded).
315pub struct EagerOpRef<T, E> {
316    id: OpId,
317    result: Option<Result<T, E>>,
318    /// Cached error for repeated finish() calls
319    cached_error: Option<E>,
320}
321
322impl<T, E: Clone> EagerOpRef<T, E> {
323    pub fn ok(id: OpId, value: T) -> Self {
324        Self {
325            id,
326            result: Some(Ok(value)),
327            cached_error: None,
328        }
329    }
330
331    pub fn err(id: OpId, error: E) -> Self {
332        Self {
333            id,
334            result: Some(Err(error.clone())),
335            cached_error: Some(error),
336        }
337    }
338}
339
340/// Trait for error types that can represent "already finished" state.
341pub trait FinishableError: Clone {
342    /// Create an error indicating the operation was already finished.
343    fn already_finished() -> Self;
344}
345
346impl FinishableError for PQError {
347    fn already_finished() -> Self {
348        PQError::AlreadyFinished
349    }
350}
351
352impl<T, E: FinishableError> OpRef for EagerOpRef<T, E> {
353    type Info = ();
354    type Stats = ();
355    type Result = T;
356    type Error = E;
357
358    fn id(&self) -> &OpId {
359        &self.id
360    }
361
362    fn info(&self) -> Option<Self::Info> {
363        Some(())
364    }
365
366    fn stats(&self) -> Option<Self::Stats> {
367        Some(())
368    }
369
370    fn is_finished(&self) -> bool {
371        true
372    }
373
374    fn finish(&mut self) -> Result<Self::Result, Self::Error> {
375        match self.result.take() {
376            Some(result) => result,
377            None => {
378                // finish() called again - return cached error or AlreadyFinished
379                Err(self.cached_error.clone().unwrap_or_else(E::already_finished))
380            }
381        }
382    }
383}
384
385/// Error type for PQ codec operations.
386#[derive(Debug, Clone, PartialEq, Eq)]
387pub enum PQError {
388    /// Codec has not been trained yet.
389    NotTrained,
390    /// OpRef::finish() was called more than once.
391    AlreadyFinished,
392    /// Serialization or deserialization error.
393    #[cfg(feature = "codec")]
394    SerializationError(String),
395}
396
397impl std::fmt::Display for PQError {
398    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399        match self {
400            PQError::NotTrained => write!(f, "codec not trained"),
401            PQError::AlreadyFinished => write!(f, "operation already finished"),
402            #[cfg(feature = "codec")]
403            PQError::SerializationError(e) => write!(f, "serialization error: {}", e),
404        }
405    }
406}
407
408impl std::error::Error for PQError {}
409
410impl<S: EmbeddingSpace + Default, const M: usize, const NBITS: usize> Default
411    for ProductQuantizer<S, M, NBITS>
412where
413    [(); bytes_for_nbits(NBITS)]:,
414    <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
415{
416    fn default() -> Self {
417        Self::new(S::default())
418    }
419}
420
421// =========================================================================
422// Codebook Serialization
423// =========================================================================
424
425/// Serializable representation of a trained PQ codebook.
426/// Contains all state needed to reconstruct a ProductQuantizer.
427#[cfg(feature = "serde")]
428#[derive(serde::Serialize, serde::Deserialize)]
429pub struct PQCodebook {
430    pub centroids: Vec<f32>,
431    pub sdc_table: Vec<f32>,
432    pub sdc_ksub: usize,
433    pub dsub: usize,
434    pub d: usize,
435    pub m: usize,
436    pub nbits: usize,
437}
438
439#[cfg(feature = "codec")]
440impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> ProductQuantizer<S, M, NBITS>
441where
442    [(); bytes_for_nbits(NBITS)]:,
443    <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
444{
445    /// Save the trained codebook to a file.
446    pub fn save_codebook(&self, path: &std::path::Path) -> Result<(), PQError> {
447        if !self.trained {
448            return Err(PQError::NotTrained);
449        }
450        let sdc = self.sdc_table.as_ref().ok_or(PQError::NotTrained)?;
451        let codebook = PQCodebook {
452            centroids: (*self.centroids).clone(),
453            sdc_table: sdc.table_data().to_vec(),
454            sdc_ksub: sdc.ksub(),
455            dsub: self.dsub,
456            d: self.d,
457            m: M,
458            nbits: NBITS,
459        };
460        let encoded = bincode::serialize(&codebook)
461            .map_err(|e| PQError::SerializationError(e.to_string()))?;
462        std::fs::write(path, encoded)
463            .map_err(|e| PQError::SerializationError(e.to_string()))?;
464        Ok(())
465    }
466
467    /// Load a trained codebook from a file and reconstruct the ProductQuantizer.
468    pub fn load_codebook(path: &std::path::Path, space: S) -> Result<Self, PQError> {
469        let data = std::fs::read(path)
470            .map_err(|e| PQError::SerializationError(e.to_string()))?;
471        let codebook: PQCodebook = bincode::deserialize(&data)
472            .map_err(|e| PQError::SerializationError(e.to_string()))?;
473
474        if codebook.m != M || codebook.nbits != NBITS {
475            return Err(PQError::SerializationError(format!(
476                "Codebook M={}, NBITS={} does not match expected M={}, NBITS={}",
477                codebook.m, codebook.nbits, M, NBITS
478            )));
479        }
480
481        let sdc_table = SDCTable::from_raw(codebook.sdc_table, codebook.sdc_ksub);
482
483        Ok(Self {
484            space,
485            dsub: codebook.dsub,
486            d: codebook.d,
487            centroids: Arc::new(codebook.centroids),
488            sdc_table: Some(Arc::new(sdc_table)),
489            trained: true,
490            next_op_id: 0,
491            subvec_buffer: vec![0.0; codebook.dsub],
492        })
493    }
494}
495
496impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> Codec<S> for ProductQuantizer<S, M, NBITS>
497where
498    [(); bytes_for_nbits(NBITS)]:,
499    <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
500{
501    type Encoded = PQCode<M, NBITS>;
502    type EncodeRef<'b> = EagerOpRef<PQCode<M, NBITS>, PQError> where Self: 'b;
503    type DecodeRef<'b> = EagerOpRef<S::EmbeddingData, PQError> where Self: 'b;
504    type TrainRef<'b> = EagerOpRef<(), PQError> where Self: 'b;
505    type ObserveRef<'b> = EagerOpRef<(), PQError> where Self: 'b;
506
507    fn encode(&mut self, embedding: &S::EmbeddingData) -> Self::EncodeRef<'_> {
508        let id = OpId(0);
509        if !self.trained {
510            return EagerOpRef::err(id, PQError::NotTrained);
511        }
512        EagerOpRef::ok(id, self.encode_embedding(embedding))
513    }
514
515    fn encode_batch(&mut self, embeddings: &[S::EmbeddingData]) -> Vec<Self::EncodeRef<'_>> {
516        embeddings.iter().map(|e| {
517            let id = OpId(0);
518            if !self.trained {
519                return EagerOpRef::err(id, PQError::NotTrained);
520            }
521            EagerOpRef::ok(id, self.encode_embedding(e))
522        }).collect()
523    }
524
525    fn decode(&self, encoded: &Self::Encoded) -> Self::DecodeRef<'_> {
526        let id = OpId(0);
527        if !self.trained {
528            return EagerOpRef::err(id, PQError::NotTrained);
529        }
530        EagerOpRef::ok(id, self.decode_code(encoded))
531    }
532
533    fn decode_batch(&self, encoded: &[Self::Encoded]) -> Vec<Self::DecodeRef<'_>> {
534        encoded.iter().map(|e| self.decode(e)).collect()
535    }
536
537    fn code_size(&self) -> Option<usize> {
538        Some(PQCode::<M, NBITS>::TOTAL_BYTES)
539    }
540
541    fn train(&mut self, embeddings: &[S::EmbeddingData]) -> Self::TrainRef<'_> {
542        let id = self.alloc_op_id();
543        self.train_on(embeddings);
544        EagerOpRef::ok(id, ())
545    }
546
547    fn observe(&mut self, _embedding: &S::EmbeddingData) -> Self::ObserveRef<'_> {
548        // Online training not implemented - no-op
549        EagerOpRef::ok(self.alloc_op_id(), ())
550    }
551
552    fn observe_batch(&mut self, embeddings: &[S::EmbeddingData]) -> Vec<Self::ObserveRef<'_>> {
553        embeddings.iter().map(|_| self.observe(&S::EmbeddingData::zeros())).collect()
554    }
555
556    fn is_trained(&self) -> bool {
557        self.trained
558    }
559}
560
561// ProductQuantizer implements EmbeddingSpace directly for PQCode<M, NBITS>
562impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> EmbeddingSpace for ProductQuantizer<S, M, NBITS>
563where
564    [(); bytes_for_nbits(NBITS)]:,
565    <S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
566{
567    type EmbeddingData = PQCode<M, NBITS>;
568    type DistanceValue = S::DistanceValue;
569    type Prepared = PQCode<M, NBITS>;
570
571    fn space_id(&self) -> &'static str {
572        "pq"
573    }
574
575    fn distance(&self, lhs: &Self::EmbeddingData, rhs: &Self::EmbeddingData) -> Self::DistanceValue {
576        let sdc = self.sdc_table.as_ref().expect("ProductQuantizer must be trained before computing distances");
577        S::DistanceValue::from(sdc.distance(lhs, rhs))
578    }
579
580    fn prepare(&self, embedding: &Self::EmbeddingData) -> Self::Prepared {
581        embedding.clone()
582    }
583
584    fn distance_prepared(
585        &self,
586        prepared: &Self::Prepared,
587        target: &Self::EmbeddingData,
588    ) -> Self::DistanceValue {
589        self.distance(prepared, target)
590    }
591
592    fn length() -> usize {
593        PQCode::<M, NBITS>::TOTAL_BYTES
594    }
595
596    fn slice_distance(a: &[f32], b: &[f32]) -> f32 {
597        S::slice_distance(a, b)
598    }
599
600    fn infinite_mapping(native_distance: &Self::DistanceValue) -> f32 {
601        S::infinite_mapping(native_distance)
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608    use bb_core::embedding::{F32Embedding, F32L2Space};
609
610    type Space = F32L2Space<8>;
611
612    fn make_test_vectors(n: usize) -> Vec<F32Embedding<8>> {
613        (0..n)
614            .map(|i| {
615                let val = i as f32;
616                F32Embedding([val, val + 0.1, val + 0.2, val + 0.3, val + 0.4, val + 0.5, val + 0.6, val + 0.7])
617            })
618            .collect()
619    }
620
621    #[test]
622    fn test_pq_creation() {
623        let space = F32L2Space::<8>;
624        let pq = ProductQuantizer::<Space, 2, 8>::new(space);
625        assert_eq!(pq.m(), 2);
626        assert_eq!(pq.ksub(), 256);
627        assert_eq!(pq.dsub(), 4);
628        assert!(!pq.is_trained());
629    }
630
631    #[test]
632    fn test_pq_creation_nbits2() {
633        let space = F32L2Space::<8>;
634        let pq = ProductQuantizer::<Space, 2, 2>::new(space);
635        assert_eq!(pq.m(), 2);
636        assert_eq!(pq.ksub(), 4);  // 2^2 = 4
637        assert_eq!(pq.dsub(), 4);
638        assert!(!pq.is_trained());
639    }
640
641    #[test]
642    fn test_pq_train_and_encode() {
643        let space = F32L2Space::<8>;
644        let mut pq = ProductQuantizer::<Space, 2, 2>::new(space); // 2 bits = 4 centroids
645        let data = make_test_vectors(100);
646
647        pq.train_on(&data);
648        assert!(pq.is_trained());
649
650        let code = pq.encode_embedding(&data[0]);
651        assert_eq!(code.m(), 2);
652    }
653
654    #[test]
655    fn test_pq_encode_decode() {
656        let space = F32L2Space::<8>;
657        // Use 4 bits = 16 centroids for better reconstruction
658        let mut pq = ProductQuantizer::<Space, 2, 4>::new(space);
659        let data = make_test_vectors(100);
660
661        pq.train_on(&data);
662
663        let original = &data[50];
664        let code = pq.encode_embedding(original);
665        let decoded = pq.decode_code(&code);
666
667        // Decoded should be close to original (within quantization error)
668        let orig_slice = original.as_slice();
669        let dec_slice = decoded.as_slice();
670
671        for i in 0..8 {
672            let diff = (orig_slice[i] - dec_slice[i]).abs();
673            assert!(diff < 10.0, "dimension {} differs by {}", i, diff);
674        }
675    }
676
677    #[test]
678    fn test_pq_adc_distance() {
679        let space = F32L2Space::<8>;
680        let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
681        let data = make_test_vectors(100);
682
683        pq.train_on(&data);
684
685        let query = &data[0];
686        let target = pq.encode_embedding(&data[1]);
687
688        let table = pq.build_distance_table(query);
689        let dist = table.distance(&target);
690
691        // Distance should be positive for different vectors
692        assert!(dist.value() >= 0.0);
693    }
694
695    #[test]
696    fn test_pq_sdc() {
697        let space = F32L2Space::<8>;
698        let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
699        let data = make_test_vectors(100);
700
701        pq.train_on(&data);
702
703        // Encode embeddings first before borrowing sdc table
704        let code1 = pq.encode_embedding(&data[0]);
705        let code2 = pq.encode_embedding(&data[1]);
706
707        let sdc = pq.sdc_table().expect("should have SDC table after training");
708
709        // Same code should have distance 0
710        assert_eq!(sdc.distance(&code1, &code1), 0.0);
711
712        // Different codes should have positive distance
713        let dist = sdc.distance(&code1, &code2);
714        assert!(dist >= 0.0);
715    }
716
717    #[test]
718    fn test_pq_embedding_space() {
719        let space = F32L2Space::<8>;
720        let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
721        let data = make_test_vectors(100);
722
723        pq.train_on(&data);
724
725        let code1 = pq.encode_embedding(&data[0]);
726        let code2 = pq.encode_embedding(&data[1]);
727
728        // Test EmbeddingSpace interface directly on ProductQuantizer
729        let dist = pq.distance(&code1, &code2);
730        assert!(dist.value() >= 0.0);
731
732        // Same code should have distance 0
733        assert_eq!(pq.distance(&code1, &code1).value(), 0.0);
734
735        // Test prepare/distance_prepared
736        let prepared = pq.prepare(&code1);
737        let dist_prepared = pq.distance_prepared(&prepared, &code2);
738        assert_eq!(dist, dist_prepared);
739    }
740
741    #[test]
742    fn test_codec_trait() {
743        let space = F32L2Space::<8>;
744        let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
745        let data = make_test_vectors(100);
746
747        let mut train_ref = pq.train(&data);
748        assert!(train_ref.is_finished());
749        train_ref.finish().unwrap();
750
751        assert!(pq.is_trained());
752
753        let mut encode_ref = pq.encode(&data[0]);
754        assert!(encode_ref.is_finished());
755        let code = encode_ref.finish().unwrap();
756
757        let mut decode_ref = pq.decode(&code);
758        assert!(decode_ref.is_finished());
759        let _decoded = decode_ref.finish().unwrap();
760    }
761
762    #[test]
763    fn test_code_size() {
764        // M=4, NBITS=8 -> 4 bytes per code (1 byte each)
765        let pq = ProductQuantizer::<Space, 4, 8>::new(F32L2Space::<8>);
766        assert_eq!(pq.code_size(), Some(4));
767
768        // M=4, NBITS=10 -> 8 bytes per code (2 bytes each)
769        let pq2 = ProductQuantizer::<Space, 4, 10>::new(F32L2Space::<8>);
770        assert_eq!(pq2.code_size(), Some(8));
771    }
772
773    #[test]
774    #[ignore] // Slow: k-means with 1024 centroids takes ~200s in debug mode
775    fn test_pq_nbits10() {
776        // Test with nbits=10 (1024 centroids per subspace)
777        // This requires more training data
778        let space = F32L2Space::<8>;
779        let mut pq = ProductQuantizer::<Space, 2, 10>::new(space);
780
781        // Need at least 1024 vectors for training
782        let data: Vec<F32Embedding<8>> = (0..2000)
783            .map(|i| {
784                let val = (i as f32) * 0.01;
785                F32Embedding([val, val + 0.1, val + 0.2, val + 0.3, val + 0.4, val + 0.5, val + 0.6, val + 0.7])
786            })
787            .collect();
788
789        pq.train_on(&data);
790        assert!(pq.is_trained());
791        assert_eq!(pq.ksub(), 1024);
792
793        let code = pq.encode_embedding(&data[500]);
794        // Verify centroid indices can be > 255
795        let idx0 = code.get(0);
796        let idx1 = code.get(1);
797        assert!(idx0 < 1024);
798        assert!(idx1 < 1024);
799    }
800}