manifoldb_vector/quantization/
pq.rs

1//! Product Quantization implementation.
2//!
3//! Product Quantization compresses vectors by splitting them into subspaces
4//! and quantizing each subspace independently.
5
6use crate::distance::DistanceMetric;
7use crate::error::VectorError;
8
9use super::config::PQConfig;
10use super::training::{KMeans, KMeansConfig};
11
12/// A compressed vector code from Product Quantization.
13///
14/// Each byte represents an index into the corresponding subspace codebook.
15/// For 256 centroids (default), this is one byte per segment.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct PQCode {
18    /// The centroid indices, one per segment.
19    codes: Vec<u8>,
20    /// Number of bits per code (for serialization).
21    bits_per_code: u8,
22}
23
24impl PQCode {
25    /// Create a new PQ code from raw indices.
26    ///
27    /// # Arguments
28    ///
29    /// - `codes`: Centroid indices, one per segment
30    /// - `bits_per_code`: Number of bits per index (typically 8 for 256 centroids)
31    #[must_use]
32    pub fn new(codes: Vec<u8>, bits_per_code: u8) -> Self {
33        Self { codes, bits_per_code }
34    }
35
36    /// Get the number of segments.
37    #[must_use]
38    pub fn num_segments(&self) -> usize {
39        self.codes.len()
40    }
41
42    /// Get the code (centroid index) for a segment.
43    #[must_use]
44    pub fn get(&self, segment: usize) -> Option<u8> {
45        self.codes.get(segment).copied()
46    }
47
48    /// Get all codes as a slice.
49    #[must_use]
50    pub fn as_slice(&self) -> &[u8] {
51        &self.codes
52    }
53
54    /// Convert to bytes for storage.
55    #[must_use]
56    pub fn to_bytes(&self) -> Vec<u8> {
57        // For 8-bit codes, just return the codes directly
58        // For other bit widths, we'd need packing
59        if self.bits_per_code == 8 {
60            self.codes.clone()
61        } else {
62            // Pack codes into bytes
63            self.pack_codes()
64        }
65    }
66
67    /// Create from bytes.
68    ///
69    /// # Arguments
70    ///
71    /// - `bytes`: Packed code bytes
72    /// - `num_segments`: Number of segments (codes)
73    /// - `bits_per_code`: Number of bits per code
74    #[must_use]
75    pub fn from_bytes(bytes: &[u8], num_segments: usize, bits_per_code: u8) -> Self {
76        if bits_per_code == 8 {
77            Self { codes: bytes[..num_segments].to_vec(), bits_per_code }
78        } else {
79            Self::unpack_codes(bytes, num_segments, bits_per_code)
80        }
81    }
82
83    /// Pack codes into bytes (for non-8-bit codes).
84    fn pack_codes(&self) -> Vec<u8> {
85        let total_bits = self.codes.len() * self.bits_per_code as usize;
86        let num_bytes = total_bits.div_ceil(8);
87        let mut bytes = vec![0u8; num_bytes];
88
89        let mut bit_pos = 0usize;
90        for &code in &self.codes {
91            let byte_idx = bit_pos / 8;
92            let bit_offset = bit_pos % 8;
93
94            // Write lower bits to current byte
95            bytes[byte_idx] |= code << bit_offset;
96
97            // Handle codes that span byte boundaries
98            if bit_offset + self.bits_per_code as usize > 8 && byte_idx + 1 < bytes.len() {
99                bytes[byte_idx + 1] |= code >> (8 - bit_offset);
100            }
101
102            bit_pos += self.bits_per_code as usize;
103        }
104
105        bytes
106    }
107
108    /// Unpack codes from bytes.
109    fn unpack_codes(bytes: &[u8], num_segments: usize, bits_per_code: u8) -> Self {
110        let mask = (1u8 << bits_per_code) - 1;
111        let mut codes = Vec::with_capacity(num_segments);
112
113        let mut bit_pos = 0usize;
114        for _ in 0..num_segments {
115            let byte_idx = bit_pos / 8;
116            let bit_offset = bit_pos % 8;
117
118            let code = if bit_offset + bits_per_code as usize <= 8 {
119                (bytes[byte_idx] >> bit_offset) & mask
120            } else {
121                let low = bytes[byte_idx] >> bit_offset;
122                let high = if byte_idx + 1 < bytes.len() {
123                    bytes[byte_idx + 1] << (8 - bit_offset)
124                } else {
125                    0
126                };
127                (low | high) & mask
128            };
129
130            codes.push(code);
131            bit_pos += bits_per_code as usize;
132        }
133
134        Self { codes, bits_per_code }
135    }
136}
137
138/// Product Quantizer for vector compression.
139///
140/// The quantizer stores codebooks (centroids) for each subspace and provides
141/// methods for encoding vectors and computing approximate distances.
142#[derive(Debug, Clone)]
143pub struct ProductQuantizer {
144    /// Configuration.
145    config: PQConfig,
146    /// Codebooks: `codebooks[segment][centroid_idx]` = centroid vector.
147    codebooks: Vec<Vec<Vec<f32>>>,
148}
149
150impl ProductQuantizer {
151    /// Train a Product Quantizer on training data.
152    ///
153    /// # Arguments
154    ///
155    /// - `config`: PQ configuration
156    /// - `training_data`: Training vectors (must all have dimension == config.dimension)
157    ///
158    /// # Errors
159    ///
160    /// Returns an error if:
161    /// - Configuration is invalid
162    /// - Training data is empty
163    /// - Training vectors have wrong dimension
164    pub fn train(config: &PQConfig, training_data: &[&[f32]]) -> Result<Self, VectorError> {
165        config.validate()?;
166
167        if training_data.is_empty() {
168            return Err(VectorError::Encoding("cannot train PQ on empty data".to_string()));
169        }
170
171        // Validate dimensions
172        for (i, v) in training_data.iter().enumerate() {
173            if v.len() != config.dimension {
174                return Err(VectorError::DimensionMismatch {
175                    expected: config.dimension,
176                    actual: v.len(),
177                });
178            }
179            if i > 1000 {
180                break; // Only check first 1000 for performance
181            }
182        }
183
184        let subspace_dim = config.subspace_dimension();
185        let mut codebooks = Vec::with_capacity(config.num_segments);
186
187        // Train a codebook for each segment
188        for segment in 0..config.num_segments {
189            let start = segment * subspace_dim;
190            let end = start + subspace_dim;
191
192            // Extract subvectors for this segment
193            let subvectors: Vec<Vec<f32>> =
194                training_data.iter().map(|v| v[start..end].to_vec()).collect();
195
196            let subvector_refs: Vec<&[f32]> = subvectors.iter().map(|v| v.as_slice()).collect();
197
198            // Train k-means on this segment
199            let kmeans_config = KMeansConfig::new(config.num_centroids)
200                .with_max_iterations(config.training_iterations)
201                .with_seed(config.seed.map(|s| s + segment as u64).unwrap_or(segment as u64));
202
203            let kmeans = KMeans::train(&subvector_refs, &kmeans_config, config.distance_metric)?;
204            codebooks.push(kmeans.centroids);
205        }
206
207        Ok(Self { config: config.clone(), codebooks })
208    }
209
210    /// Create a Product Quantizer from pre-trained codebooks.
211    ///
212    /// # Arguments
213    ///
214    /// - `config`: PQ configuration
215    /// - `codebooks`: Pre-trained codebooks, shape `[num_segments][num_centroids][subspace_dim]`
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if codebooks don't match the configuration.
220    pub fn from_codebooks(
221        config: &PQConfig,
222        codebooks: Vec<Vec<Vec<f32>>>,
223    ) -> Result<Self, VectorError> {
224        config.validate()?;
225
226        if codebooks.len() != config.num_segments {
227            return Err(VectorError::Encoding(format!(
228                "expected {} codebooks, got {}",
229                config.num_segments,
230                codebooks.len()
231            )));
232        }
233
234        let subspace_dim = config.subspace_dimension();
235        for (i, codebook) in codebooks.iter().enumerate() {
236            if codebook.len() != config.num_centroids {
237                return Err(VectorError::Encoding(format!(
238                    "codebook {} has {} centroids, expected {}",
239                    i,
240                    codebook.len(),
241                    config.num_centroids
242                )));
243            }
244            for centroid in codebook {
245                if centroid.len() != subspace_dim {
246                    return Err(VectorError::DimensionMismatch {
247                        expected: subspace_dim,
248                        actual: centroid.len(),
249                    });
250                }
251            }
252        }
253
254        Ok(Self { config: config.clone(), codebooks })
255    }
256
257    /// Get the configuration.
258    #[must_use]
259    pub fn config(&self) -> &PQConfig {
260        &self.config
261    }
262
263    /// Get the codebooks.
264    #[must_use]
265    pub fn codebooks(&self) -> &[Vec<Vec<f32>>] {
266        &self.codebooks
267    }
268
269    /// Encode a vector into a PQ code.
270    ///
271    /// # Arguments
272    ///
273    /// - `vector`: Input vector with dimension == config.dimension
274    ///
275    /// # Panics
276    ///
277    /// Panics if the vector has wrong dimension.
278    #[must_use]
279    #[allow(clippy::cast_possible_truncation)]
280    pub fn encode(&self, vector: &[f32]) -> PQCode {
281        debug_assert_eq!(vector.len(), self.config.dimension);
282
283        let subspace_dim = self.config.subspace_dimension();
284        let mut codes = Vec::with_capacity(self.config.num_segments);
285
286        for (segment, codebook) in self.codebooks.iter().enumerate() {
287            let start = segment * subspace_dim;
288            let end = start + subspace_dim;
289            let subvector = &vector[start..end];
290
291            // Find nearest centroid
292            let mut min_dist = f32::MAX;
293            let mut min_idx = 0u8;
294
295            for (idx, centroid) in codebook.iter().enumerate() {
296                let dist = self.subspace_distance(subvector, centroid);
297                if dist < min_dist {
298                    min_dist = dist;
299                    min_idx = idx as u8;
300                }
301            }
302
303            codes.push(min_idx);
304        }
305
306        PQCode::new(codes, self.config.bits_per_code() as u8)
307    }
308
309    /// Decode a PQ code back to an approximate vector.
310    ///
311    /// The reconstructed vector is the concatenation of the centroids
312    /// indicated by the code.
313    #[must_use]
314    pub fn decode(&self, code: &PQCode) -> Vec<f32> {
315        let mut vector = Vec::with_capacity(self.config.dimension);
316
317        for (segment, &idx) in code.as_slice().iter().enumerate() {
318            let centroid = &self.codebooks[segment][idx as usize];
319            vector.extend_from_slice(centroid);
320        }
321
322        vector
323    }
324
325    /// Compute a distance lookup table for asymmetric distance computation (ADC).
326    ///
327    /// The table contains precomputed distances from each subvector of the query
328    /// to all centroids in the corresponding codebook.
329    ///
330    /// Shape: `table[segment][centroid_idx]` = distance
331    ///
332    /// # Arguments
333    ///
334    /// - `query`: Query vector with dimension == config.dimension
335    #[must_use]
336    pub fn compute_distance_table(&self, query: &[f32]) -> DistanceTable {
337        debug_assert_eq!(query.len(), self.config.dimension);
338
339        let subspace_dim = self.config.subspace_dimension();
340        let mut table = Vec::with_capacity(self.config.num_segments);
341
342        for (segment, codebook) in self.codebooks.iter().enumerate() {
343            let start = segment * subspace_dim;
344            let end = start + subspace_dim;
345            let subvector = &query[start..end];
346
347            let mut segment_distances = Vec::with_capacity(codebook.len());
348            for centroid in codebook {
349                let dist = self.subspace_distance(subvector, centroid);
350                segment_distances.push(dist);
351            }
352
353            table.push(segment_distances);
354        }
355
356        DistanceTable { table, metric: self.config.distance_metric }
357    }
358
359    /// Compute asymmetric distance from a precomputed distance table to a PQ code.
360    ///
361    /// This is the primary method for fast approximate nearest neighbor search.
362    /// The query vector is exact, while the database vector is compressed.
363    ///
364    /// # Arguments
365    ///
366    /// - `table`: Distance table from `compute_distance_table`
367    /// - `code`: PQ code of a database vector
368    #[must_use]
369    #[inline]
370    pub fn asymmetric_distance(&self, table: &DistanceTable, code: &PQCode) -> f32 {
371        let mut total = 0.0f32;
372
373        for (segment, &idx) in code.as_slice().iter().enumerate() {
374            total += table.table[segment][idx as usize];
375        }
376
377        // For Euclidean distance, we sum squared distances and take sqrt at the end
378        // For other metrics, we just sum
379        match self.config.distance_metric {
380            DistanceMetric::Euclidean => total.sqrt(),
381            _ => total,
382        }
383    }
384
385    /// Compute asymmetric squared distance (faster, no sqrt).
386    ///
387    /// For Euclidean distance, returns the squared distance.
388    /// For other metrics, returns the same as `asymmetric_distance`.
389    #[must_use]
390    #[inline]
391    pub fn asymmetric_distance_squared(&self, table: &DistanceTable, code: &PQCode) -> f32 {
392        let mut total = 0.0f32;
393
394        for (segment, &idx) in code.as_slice().iter().enumerate() {
395            total += table.table[segment][idx as usize];
396        }
397
398        total
399    }
400
401    /// Compute symmetric distance between two PQ codes.
402    ///
403    /// This is faster but less accurate than asymmetric distance.
404    /// Both vectors are compressed.
405    #[must_use]
406    pub fn symmetric_distance(&self, code_a: &PQCode, code_b: &PQCode) -> f32 {
407        let mut total = 0.0f32;
408
409        for segment in 0..self.config.num_segments {
410            let idx_a = code_a.as_slice()[segment] as usize;
411            let idx_b = code_b.as_slice()[segment] as usize;
412
413            let centroid_a = &self.codebooks[segment][idx_a];
414            let centroid_b = &self.codebooks[segment][idx_b];
415
416            total += self.subspace_distance(centroid_a, centroid_b);
417        }
418
419        match self.config.distance_metric {
420            DistanceMetric::Euclidean => total.sqrt(),
421            _ => total,
422        }
423    }
424
425    /// Compute distance between two subvectors.
426    #[inline]
427    fn subspace_distance(&self, a: &[f32], b: &[f32]) -> f32 {
428        match self.config.distance_metric {
429            DistanceMetric::Euclidean => {
430                // Return squared distance for efficiency (sqrt at the end)
431                a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
432            }
433            DistanceMetric::Cosine => {
434                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
435                let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
436                let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
437                if norm_a == 0.0 || norm_b == 0.0 {
438                    1.0
439                } else {
440                    1.0 - (dot / (norm_a * norm_b))
441                }
442            }
443            DistanceMetric::DotProduct => {
444                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
445                -dot
446            }
447            DistanceMetric::Manhattan => a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum(),
448            DistanceMetric::Chebyshev => {
449                a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).fold(0.0f32, f32::max)
450            }
451        }
452    }
453
454    /// Serialize the quantizer to bytes.
455    #[must_use]
456    pub fn to_bytes(&self) -> Vec<u8> {
457        let mut bytes = Vec::new();
458
459        // Version byte
460        bytes.push(1u8);
461
462        // Config
463        bytes.extend_from_slice(&(self.config.dimension as u32).to_le_bytes());
464        bytes.extend_from_slice(&(self.config.num_segments as u32).to_le_bytes());
465        bytes.extend_from_slice(&(self.config.num_centroids as u32).to_le_bytes());
466        bytes.push(match self.config.distance_metric {
467            DistanceMetric::Euclidean => 0,
468            DistanceMetric::Cosine => 1,
469            DistanceMetric::DotProduct => 2,
470            DistanceMetric::Manhattan => 3,
471            DistanceMetric::Chebyshev => 4,
472        });
473
474        // Codebooks
475        for codebook in &self.codebooks {
476            for centroid in codebook {
477                for &val in centroid {
478                    bytes.extend_from_slice(&val.to_le_bytes());
479                }
480            }
481        }
482
483        bytes
484    }
485
486    /// Deserialize a quantizer from bytes.
487    ///
488    /// # Errors
489    ///
490    /// Returns an error if the bytes are malformed.
491    pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
492        if bytes.len() < 14 {
493            return Err(VectorError::Encoding("PQ bytes too short".to_string()));
494        }
495
496        let version = bytes[0];
497        if version != 1 {
498            return Err(VectorError::Encoding(format!("unsupported PQ version: {}", version)));
499        }
500
501        let dimension = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
502        let num_segments = u32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]) as usize;
503        let num_centroids =
504            u32::from_le_bytes([bytes[9], bytes[10], bytes[11], bytes[12]]) as usize;
505        let distance_metric = match bytes[13] {
506            0 => DistanceMetric::Euclidean,
507            1 => DistanceMetric::Cosine,
508            2 => DistanceMetric::DotProduct,
509            3 => DistanceMetric::Manhattan,
510            4 => DistanceMetric::Chebyshev,
511            m => return Err(VectorError::Encoding(format!("unknown metric: {}", m))),
512        };
513
514        let config = PQConfig::new(dimension, num_segments)
515            .with_num_centroids(num_centroids)
516            .with_distance_metric(distance_metric);
517
518        let subspace_dim = dimension / num_segments;
519        let codebook_size = num_centroids * subspace_dim * 4; // 4 bytes per f32
520        let expected_size = 14 + num_segments * codebook_size;
521
522        if bytes.len() < expected_size {
523            return Err(VectorError::Encoding(format!(
524                "PQ bytes too short: expected {}, got {}",
525                expected_size,
526                bytes.len()
527            )));
528        }
529
530        let mut offset = 14;
531        let mut codebooks = Vec::with_capacity(num_segments);
532
533        for _ in 0..num_segments {
534            let mut codebook = Vec::with_capacity(num_centroids);
535            for _ in 0..num_centroids {
536                let mut centroid = Vec::with_capacity(subspace_dim);
537                for _ in 0..subspace_dim {
538                    let val = f32::from_le_bytes([
539                        bytes[offset],
540                        bytes[offset + 1],
541                        bytes[offset + 2],
542                        bytes[offset + 3],
543                    ]);
544                    centroid.push(val);
545                    offset += 4;
546                }
547                codebook.push(centroid);
548            }
549            codebooks.push(codebook);
550        }
551
552        Self::from_codebooks(&config, codebooks)
553    }
554}
555
556/// Precomputed distance table for asymmetric distance computation.
557///
558/// Contains distances from query subvectors to all centroids in each codebook.
559#[derive(Debug, Clone)]
560pub struct DistanceTable {
561    /// Distance table: `table[segment][centroid_idx]` = distance.
562    table: Vec<Vec<f32>>,
563    /// Distance metric used.
564    metric: DistanceMetric,
565}
566
567impl DistanceTable {
568    /// Get the distance for a segment and centroid index.
569    #[must_use]
570    #[inline]
571    pub fn get(&self, segment: usize, centroid_idx: usize) -> f32 {
572        self.table[segment][centroid_idx]
573    }
574
575    /// Get the number of segments.
576    #[must_use]
577    pub fn num_segments(&self) -> usize {
578        self.table.len()
579    }
580
581    /// Get the number of centroids per segment.
582    #[must_use]
583    pub fn num_centroids(&self) -> usize {
584        self.table.first().map_or(0, Vec::len)
585    }
586
587    /// Get the distance metric used.
588    #[must_use]
589    pub fn metric(&self) -> DistanceMetric {
590        self.metric
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597
598    fn generate_random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
599        let mut rng_state = seed;
600        (0..n)
601            .map(|_| {
602                (0..dim)
603                    .map(|_| {
604                        rng_state ^= rng_state << 13;
605                        rng_state ^= rng_state >> 7;
606                        rng_state ^= rng_state << 17;
607                        (rng_state as f64 / u64::MAX as f64) as f32 * 2.0 - 1.0
608                    })
609                    .collect()
610            })
611            .collect()
612    }
613
614    #[test]
615    fn test_pq_code_roundtrip() {
616        let code = PQCode::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 8);
617        let bytes = code.to_bytes();
618        let restored = PQCode::from_bytes(&bytes, 8, 8);
619        assert_eq!(code, restored);
620    }
621
622    #[test]
623    fn test_pq_code_4bit_roundtrip() {
624        let code = PQCode::new(vec![1, 15, 8, 3], 4);
625        let bytes = code.to_bytes();
626        let restored = PQCode::from_bytes(&bytes, 4, 4);
627        assert_eq!(code, restored);
628    }
629
630    #[test]
631    fn test_pq_train_and_encode() {
632        // Generate random training data
633        let training_data = generate_random_vectors(100, 32, 42);
634        let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
635
636        let config = PQConfig::new(32, 4).with_num_centroids(16).with_seed(42);
637
638        let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
639
640        // Encode a vector
641        let vector = generate_random_vectors(1, 32, 123)[0].clone();
642        let code = pq.encode(&vector);
643
644        assert_eq!(code.num_segments(), 4);
645        for i in 0..4 {
646            assert!(code.get(i).unwrap() < 16);
647        }
648    }
649
650    #[test]
651    fn test_pq_decode() {
652        let training_data = generate_random_vectors(100, 32, 42);
653        let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
654
655        let config = PQConfig::new(32, 4).with_num_centroids(16).with_seed(42);
656        let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
657
658        let vector = generate_random_vectors(1, 32, 123)[0].clone();
659        let code = pq.encode(&vector);
660        let decoded = pq.decode(&code);
661
662        assert_eq!(decoded.len(), 32);
663    }
664
665    #[test]
666    fn test_asymmetric_distance() {
667        let training_data = generate_random_vectors(200, 64, 42);
668        let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
669
670        let config = PQConfig::new(64, 8).with_num_centroids(32).with_seed(42);
671        let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
672
673        // Encode database vectors
674        let db_vectors = generate_random_vectors(50, 64, 100);
675        let codes: Vec<PQCode> = db_vectors.iter().map(|v| pq.encode(v)).collect();
676
677        // Query vector
678        let query = generate_random_vectors(1, 64, 200)[0].clone();
679        let table = pq.compute_distance_table(&query);
680
681        // Compute approximate distances
682        let approx_dists: Vec<f32> =
683            codes.iter().map(|c| pq.asymmetric_distance(&table, c)).collect();
684
685        // All distances should be non-negative for Euclidean
686        for d in &approx_dists {
687            assert!(*d >= 0.0, "distance should be non-negative: {}", d);
688        }
689    }
690
691    #[test]
692    fn test_symmetric_distance() {
693        let training_data = generate_random_vectors(100, 32, 42);
694        let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
695
696        let config = PQConfig::new(32, 4).with_num_centroids(16).with_seed(42);
697        let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
698
699        let v1 = generate_random_vectors(1, 32, 100)[0].clone();
700        let v2 = generate_random_vectors(1, 32, 200)[0].clone();
701
702        let code1 = pq.encode(&v1);
703        let code2 = pq.encode(&v2);
704
705        let dist = pq.symmetric_distance(&code1, &code2);
706        assert!(dist >= 0.0);
707
708        // Distance to self should be 0
709        let self_dist = pq.symmetric_distance(&code1, &code1);
710        assert!(self_dist < 1e-6);
711    }
712
713    #[test]
714    fn test_pq_serialization() {
715        let training_data = generate_random_vectors(100, 32, 42);
716        let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
717
718        let config = PQConfig::new(32, 4).with_num_centroids(16).with_seed(42);
719        let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
720
721        let bytes = pq.to_bytes();
722        let restored = ProductQuantizer::from_bytes(&bytes).unwrap();
723
724        // Verify config matches
725        assert_eq!(pq.config().dimension, restored.config().dimension);
726        assert_eq!(pq.config().num_segments, restored.config().num_segments);
727        assert_eq!(pq.config().num_centroids, restored.config().num_centroids);
728
729        // Verify codebooks match
730        for (seg, (orig, rest)) in
731            pq.codebooks().iter().zip(restored.codebooks().iter()).enumerate()
732        {
733            for (cent, (o, r)) in orig.iter().zip(rest.iter()).enumerate() {
734                for (dim, (&ov, &rv)) in o.iter().zip(r.iter()).enumerate() {
735                    assert!(
736                        (ov - rv).abs() < 1e-6,
737                        "mismatch at seg={}, cent={}, dim={}: {} vs {}",
738                        seg,
739                        cent,
740                        dim,
741                        ov,
742                        rv
743                    );
744                }
745            }
746        }
747    }
748
749    #[test]
750    fn test_distance_approximation_quality() {
751        // Test that PQ distances approximate true distances reasonably well
752        let training_data = generate_random_vectors(500, 64, 42);
753        let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
754
755        let config = PQConfig::new(64, 8).with_num_centroids(256).with_seed(42);
756        let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
757
758        // Test vectors
759        let query = generate_random_vectors(1, 64, 100)[0].clone();
760        let database = generate_random_vectors(100, 64, 200);
761
762        let table = pq.compute_distance_table(&query);
763
764        // Compute true and approximate distances
765        let mut correlations = Vec::new();
766        for db_vec in &database {
767            let true_dist: f32 =
768                query.iter().zip(db_vec.iter()).map(|(a, b)| (a - b) * (a - b)).sum::<f32>().sqrt();
769
770            let code = pq.encode(db_vec);
771            let approx_dist = pq.asymmetric_distance(&table, &code);
772
773            // Track relative error
774            if true_dist > 0.1 {
775                let rel_error = (approx_dist - true_dist).abs() / true_dist;
776                correlations.push(rel_error);
777            }
778        }
779
780        // Average relative error should be reasonable (< 50% for this test)
781        let avg_error: f32 = correlations.iter().sum::<f32>() / correlations.len() as f32;
782        assert!(avg_error < 0.5, "average relative error too high: {}", avg_error);
783    }
784}