Skip to main content

lance_index/vector/
pq.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Product Quantization
5//!
6
7use std::sync::Arc;
8
9use arrow::datatypes::{self, ArrowPrimitiveType};
10use arrow_array::{Array, FixedSizeListArray, UInt8Array, cast::AsArray};
11use arrow_array::{ArrayRef, Float32Array, PrimitiveArray};
12use arrow_schema::{DataType, Field};
13use deepsize::DeepSizeOf;
14use distance::build_distance_table_dot;
15use lance_arrow::*;
16use lance_core::{Error, Result, assume_eq};
17use lance_linalg::distance::{DistanceType, Dot, L2, l2::L2Prepared};
18use lance_table::utils::LanceIteratorExtension;
19use num_traits::Float;
20use prost::Message;
21use storage::{PQ_METADATA_KEY, ProductQuantizationMetadata, ProductQuantizationStorage};
22use tracing::instrument;
23
24pub mod builder;
25pub mod distance;
26pub mod storage;
27pub mod transform;
28pub(crate) mod utils;
29
30use self::distance::{
31    build_distance_table_l2, build_distance_table_l2_prepared, compute_pq_distance,
32};
33pub use self::utils::num_centroids;
34use super::quantizer::{
35    Quantization, QuantizationMetadata, QuantizationType, Quantizer, QuantizerBuildParams,
36};
37use super::{PQ_CODE_COLUMN, pb};
38use crate::vector::kmeans::compute_partition;
39pub use builder::PQBuildParams;
40use utils::get_sub_vector_centroids;
41
42#[derive(Debug, Clone)]
43pub struct ProductQuantizer {
44    pub num_sub_vectors: usize,
45    pub num_bits: u32,
46    pub dimension: usize,
47    pub codebook: FixedSizeListArray,
48    pub distance_type: DistanceType,
49    /// Pre-transposed L2 targets per sub-vector for fast f32 L2 batch computation.
50    /// Only populated when codebook is f32 and distance_type is L2
51    /// (Cosine is converted to L2 before construction, so it benefits too).
52    l2_targets: Option<Vec<L2Prepared>>,
53}
54
55impl DeepSizeOf for ProductQuantizer {
56    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
57        self.codebook.get_array_memory_size()
58            + self.num_sub_vectors.deep_size_of_children(_context)
59            + self.num_bits.deep_size_of_children(_context)
60            + self.dimension.deep_size_of_children(_context)
61            + self.distance_type.deep_size_of_children(_context)
62            + self
63                .l2_targets
64                .as_ref()
65                .map_or(0, |v| v.iter().map(|t| t.size_bytes()).sum())
66    }
67}
68
69impl ProductQuantizer {
70    /// Build per-sub-vector L2Prepared from the codebook if applicable (f32 + L2).
71    fn build_l2_targets(
72        codebook: &FixedSizeListArray,
73        distance_type: DistanceType,
74        num_sub_vectors: usize,
75        num_bits: u32,
76        dimension: usize,
77    ) -> Option<Vec<L2Prepared>> {
78        if codebook.value_type() != DataType::Float32 || distance_type != DistanceType::L2 {
79            return None;
80        }
81        let values = codebook
82            .values()
83            .as_primitive::<datatypes::Float32Type>()
84            .values();
85        let sub_dim = dimension / num_sub_vectors;
86        let num_centroids = 2_usize.pow(num_bits);
87        let block_size = sub_dim * num_centroids;
88
89        let targets = (0..num_sub_vectors)
90            .map(|sub_idx| {
91                let block_start = sub_idx * block_size;
92                let block = &values[block_start..block_start + block_size];
93                L2Prepared::new(block, sub_dim)
94            })
95            .collect();
96        Some(targets)
97    }
98
99    pub fn new(
100        num_sub_vectors: usize,
101        num_bits: u32,
102        dimension: usize,
103        codebook: FixedSizeListArray,
104        distance_type: DistanceType,
105    ) -> Self {
106        let l2_targets = Self::build_l2_targets(
107            &codebook,
108            distance_type,
109            num_sub_vectors,
110            num_bits,
111            dimension,
112        );
113        Self {
114            num_bits,
115            num_sub_vectors,
116            dimension,
117            codebook,
118            distance_type,
119            l2_targets,
120        }
121    }
122
123    pub fn from_proto(proto: &pb::Pq, distance_type: DistanceType) -> Result<Self> {
124        let distance_type = match distance_type {
125            DistanceType::Cosine => DistanceType::L2,
126            _ => distance_type,
127        };
128        let codebook = match proto.codebook_tensor.as_ref() {
129            Some(tensor) => FixedSizeListArray::try_from(tensor)?,
130            None => FixedSizeListArray::try_new_from_values(
131                Float32Array::from(proto.codebook.clone()),
132                proto.dimension as i32,
133            )?,
134        };
135        Ok(Self::new(
136            proto.num_sub_vectors as usize,
137            proto.num_bits,
138            proto.dimension as usize,
139            codebook,
140            distance_type,
141        ))
142    }
143
144    #[instrument(name = "ProductQuantizer::transform", level = "debug", skip_all)]
145    fn transform<T: ArrowPrimitiveType>(&self, vectors: &dyn Array) -> Result<ArrayRef>
146    where
147        T::Native: Float + L2 + Dot,
148    {
149        match self.num_bits {
150            4 => self.transform_impl::<4, T>(vectors),
151            8 => self.transform_impl::<8, T>(vectors),
152            _ => Err(Error::index(format!(
153                "ProductQuantization: num_bits {} not supported",
154                self.num_bits
155            ))),
156        }
157    }
158
159    fn transform_impl<const NUM_BITS: u32, T: ArrowPrimitiveType>(
160        &self,
161        vectors: &dyn Array,
162    ) -> Result<ArrayRef>
163    where
164        T::Native: Float + L2 + Dot,
165    {
166        let fsl = vectors
167            .as_fixed_size_list_opt()
168            .ok_or(Error::index(format!(
169                "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
170                vectors.data_type()
171            )))?;
172        let num_sub_vectors = self.num_sub_vectors;
173        let dim = self.dimension;
174        if NUM_BITS == 4 && !num_sub_vectors.is_multiple_of(2) {
175            return Err(Error::index(format!(
176                "PQ: num_sub_vectors must be divisible by 2 for num_bits=4, but got {}",
177                num_sub_vectors,
178            )));
179        }
180        let codebook = self.codebook.values().as_primitive::<T>();
181
182        let distance_type = self.distance_type;
183
184        let flatten_data = fsl.values().as_primitive::<T>();
185        let sub_dim = dim / num_sub_vectors;
186        let num_centroids = 2_usize.pow(NUM_BITS);
187        let total_code_length = fsl.len() * num_sub_vectors / (8 / NUM_BITS as usize);
188
189        let values = if let Some(targets) = &self.l2_targets {
190            // Fast path for f32 + L2: use pre-transposed codebook.
191            // SAFETY: l2_targets is only populated when T::Native is f32.
192            let flat_f32: &[f32] = unsafe {
193                std::slice::from_raw_parts(
194                    flatten_data.values().as_ptr() as *const f32,
195                    flatten_data.values().len(),
196                )
197            };
198            let mut values = vec![0u8; total_code_length];
199            let bytes_per_vector = num_sub_vectors / (8 / NUM_BITS as usize);
200            let mut dist_buf = vec![0.0f32; num_centroids];
201            for (vec_idx, vector) in flat_f32.chunks_exact(dim).enumerate() {
202                let out = &mut values[vec_idx * bytes_per_vector..][..bytes_per_vector];
203                if NUM_BITS == 4 {
204                    for (pair_idx, pair) in vector.chunks_exact(sub_dim * 2).enumerate() {
205                        let lo = targets[pair_idx * 2]
206                            .nearest_into(&pair[..sub_dim], &mut dist_buf)
207                            .unwrap_or(0) as u8;
208                        let hi = targets[pair_idx * 2 + 1]
209                            .nearest_into(&pair[sub_dim..], &mut dist_buf)
210                            .unwrap_or(0) as u8;
211                        out[pair_idx] = (hi << 4) | lo;
212                    }
213                } else {
214                    for (sub_idx, sv) in vector.chunks_exact(sub_dim).enumerate() {
215                        out[sub_idx] = targets[sub_idx]
216                            .nearest_into(sv, &mut dist_buf)
217                            .unwrap_or(0) as u8;
218                    }
219                }
220            }
221            values
222        } else {
223            flatten_data
224                .values()
225                .chunks_exact(dim)
226                .flat_map(|vector| {
227                    let sub_vec_code: Vec<u8> = vector
228                        .chunks_exact(sub_dim)
229                        .enumerate()
230                        .map(|(sub_idx, sub_vector)| {
231                            let centroids = get_sub_vector_centroids::<NUM_BITS, _>(
232                                codebook.values(),
233                                dim,
234                                num_sub_vectors,
235                                sub_idx,
236                            );
237                            assume_eq!(centroids.len(), num_centroids * sub_dim);
238                            compute_partition(centroids, sub_vector, distance_type).unwrap_or(0)
239                                as u8
240                        })
241                        .collect();
242                    if NUM_BITS == 4 {
243                        sub_vec_code
244                            .chunks_exact(2)
245                            .map(|v| (v[1] << 4) | v[0])
246                            .collect::<Vec<_>>()
247                    } else {
248                        sub_vec_code
249                    }
250                })
251                .exact_size(total_code_length)
252                .collect::<Vec<_>>()
253        };
254
255        let num_sub_vectors_in_byte = if NUM_BITS == 4 {
256            num_sub_vectors / 2
257        } else {
258            num_sub_vectors
259        };
260
261        debug_assert_eq!(values.len(), fsl.len() * num_sub_vectors_in_byte);
262        Ok(Arc::new(FixedSizeListArray::try_new_from_values(
263            UInt8Array::from(values),
264            num_sub_vectors_in_byte as i32,
265        )?))
266    }
267
268    // the code must be transposed
269    pub fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
270        if code.is_empty() {
271            return Ok(Float32Array::from(Vec::<f32>::new()));
272        }
273
274        match self.distance_type {
275            DistanceType::L2 => self.l2_distances(query, code),
276            DistanceType::Cosine => {
277                // it seems we implemented cosine distance at some version,
278                // but from now on, we should use normalized L2 distance.
279                debug_assert!(
280                    false,
281                    "cosine distance should be converted to normalized L2 distance"
282                );
283                // L2 over normalized vectors:  ||x - y|| = x^2 + y^2 - 2 * xy = 1 + 1 - 2 * xy = 2 * (1 - xy)
284                // Cosine distance: 1 - |xy| / (||x|| * ||y||) = 1 - xy / (x^2 * y^2) = 1 - xy / (1 * 1) = 1 - xy
285                // Therefore, Cosine = L2 / 2
286                let l2_dists = self.l2_distances(query, code)?;
287                Ok(l2_dists.values().iter().map(|v| *v / 2.0).collect())
288            }
289            DistanceType::Dot => self.dot_distances(query, code),
290            _ => panic!(
291                "ProductQuantization: distance type {} not supported",
292                self.distance_type
293            ),
294        }
295    }
296
297    /// Pre-compute L2 distance from the query to all code.
298    ///
299    /// It returns the squared L2 distance.
300    fn l2_distances(&self, key: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
301        let distance_table = self.build_l2_distance_table(key)?;
302        Ok(self.compute_l2_distance(&distance_table, code.values()))
303    }
304
305    /// Parameters
306    /// ----------
307    ///  - query: the query vector, with shape (dimension, )
308    ///  - code: the PQ code in one partition.
309    ///
310    fn dot_distances(&self, key: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
311        match key.data_type() {
312            DataType::Float16 => {
313                self.dot_distances_impl::<datatypes::Float16Type>(key.as_primitive(), code)
314            }
315            DataType::Float32 => {
316                self.dot_distances_impl::<datatypes::Float32Type>(key.as_primitive(), code)
317            }
318            DataType::Float64 => {
319                self.dot_distances_impl::<datatypes::Float64Type>(key.as_primitive(), code)
320            }
321            _ => Err(Error::index(format!(
322                "unsupported data type: {}",
323                key.data_type()
324            ))),
325        }
326    }
327
328    fn dot_distances_impl<T: ArrowPrimitiveType>(
329        &self,
330        key: &PrimitiveArray<T>,
331        code: &UInt8Array,
332    ) -> Result<Float32Array>
333    where
334        T::Native: Dot,
335    {
336        let distance_table = build_distance_table_dot(
337            self.codebook.values().as_primitive::<T>().values(),
338            self.num_bits,
339            self.num_sub_vectors,
340            key.values(),
341        );
342
343        let distances = compute_pq_distance(
344            &distance_table,
345            self.num_bits,
346            self.num_sub_vectors,
347            code.values(),
348            0,
349        );
350
351        let diff = self.num_sub_vectors as f32 - 1.0;
352        let distances = distances.into_iter().map(|d| d - diff).collect::<Vec<_>>();
353        Ok(distances.into())
354    }
355
356    fn build_l2_distance_table(&self, key: &dyn Array) -> Result<Vec<f32>> {
357        match key.data_type() {
358            DataType::Float16 => {
359                Ok(self.build_l2_distance_table_impl::<datatypes::Float16Type>(key.as_primitive()))
360            }
361            DataType::Float32 => {
362                if let Some(targets) = &self.l2_targets {
363                    let query = key.as_primitive::<datatypes::Float32Type>().values();
364                    Ok(build_distance_table_l2_prepared(targets, query))
365                } else {
366                    Ok(self
367                        .build_l2_distance_table_impl::<datatypes::Float32Type>(key.as_primitive()))
368                }
369            }
370            DataType::Float64 => {
371                Ok(self.build_l2_distance_table_impl::<datatypes::Float64Type>(key.as_primitive()))
372            }
373            _ => Err(Error::index(format!(
374                "unsupported data type: {}",
375                key.data_type()
376            ))),
377        }
378    }
379
380    fn build_l2_distance_table_impl<T: ArrowPrimitiveType>(
381        &self,
382        key: &PrimitiveArray<T>,
383    ) -> Vec<f32>
384    where
385        T::Native: L2,
386    {
387        build_distance_table_l2(
388            self.codebook.values().as_primitive::<T>().values(),
389            self.num_bits,
390            self.num_sub_vectors,
391            key.values(),
392        )
393    }
394
395    /// Compute L2 distance from the query to all code.
396    ///
397    /// Type parameters
398    /// ---------------
399    /// - C: the tile size of code-book to run at once.
400    /// - V: the tile size of PQ code to run at once.
401    ///
402    /// Parameters
403    /// ----------
404    /// - distance_table: the pre-computed L2 distance table.
405    ///   It is a flatten array of [num_sub_vectors, num_centroids] f32.
406    /// - code: the PQ code to be used to compute the distances.
407    ///
408    /// Returns
409    /// -------
410    ///  The squared L2 distance.
411    #[inline]
412    fn compute_l2_distance(&self, distance_table: &[f32], code: &[u8]) -> Float32Array {
413        Float32Array::from(compute_pq_distance(
414            distance_table,
415            self.num_bits,
416            self.num_sub_vectors,
417            code,
418            100,
419        ))
420    }
421
422    /// Get the centroids for one sub-vector.
423    ///
424    /// Returns a flatten `num_centroids * sub_vector_width` f32 array.
425    pub fn centroids<T: ArrowPrimitiveType>(&self, sub_vector_idx: usize) -> &[T::Native] {
426        match self.num_bits {
427            4 => get_sub_vector_centroids::<4, _>(
428                self.codebook.values().as_primitive::<T>().values(),
429                self.dimension,
430                self.num_sub_vectors,
431                sub_vector_idx,
432            ),
433            8 => get_sub_vector_centroids::<8, _>(
434                self.codebook.values().as_primitive::<T>().values(),
435                self.dimension,
436                self.num_sub_vectors,
437                sub_vector_idx,
438            ),
439            _ => panic!(
440                "ProductQuantization: num_bits {} not supported",
441                self.num_bits
442            ),
443        }
444    }
445}
446
447impl Quantization for ProductQuantizer {
448    type BuildParams = PQBuildParams;
449    type Metadata = ProductQuantizationMetadata;
450    type Storage = ProductQuantizationStorage;
451
452    fn build(
453        data: &dyn Array,
454        distance_type: DistanceType,
455        params: &Self::BuildParams,
456    ) -> Result<Self> {
457        assert_eq!(data.null_count(), 0);
458        let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
459            "PQ builder: input is not a FixedSizeList: {}",
460            data.data_type()
461        )))?;
462
463        if let Some(codebook) = params.codebook.as_ref() {
464            return Ok(Self::new(
465                params.num_sub_vectors,
466                params.num_bits as u32,
467                fsl.value_length() as usize,
468                FixedSizeListArray::try_new_from_values(codebook.clone(), fsl.value_length())?,
469                distance_type,
470            ));
471        }
472
473        params.build(data, distance_type)
474    }
475
476    fn retrain(&mut self, data: &dyn Array) -> Result<()> {
477        assert_eq!(data.null_count(), 0);
478        let params = PQBuildParams::with_codebook(
479            self.num_sub_vectors,
480            self.num_bits as usize,
481            Arc::new(self.codebook.clone()),
482        );
483
484        *self = params.build(data, self.distance_type)?;
485        Ok(())
486    }
487
488    fn code_dim(&self) -> usize {
489        self.num_sub_vectors
490    }
491
492    fn column(&self) -> &'static str {
493        PQ_CODE_COLUMN
494    }
495
496    fn use_residual(distance_type: DistanceType) -> bool {
497        PQBuildParams::use_residual(distance_type)
498    }
499
500    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
501        let fsl = vectors
502            .as_fixed_size_list_opt()
503            .ok_or(Error::index(format!(
504                "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
505                vectors.data_type()
506            )))?;
507
508        match fsl.value_type() {
509            DataType::Float16 => self.transform::<datatypes::Float16Type>(vectors),
510            DataType::Float32 => self.transform::<datatypes::Float32Type>(vectors),
511            DataType::Float64 => self.transform::<datatypes::Float64Type>(vectors),
512            _ => Err(Error::index(format!(
513                "unsupported data type: {}",
514                fsl.value_type()
515            ))),
516        }
517    }
518
519    fn metadata_key() -> &'static str {
520        PQ_METADATA_KEY
521    }
522
523    fn quantization_type() -> QuantizationType {
524        QuantizationType::Product
525    }
526
527    fn metadata(&self, args: Option<QuantizationMetadata>) -> Self::Metadata {
528        let codebook_position = match &args {
529            Some(args) => args.codebook_position,
530            None => Some(0),
531        };
532
533        let codebook_position = codebook_position.expect("codebook position should be set");
534        ProductQuantizationMetadata {
535            codebook_position,
536            nbits: self.num_bits,
537            num_sub_vectors: self.num_sub_vectors,
538            dimension: self.dimension,
539            codebook: Some(self.codebook.clone()),
540            codebook_tensor: Vec::new(),
541            transposed: args.map(|args| args.transposed).unwrap_or_default(),
542        }
543    }
544
545    fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
546        let distance_type = match distance_type {
547            DistanceType::Cosine => DistanceType::L2,
548            _ => distance_type,
549        };
550        let codebook = match metadata.codebook.as_ref() {
551            Some(fsl) => fsl.clone(),
552            None => {
553                let tensor = pb::Tensor::decode(metadata.codebook_tensor.as_ref())?;
554                FixedSizeListArray::try_from(&tensor)?
555            }
556        };
557        Ok(Quantizer::Product(Self::new(
558            metadata.num_sub_vectors,
559            metadata.nbits,
560            metadata.dimension,
561            codebook,
562            distance_type,
563        )))
564    }
565
566    fn field(&self) -> Field {
567        let num_bytes_per_sub_vector = self.num_sub_vectors * self.num_bits as usize / 8;
568        Field::new(
569            PQ_CODE_COLUMN,
570            DataType::FixedSizeList(
571                Arc::new(Field::new("item", DataType::UInt8, true)),
572                num_bytes_per_sub_vector as i32,
573            ),
574            true,
575        )
576    }
577}
578
579impl TryFrom<&ProductQuantizer> for pb::Pq {
580    type Error = Error;
581
582    fn try_from(pq: &ProductQuantizer) -> Result<Self> {
583        let tensor = pb::Tensor::try_from(&pq.codebook)?;
584        Ok(Self {
585            num_bits: pq.num_bits,
586            num_sub_vectors: pq.num_sub_vectors as u32,
587            dimension: pq.dimension as u32,
588            codebook: vec![],
589            codebook_tensor: Some(tensor),
590        })
591    }
592}
593
594impl TryFrom<Quantizer> for ProductQuantizer {
595    type Error = Error;
596    fn try_from(value: Quantizer) -> Result<Self> {
597        match value {
598            Quantizer::Product(pq) => Ok(pq),
599            _ => Err(Error::index("Expect to be a ProductQuantizer".to_string())),
600        }
601    }
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607
608    use std::iter::repeat_n;
609
610    use approx::assert_relative_eq;
611    use arrow::datatypes::UInt8Type;
612    use arrow_array::Float16Array;
613    use half::f16;
614    use lance_linalg::distance::l2_distance_batch;
615    use lance_linalg::kernels::argmin;
616    use lance_testing::datagen::generate_random_array;
617    use num_traits::Zero;
618    use storage::transpose;
619
620    #[test]
621    fn test_f16_pq_to_protobuf() {
622        let pq = ProductQuantizer::new(
623            4,
624            8,
625            16,
626            FixedSizeListArray::try_new_from_values(
627                Float16Array::from_iter_values(repeat_n(f16::zero(), 256 * 16)),
628                16,
629            )
630            .unwrap(),
631            DistanceType::L2,
632        );
633        let proto: pb::Pq = pb::Pq::try_from(&pq).unwrap();
634        assert_eq!(proto.num_bits, 8);
635        assert_eq!(proto.num_sub_vectors, 4);
636        assert_eq!(proto.dimension, 16);
637        assert!(proto.codebook.is_empty());
638        assert!(proto.codebook_tensor.is_some());
639
640        let tensor = proto.codebook_tensor.as_ref().unwrap();
641        assert_eq!(tensor.data_type, pb::tensor::DataType::Float16 as i32);
642        assert_eq!(tensor.shape, vec![256, 16]);
643    }
644
645    #[test]
646    fn test_l2_distance() {
647        const DIM: usize = 512;
648        const TOTAL: usize = 66; // 64 + 2 to make sure reminder is handled correctly.
649        let codebook = generate_random_array(256 * DIM);
650        let pq = ProductQuantizer::new(
651            16,
652            8,
653            DIM,
654            FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
655            DistanceType::L2,
656        );
657        let pq_code = UInt8Array::from_iter_values((0..16 * TOTAL).map(|v| v as u8));
658        let query = generate_random_array(DIM);
659
660        let transposed_pq_codes = transpose(&pq_code, TOTAL, 16);
661        let dists = pq.compute_distances(&query, &transposed_pq_codes).unwrap();
662
663        let sub_vec_len = DIM / 16;
664        let expected = pq_code
665            .values()
666            .chunks(16)
667            .map(|code| {
668                code.iter()
669                    .enumerate()
670                    .flat_map(|(sub_idx, c)| {
671                        let subvec_centroids = pq.centroids::<datatypes::Float32Type>(sub_idx);
672                        let subvec =
673                            &query.values()[sub_idx * sub_vec_len..(sub_idx + 1) * sub_vec_len];
674                        l2_distance_batch(
675                            subvec,
676                            &subvec_centroids
677                                [*c as usize * sub_vec_len..(*c as usize + 1) * sub_vec_len],
678                            sub_vec_len,
679                        )
680                    })
681                    .sum::<f32>()
682            })
683            .collect::<Vec<_>>();
684        dists
685            .values()
686            .iter()
687            .zip(expected.iter())
688            .for_each(|(v, e)| {
689                assert_relative_eq!(*v, *e, epsilon = 1e-4);
690            });
691    }
692
693    #[test]
694    fn test_pq_transform() {
695        const DIM: usize = 16;
696        const TOTAL: usize = 64;
697        let codebook = generate_random_array(DIM * 256);
698        let pq = ProductQuantizer::new(
699            4,
700            8,
701            DIM,
702            FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
703            DistanceType::L2,
704        );
705
706        let vectors = generate_random_array(DIM * TOTAL);
707        let fsl = FixedSizeListArray::try_new_from_values(vectors.clone(), DIM as i32).unwrap();
708        let pq_code = pq.quantize(&fsl).unwrap();
709
710        let mut expected = Vec::with_capacity(TOTAL * 4);
711        vectors.values().chunks_exact(DIM).for_each(|vec| {
712            vec.chunks_exact(DIM / 4)
713                .enumerate()
714                .for_each(|(sub_idx, sub_vec)| {
715                    let centroids = pq.centroids::<datatypes::Float32Type>(sub_idx);
716                    let dists = l2_distance_batch(sub_vec, centroids, DIM / 4);
717                    let code = argmin(dists).unwrap() as u8;
718                    expected.push(code);
719                });
720        });
721
722        assert_eq!(pq_code.len(), TOTAL);
723        assert_eq!(
724            &expected,
725            pq_code
726                .as_fixed_size_list()
727                .values()
728                .as_primitive::<UInt8Type>()
729                .values()
730        );
731    }
732}