lance_index/vector/pq/
storage.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Product Quantization storage
5//!
6//! Used as storage backend for Graph based algorithms.
7
8use std::{cmp::min, collections::HashMap, sync::Arc};
9
10use arrow::datatypes::{self, UInt8Type};
11use arrow_array::{
12    cast::AsArray,
13    types::{Float32Type, UInt64Type},
14    FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array,
15};
16use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
17use arrow_schema::{DataType, SchemaRef};
18use async_trait::async_trait;
19use bytes::{Bytes, BytesMut};
20use deepsize::DeepSizeOf;
21use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
22use lance_core::{Error, Result, ROW_ID};
23use lance_file::previous::{
24    reader::FileReader as PreviousFileReader, writer::FileWriter as PreviousFileWriter,
25};
26use lance_io::{object_store::ObjectStore, utils::read_message};
27use lance_linalg::distance::{DistanceType, Dot, L2};
28use lance_table::utils::LanceIteratorExtension;
29use lance_table::{format::SelfDescribingFileReader, io::manifest::ManifestDescribing};
30use object_store::path::Path;
31use prost::Message;
32use serde::{Deserialize, Serialize};
33use snafu::location;
34
35use super::distance::{build_distance_table_dot, build_distance_table_l2, compute_pq_distance};
36use super::ProductQuantizer;
37use crate::frag_reuse::FragReuseIndex;
38use crate::{
39    pb,
40    vector::{
41        pq::transform::PQTransformer,
42        quantizer::{QuantizerMetadata, QuantizerStorage},
43        storage::{DistCalculator, VectorStore},
44        transform::Transformer,
45        PQ_CODE_COLUMN,
46    },
47    IndexMetadata, INDEX_METADATA_SCHEMA_KEY,
48};
49
50pub const PQ_METADATA_KEY: &str = "lance:pq";
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ProductQuantizationMetadata {
54    pub codebook_position: usize,
55    pub nbits: u32,
56    pub num_sub_vectors: usize,
57    pub dimension: usize,
58
59    #[serde(skip)]
60    pub codebook: Option<FixedSizeListArray>,
61
62    // empty for v1 format
63    // used for v3 format
64    // deprecated in later version
65    pub codebook_tensor: Vec<u8>,
66    pub transposed: bool,
67}
68
69impl DeepSizeOf for ProductQuantizationMetadata {
70    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
71        self.codebook
72            .as_ref()
73            .map(|codebook| codebook.get_array_memory_size())
74            .unwrap_or(0)
75    }
76}
77
78impl PartialEq for ProductQuantizationMetadata {
79    fn eq(&self, other: &Self) -> bool {
80        self.num_sub_vectors == other.num_sub_vectors
81            && self.nbits == other.nbits
82            && self.dimension == other.dimension
83            && self.codebook == other.codebook
84    }
85}
86
87#[async_trait]
88impl QuantizerMetadata for ProductQuantizationMetadata {
89    fn buffer_index(&self) -> Option<u32> {
90        if self.codebook_position > 0 {
91            // the global buffer index starts from 1
92            Some(self.codebook_position as u32)
93        } else {
94            None
95        }
96    }
97
98    fn set_buffer_index(&mut self, index: u32) {
99        self.codebook_position = index as usize;
100    }
101
102    fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> {
103        debug_assert!(!bytes.is_empty());
104        debug_assert!(self.codebook.is_none());
105        let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?;
106        self.codebook = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
107        Ok(())
108    }
109
110    fn extra_metadata(&self) -> Result<Option<Bytes>> {
111        debug_assert!(self.codebook.is_some());
112        let codebook_tensor: pb::Tensor = pb::Tensor::try_from(self.codebook.as_ref().unwrap())?;
113        let mut bytes = BytesMut::new();
114        codebook_tensor.encode(&mut bytes)?;
115        Ok(Some(bytes.freeze()))
116    }
117
118    async fn load(reader: &PreviousFileReader) -> Result<Self> {
119        let metadata = reader
120            .schema()
121            .metadata
122            .get(PQ_METADATA_KEY)
123            .ok_or(Error::Index {
124                message: format!(
125                    "Reading PQ storage: metadata key {} not found",
126                    PQ_METADATA_KEY
127                ),
128                location: location!(),
129            })?;
130        let mut metadata: Self = serde_json::from_str(metadata).map_err(|_| Error::Index {
131            message: format!("Failed to parse PQ metadata: {}", metadata),
132            location: location!(),
133        })?;
134
135        debug_assert!(metadata.codebook.is_none());
136        debug_assert!(metadata.codebook_tensor.is_empty());
137
138        let codebook_tensor: pb::Tensor =
139            read_message(reader.object_reader.as_ref(), metadata.codebook_position).await?;
140        metadata.codebook = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
141        Ok(metadata)
142    }
143}
144
145/// Product Quantization Storage
146///
147/// It stores PQ code, as well as the row ID to the original vectors.
148///
149/// It is possible to store additional metadata to accelerate filtering later.
150#[derive(Clone, Debug)]
151pub struct ProductQuantizationStorage {
152    metadata: ProductQuantizationMetadata,
153    distance_type: DistanceType,
154    batch: RecordBatch,
155
156    // For easy access
157    pq_code: Arc<UInt8Array>,
158    row_ids: Arc<UInt64Array>,
159}
160
161impl DeepSizeOf for ProductQuantizationStorage {
162    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
163        self.batch.get_array_memory_size()
164            + self
165                .metadata
166                .codebook
167                .as_ref()
168                .map(|codebook| codebook.get_array_memory_size())
169                .unwrap_or(0)
170    }
171}
172
173impl PartialEq for ProductQuantizationStorage {
174    fn eq(&self, other: &Self) -> bool {
175        self.distance_type == other.distance_type
176            && self.metadata.eq(&other.metadata)
177            && self.batch.columns().eq(other.batch.columns())
178    }
179}
180
181impl ProductQuantizationStorage {
182    #[allow(clippy::too_many_arguments)]
183    pub fn new(
184        codebook: FixedSizeListArray,
185        mut batch: RecordBatch,
186        num_bits: u32,
187        num_sub_vectors: usize,
188        dimension: usize,
189        distance_type: DistanceType,
190        transposed: bool,
191        frag_reuse_index: Option<Arc<FragReuseIndex>>,
192    ) -> Result<Self> {
193        if batch.num_columns() != 2 {
194            log::warn!(
195                "PQ storage should have 2 columns, but got {} columns: {}",
196                batch.num_columns(),
197                batch.schema(),
198            );
199            batch = batch.project(&[
200                batch.schema().index_of(ROW_ID)?,
201                batch.schema().index_of(PQ_CODE_COLUMN)?,
202            ])?;
203        }
204
205        let Some(row_ids) = batch.column_by_name(ROW_ID) else {
206            return Err(Error::Index {
207                message: "Row ID column not found from PQ storage".to_string(),
208                location: location!(),
209            });
210        };
211        let row_ids: Arc<UInt64Array> = row_ids
212            .as_primitive_opt::<UInt64Type>()
213            .ok_or(Error::Index {
214                message: "Row ID column is not of type UInt64".to_string(),
215                location: location!(),
216            })?
217            .clone()
218            .into();
219
220        if !transposed {
221            let num_sub_vectors_in_byte = if num_bits == 4 {
222                num_sub_vectors / 2
223            } else {
224                num_sub_vectors
225            };
226            let pq_col = batch[PQ_CODE_COLUMN].as_fixed_size_list();
227            let transposed_code = transpose(
228                pq_col.values().as_primitive::<UInt8Type>(),
229                row_ids.len(),
230                num_sub_vectors_in_byte,
231            );
232            let pq_code_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
233                transposed_code,
234                num_sub_vectors_in_byte as i32,
235            )?);
236            batch = batch.replace_column_by_name(PQ_CODE_COLUMN, pq_code_fsl)?;
237        }
238
239        let mut pq_code: Arc<UInt8Array> = batch[PQ_CODE_COLUMN]
240            .as_fixed_size_list()
241            .values()
242            .as_primitive()
243            .clone()
244            .into();
245
246        if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
247            let transposed_codes = pq_code.values();
248            let mut new_row_ids = Vec::with_capacity(row_ids.len());
249            let mut new_codes = Vec::with_capacity(row_ids.len() * num_sub_vectors);
250
251            let row_ids_values = row_ids.values();
252            for (i, row_id) in row_ids_values.iter().enumerate() {
253                if let Some(mapped_value) = frag_reuse_index_ref.remap_row_id(*row_id) {
254                    new_row_ids.push(mapped_value);
255                    new_codes.extend(get_pq_code(
256                        transposed_codes,
257                        num_bits,
258                        num_sub_vectors,
259                        i as u32,
260                    ));
261                }
262            }
263
264            let new_row_ids = Arc::new(UInt64Array::from(new_row_ids));
265            let new_codes = UInt8Array::from(new_codes);
266            batch = if new_row_ids.is_empty() {
267                RecordBatch::new_empty(batch.schema())
268            } else {
269                let num_bytes_in_code = new_codes.len() / new_row_ids.len();
270                let new_transposed_codes =
271                    transpose(&new_codes, new_row_ids.len(), num_bytes_in_code);
272                let codes_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
273                    new_transposed_codes,
274                    num_bytes_in_code as i32,
275                )?);
276                RecordBatch::try_new(batch.schema(), vec![new_row_ids, codes_fsl])?
277            };
278            pq_code = batch[PQ_CODE_COLUMN]
279                .as_fixed_size_list()
280                .values()
281                .as_primitive::<UInt8Type>()
282                .clone()
283                .into();
284        }
285
286        let distance_type = match distance_type {
287            DistanceType::Cosine => DistanceType::L2,
288            _ => distance_type,
289        };
290        let metadata = ProductQuantizationMetadata {
291            codebook_position: 0,
292            nbits: num_bits,
293            num_sub_vectors,
294            dimension,
295            codebook: Some(codebook),
296            codebook_tensor: Vec::new(), // empty for v1 format
297            transposed: true,
298        };
299        Ok(Self {
300            metadata,
301            distance_type,
302            batch,
303            pq_code,
304            row_ids,
305        })
306    }
307
308    pub fn batch(&self) -> &RecordBatch {
309        &self.batch
310    }
311
312    /// Build a PQ storage from ProductQuantizer and a RecordBatch.
313    ///
314    /// Parameters
315    /// ----------
316    /// quantizer: ProductQuantizer
317    ///    The quantizer used to transform the vectors.
318    /// batch: RecordBatch
319    ///   The batch of vectors to be transformed.
320    /// vector_col: &str
321    ///   The name of the column containing the vectors.
322    pub async fn build(
323        quantizer: ProductQuantizer,
324        batch: &RecordBatch,
325        vector_col: &str,
326        frag_reuse_index: Option<Arc<FragReuseIndex>>,
327    ) -> Result<Self> {
328        let codebook = quantizer.codebook.clone();
329        let num_bits = quantizer.num_bits;
330        let dimension = quantizer.dimension;
331        let num_sub_vectors = quantizer.num_sub_vectors;
332        let metric_type = quantizer.distance_type;
333        let transform = PQTransformer::new(quantizer, vector_col, PQ_CODE_COLUMN);
334        let batch = transform.transform(batch)?;
335        Self::new(
336            codebook,
337            batch,
338            num_bits,
339            num_sub_vectors,
340            dimension,
341            metric_type,
342            false,
343            frag_reuse_index,
344        )
345    }
346
347    pub fn codebook(&self) -> &FixedSizeListArray {
348        self.metadata.codebook.as_ref().unwrap()
349    }
350
351    /// Load full PQ storage from disk.
352    ///
353    /// Parameters
354    /// ----------
355    /// object_store: &ObjectStore
356    ///   The object store to load the storage from.
357    /// path: &Path
358    ///  The path to the storage.
359    ///
360    /// Returns
361    /// --------
362    /// Self
363    ///
364    /// Currently it loads everything in memory.
365    /// TODO: support lazy loading later.
366    pub async fn load(
367        object_store: &ObjectStore,
368        path: &Path,
369        frag_reuse_index: Option<Arc<FragReuseIndex>>,
370    ) -> Result<Self> {
371        let reader = PreviousFileReader::try_new_self_described(object_store, path, None).await?;
372        let schema = reader.schema();
373
374        let metadata_str = schema
375            .metadata
376            .get(INDEX_METADATA_SCHEMA_KEY)
377            .ok_or(Error::Index {
378                message: format!(
379                    "Reading PQ storage: index key {} not found",
380                    INDEX_METADATA_SCHEMA_KEY
381                ),
382                location: location!(),
383            })?;
384        let index_metadata: IndexMetadata =
385            serde_json::from_str(metadata_str).map_err(|_| Error::Index {
386                message: format!("Failed to parse index metadata: {}", metadata_str),
387                location: location!(),
388            })?;
389        let distance_type: DistanceType =
390            DistanceType::try_from(index_metadata.distance_type.as_str())?;
391
392        let metadata = ProductQuantizationMetadata::load(&reader).await?;
393        Self::load_partition(
394            &reader,
395            0..reader.len(),
396            distance_type,
397            &metadata,
398            frag_reuse_index,
399        )
400        .await
401    }
402
403    pub fn schema(&self) -> SchemaRef {
404        self.batch.schema()
405    }
406
407    pub fn get_row_ids(&self, ids: &[u32]) -> Vec<u64> {
408        ids.iter()
409            .map(|&id| self.row_ids.value(id as usize))
410            .collect()
411    }
412
413    /// Write the PQ storage as a Lance partition to disk,
414    /// and returns the number of rows written.
415    ///
416    pub async fn write_partition(
417        &self,
418        writer: &mut PreviousFileWriter<ManifestDescribing>,
419    ) -> Result<usize> {
420        let batch_size: usize = 10240; // TODO: make it configurable
421        for offset in (0..self.batch.num_rows()).step_by(batch_size) {
422            let length = min(batch_size, self.batch.num_rows() - offset);
423            let slice = self.batch.slice(offset, length);
424            writer.write(&[slice]).await?;
425        }
426        Ok(self.batch.num_rows())
427    }
428}
429
430pub fn transpose<T: ArrowPrimitiveType>(
431    original: &PrimitiveArray<T>,
432    num_rows: usize,
433    num_columns: usize,
434) -> PrimitiveArray<T>
435where
436    PrimitiveArray<T>: From<Vec<T::Native>>,
437{
438    if original.is_empty() {
439        return original.clone();
440    }
441
442    let mut transposed_codes = vec![T::default_value(); original.len()];
443    for (vec_idx, codes) in original.values().chunks_exact(num_columns).enumerate() {
444        for (sub_vec_idx, code) in codes.iter().enumerate() {
445            transposed_codes[sub_vec_idx * num_rows + vec_idx] = *code;
446        }
447    }
448
449    transposed_codes.into()
450}
451
452#[async_trait]
453impl QuantizerStorage for ProductQuantizationStorage {
454    type Metadata = ProductQuantizationMetadata;
455
456    fn try_from_batch(
457        batch: RecordBatch,
458        metadata: &Self::Metadata,
459        distance_type: DistanceType,
460        frag_reuse_index: Option<Arc<FragReuseIndex>>,
461    ) -> Result<Self>
462    where
463        Self: Sized,
464    {
465        let distance_type = match distance_type {
466            DistanceType::Cosine => DistanceType::L2,
467            _ => distance_type,
468        };
469
470        // now it supports only Float32Type
471        let codebook = match &metadata.codebook {
472            Some(codebook) => codebook.clone(),
473            None => {
474                // legacy format would contains codebook tensor but not codebook
475                debug_assert!(!metadata.codebook_tensor.is_empty());
476                let codebook_tensor = pb::Tensor::decode(metadata.codebook_tensor.as_slice())?;
477                FixedSizeListArray::try_from(&codebook_tensor)?
478            }
479        };
480
481        Self::new(
482            codebook,
483            batch,
484            metadata.nbits,
485            metadata.num_sub_vectors,
486            metadata.dimension,
487            distance_type,
488            metadata.transposed,
489            frag_reuse_index,
490        )
491    }
492
493    fn metadata(&self) -> &Self::Metadata {
494        &self.metadata
495    }
496
497    // we can't use the default implementation of remap,
498    // because PQ Storage transposed the PQ codes
499    fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
500        let transposed_codes = self.pq_code.values();
501        let mut new_row_ids = Vec::with_capacity(self.len());
502        let mut new_codes = Vec::with_capacity(self.len() * self.metadata.num_sub_vectors);
503
504        let row_ids = self.row_ids.values();
505        for (i, row_id) in row_ids.iter().enumerate() {
506            match mapping.get(row_id) {
507                Some(Some(new_id)) => {
508                    new_row_ids.push(*new_id);
509                    new_codes.extend(get_pq_code(
510                        transposed_codes,
511                        self.metadata.nbits,
512                        self.metadata.num_sub_vectors,
513                        i as u32,
514                    ));
515                }
516                Some(None) => {}
517                None => {
518                    new_row_ids.push(*row_id);
519                    new_codes.extend(get_pq_code(
520                        transposed_codes,
521                        self.metadata.nbits,
522                        self.metadata.num_sub_vectors,
523                        i as u32,
524                    ));
525                }
526            }
527        }
528
529        let new_row_ids = Arc::new(UInt64Array::from(new_row_ids));
530        let new_codes = UInt8Array::from(new_codes);
531        let batch = if new_row_ids.is_empty() {
532            RecordBatch::new_empty(self.schema())
533        } else {
534            let num_bytes_in_code = new_codes.len() / new_row_ids.len();
535            let new_transposed_codes = transpose(&new_codes, new_row_ids.len(), num_bytes_in_code);
536            let codes_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
537                new_transposed_codes,
538                num_bytes_in_code as i32,
539            )?);
540            RecordBatch::try_new(self.schema(), vec![new_row_ids.clone(), codes_fsl])?
541        };
542        let transposed_codes = batch[PQ_CODE_COLUMN]
543            .as_fixed_size_list()
544            .values()
545            .as_primitive::<UInt8Type>()
546            .clone();
547
548        Ok(Self {
549            metadata: self.metadata.clone(),
550            distance_type: self.distance_type,
551            batch,
552            pq_code: Arc::new(transposed_codes),
553            row_ids: new_row_ids,
554        })
555    }
556
557    /// Load a partition of PQ storage from disk.
558    ///
559    /// Parameters
560    /// ----------
561    /// - *reader: &PreviousFileReader
562    async fn load_partition(
563        reader: &PreviousFileReader,
564        range: std::ops::Range<usize>,
565        distance_type: DistanceType,
566        metadata: &Self::Metadata,
567        frag_reuse_index: Option<Arc<FragReuseIndex>>,
568    ) -> Result<Self> {
569        // Hard coded to float32 for now
570        let codebook = metadata
571            .codebook
572            .as_ref()
573            .ok_or(Error::Index {
574                message: "Codebook not found in PQ metadata".to_string(),
575                location: location!(),
576            })?
577            .values()
578            .as_primitive::<Float32Type>()
579            .clone();
580
581        let codebook =
582            FixedSizeListArray::try_new_from_values(codebook, metadata.dimension as i32)?;
583
584        let schema = reader.schema();
585        let batch = reader.read_range(range, schema).await?;
586
587        Self::new(
588            codebook,
589            batch,
590            metadata.nbits,
591            metadata.num_sub_vectors,
592            metadata.dimension,
593            distance_type,
594            metadata.transposed,
595            frag_reuse_index,
596        )
597    }
598}
599
600impl VectorStore for ProductQuantizationStorage {
601    type DistanceCalculator<'a> = PQDistCalculator;
602
603    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
604        Ok(std::iter::once(self.batch.clone()))
605    }
606
607    fn append_batch(&self, _batch: RecordBatch, _vector_column: &str) -> Result<Self> {
608        unimplemented!()
609    }
610
611    fn schema(&self) -> &SchemaRef {
612        self.batch.schema_ref()
613    }
614
615    fn as_any(&self) -> &dyn std::any::Any {
616        self
617    }
618
619    fn len(&self) -> usize {
620        self.batch.num_rows()
621    }
622
623    fn distance_type(&self) -> DistanceType {
624        self.distance_type
625    }
626
627    fn row_id(&self, id: u32) -> u64 {
628        self.row_ids.values()[id as usize]
629    }
630
631    fn row_ids(&self) -> impl Iterator<Item = &u64> {
632        self.row_ids.values().iter()
633    }
634
635    fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
636        let codebook = self.metadata.codebook.as_ref().unwrap();
637        match codebook.value_type() {
638            DataType::Float16 => PQDistCalculator::new(
639                codebook
640                    .values()
641                    .as_primitive::<datatypes::Float16Type>()
642                    .values(),
643                self.metadata.nbits,
644                self.metadata.num_sub_vectors,
645                self.pq_code.clone(),
646                query.as_primitive::<datatypes::Float16Type>().values(),
647                self.distance_type,
648            ),
649            DataType::Float32 => PQDistCalculator::new(
650                codebook
651                    .values()
652                    .as_primitive::<datatypes::Float32Type>()
653                    .values(),
654                self.metadata.nbits,
655                self.metadata.num_sub_vectors,
656                self.pq_code.clone(),
657                query.as_primitive::<datatypes::Float32Type>().values(),
658                self.distance_type,
659            ),
660            DataType::Float64 => PQDistCalculator::new(
661                codebook
662                    .values()
663                    .as_primitive::<datatypes::Float64Type>()
664                    .values(),
665                self.metadata.nbits,
666                self.metadata.num_sub_vectors,
667                self.pq_code.clone(),
668                query.as_primitive::<datatypes::Float64Type>().values(),
669                self.distance_type,
670            ),
671            _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
672        }
673    }
674
675    fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
676        let codes = get_pq_code(
677            self.pq_code.values(),
678            self.metadata.nbits,
679            self.metadata.num_sub_vectors,
680            id,
681        );
682        let codebook = self.metadata.codebook.as_ref().unwrap();
683        match codebook.value_type() {
684            DataType::Float16 => {
685                let codebook = codebook
686                    .values()
687                    .as_primitive::<datatypes::Float16Type>()
688                    .values();
689                let query = get_centroids(
690                    codebook,
691                    self.metadata.nbits,
692                    self.metadata.num_sub_vectors,
693                    self.metadata.dimension,
694                    codes,
695                );
696                PQDistCalculator::new(
697                    codebook,
698                    self.metadata.nbits,
699                    self.metadata.num_sub_vectors,
700                    self.pq_code.clone(),
701                    &query,
702                    self.distance_type,
703                )
704            }
705            DataType::Float32 => {
706                let codebook = codebook
707                    .values()
708                    .as_primitive::<datatypes::Float32Type>()
709                    .values();
710                let query = get_centroids(
711                    codebook,
712                    self.metadata.nbits,
713                    self.metadata.num_sub_vectors,
714                    self.metadata.dimension,
715                    codes,
716                );
717                PQDistCalculator::new(
718                    codebook,
719                    self.metadata.nbits,
720                    self.metadata.num_sub_vectors,
721                    self.pq_code.clone(),
722                    &query,
723                    self.distance_type,
724                )
725            }
726            DataType::Float64 => {
727                let codebook = codebook
728                    .values()
729                    .as_primitive::<datatypes::Float64Type>()
730                    .values();
731                let query = get_centroids(
732                    codebook,
733                    self.metadata.nbits,
734                    self.metadata.num_sub_vectors,
735                    self.metadata.dimension,
736                    codes,
737                );
738                PQDistCalculator::new(
739                    codebook,
740                    self.metadata.nbits,
741                    self.metadata.num_sub_vectors,
742                    self.pq_code.clone(),
743                    &query,
744                    self.distance_type,
745                )
746            }
747            _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
748        }
749    }
750
751    fn dist_between(&self, u: u32, v: u32) -> f32 {
752        // this is a fast way to compute distance between two vectors in the same storage.
753        // it doesn't construct the distance table.
754        let pq_codes = self.pq_code.values();
755        let u_codes = get_pq_code(
756            pq_codes,
757            self.metadata.nbits,
758            self.metadata.num_sub_vectors,
759            u,
760        );
761        let v_codes = get_pq_code(
762            pq_codes,
763            self.metadata.nbits,
764            self.metadata.num_sub_vectors,
765            v,
766        );
767        let codebook = self.metadata.codebook.as_ref().unwrap();
768
769        match codebook.value_type() {
770            DataType::Float16 => {
771                let qu = get_centroids(
772                    codebook
773                        .values()
774                        .as_primitive::<datatypes::Float16Type>()
775                        .values(),
776                    self.metadata.nbits,
777                    self.metadata.num_sub_vectors,
778                    self.metadata.dimension,
779                    u_codes,
780                );
781                let qv = get_centroids(
782                    codebook
783                        .values()
784                        .as_primitive::<datatypes::Float16Type>()
785                        .values(),
786                    self.metadata.nbits,
787                    self.metadata.num_sub_vectors,
788                    self.metadata.dimension,
789                    v_codes,
790                );
791                self.distance_type.func()(&qu, &qv)
792            }
793            DataType::Float32 => {
794                let qu = get_centroids(
795                    codebook
796                        .values()
797                        .as_primitive::<datatypes::Float32Type>()
798                        .values(),
799                    self.metadata.nbits,
800                    self.metadata.num_sub_vectors,
801                    self.metadata.dimension,
802                    u_codes,
803                );
804                let qv = get_centroids(
805                    codebook
806                        .values()
807                        .as_primitive::<datatypes::Float32Type>()
808                        .values(),
809                    self.metadata.nbits,
810                    self.metadata.num_sub_vectors,
811                    self.metadata.dimension,
812                    v_codes,
813                );
814                self.distance_type.func()(&qu, &qv)
815            }
816            DataType::Float64 => {
817                let qu = get_centroids(
818                    codebook
819                        .values()
820                        .as_primitive::<datatypes::Float64Type>()
821                        .values(),
822                    self.metadata.nbits,
823                    self.metadata.num_sub_vectors,
824                    self.metadata.dimension,
825                    u_codes,
826                );
827                let qv = get_centroids(
828                    codebook
829                        .values()
830                        .as_primitive::<datatypes::Float64Type>()
831                        .values(),
832                    self.metadata.nbits,
833                    self.metadata.num_sub_vectors,
834                    self.metadata.dimension,
835                    v_codes,
836                );
837                self.distance_type.func()(&qu, &qv)
838            }
839            _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
840        }
841    }
842}
843
844/// Distance calculator backed by PQ code.
845pub struct PQDistCalculator {
846    distance_table: Vec<f32>,
847    pq_code: Arc<UInt8Array>,
848    num_sub_vectors: usize,
849    num_bits: u32,
850    distance_type: DistanceType,
851}
852
853impl PQDistCalculator {
854    fn new<T: L2 + Dot>(
855        codebook: &[T],
856        num_bits: u32,
857        num_sub_vectors: usize,
858        pq_code: Arc<UInt8Array>,
859        query: &[T],
860        distance_type: DistanceType,
861    ) -> Self {
862        let distance_table = match distance_type {
863            DistanceType::L2 | DistanceType::Cosine => {
864                build_distance_table_l2(codebook, num_bits, num_sub_vectors, query)
865            }
866            DistanceType::Dot => {
867                build_distance_table_dot(codebook, num_bits, num_sub_vectors, query)
868            }
869            _ => unimplemented!("DistanceType is not supported: {:?}", distance_type),
870        };
871        Self {
872            distance_table,
873            num_sub_vectors,
874            pq_code,
875            num_bits,
876            distance_type,
877        }
878    }
879
880    fn get_pq_code(&self, id: u32) -> impl Iterator<Item = usize> + '_ {
881        get_pq_code(
882            self.pq_code.values(),
883            self.num_bits,
884            self.num_sub_vectors,
885            id,
886        )
887        .map(|v| v as usize)
888    }
889}
890
891impl DistCalculator for PQDistCalculator {
892    fn distance(&self, id: u32) -> f32 {
893        let num_centroids = 2_usize.pow(self.num_bits);
894        let pq_code = self.get_pq_code(id);
895        let diff = self.num_sub_vectors as f32 - 1.0;
896        let dist = if self.num_bits == 4 {
897            pq_code
898                .enumerate()
899                .map(|(i, c)| {
900                    let current_idx = c & 0x0F;
901                    let next_idx = c >> 4;
902
903                    self.distance_table[2 * i * num_centroids + current_idx]
904                        + self.distance_table[(2 * i + 1) * num_centroids + next_idx]
905                })
906                .sum()
907        } else {
908            pq_code
909                .enumerate()
910                .map(|(i, c)| self.distance_table[i * num_centroids + c])
911                .sum()
912        };
913
914        if self.distance_type == DistanceType::Dot {
915            dist - diff
916        } else {
917            dist
918        }
919    }
920
921    fn distance_all(&self, k_hint: usize) -> Vec<f32> {
922        match self.distance_type {
923            DistanceType::L2 => compute_pq_distance(
924                &self.distance_table,
925                self.num_bits,
926                self.num_sub_vectors,
927                self.pq_code.values(),
928                k_hint,
929            ),
930            DistanceType::Cosine => {
931                // it seems we implemented cosine distance at some version,
932                // but from now on, we should use normalized L2 distance.
933                debug_assert!(
934                    false,
935                    "cosine distance should be converted to normalized L2 distance"
936                );
937                // L2 over normalized vectors:  ||x - y|| = x^2 + y^2 - 2 * xy = 1 + 1 - 2 * xy = 2 * (1 - xy)
938                // Cosine distance: 1 - |xy| / (||x|| * ||y||) = 1 - xy / (x^2 * y^2) = 1 - xy / (1 * 1) = 1 - xy
939                // Therefore, Cosine = L2 / 2
940                let l2_dists = compute_pq_distance(
941                    &self.distance_table,
942                    self.num_bits,
943                    self.num_sub_vectors,
944                    self.pq_code.values(),
945                    k_hint,
946                );
947                l2_dists.into_iter().map(|v| v / 2.0).collect()
948            }
949            DistanceType::Dot => {
950                let dot_dists = compute_pq_distance(
951                    &self.distance_table,
952                    self.num_bits,
953                    self.num_sub_vectors,
954                    self.pq_code.values(),
955                    k_hint,
956                );
957                let diff = self.num_sub_vectors as f32 - 1.0;
958                dot_dists.into_iter().map(|v| v - diff).collect()
959            }
960            _ => unimplemented!("distance type is not supported: {:?}", self.distance_type),
961        }
962    }
963}
964
965fn get_pq_code(
966    pq_code: &[u8],
967    num_bits: u32,
968    num_sub_vectors: usize,
969    id: u32,
970) -> impl Iterator<Item = u8> + '_ {
971    let num_bytes = if num_bits == 4 {
972        num_sub_vectors / 2
973    } else {
974        num_sub_vectors
975    };
976
977    let num_vectors = pq_code.len() / num_bytes;
978    pq_code
979        .iter()
980        .skip(id as usize)
981        .step_by(num_vectors)
982        .copied()
983        .exact_size(num_bytes)
984}
985
986fn get_centroids<T: Clone>(
987    codebook: &[T],
988    num_bits: u32,
989    num_sub_vectors: usize,
990    dimension: usize,
991    codes: impl Iterator<Item = u8>,
992) -> Vec<T> {
993    // codebook[i][j] is the j-th centroid of the i-th sub-vector.
994    // the codebook is stored as a flat array, codebook[i * num_centroids + j] = codebook[i][j]
995
996    if num_bits == 4 {
997        return get_centroids_4bit(codebook, num_sub_vectors, dimension, codes);
998    }
999
1000    let num_centroids: usize = 2_usize.pow(8);
1001    let sub_vector_width = dimension / num_sub_vectors;
1002    let mut centroids = Vec::with_capacity(dimension);
1003    for (sub_vec_idx, centroid_idx) in codes.enumerate() {
1004        let centroid_idx = centroid_idx as usize;
1005        let centroid = &codebook[sub_vec_idx * num_centroids * sub_vector_width
1006            + centroid_idx * sub_vector_width
1007            ..sub_vec_idx * num_centroids * sub_vector_width
1008                + (centroid_idx + 1) * sub_vector_width];
1009        centroids.extend_from_slice(centroid);
1010    }
1011    centroids
1012}
1013
1014fn get_centroids_4bit<T: Clone>(
1015    codebook: &[T],
1016    num_sub_vectors: usize,
1017    dimension: usize,
1018    codes: impl Iterator<Item = u8>,
1019) -> Vec<T> {
1020    let num_centroids: usize = 16;
1021    let sub_vector_width = dimension / num_sub_vectors;
1022    let mut centroids = Vec::with_capacity(dimension);
1023    for (sub_vec_idx, centroid_idx) in codes.into_iter().enumerate() {
1024        let current_idx = (centroid_idx & 0x0F) as usize;
1025        let offset = 2 * sub_vec_idx * num_centroids * sub_vector_width;
1026        let current_centroid = &codebook[offset + current_idx * sub_vector_width
1027            ..offset + (current_idx + 1) * sub_vector_width];
1028        centroids.extend_from_slice(current_centroid);
1029
1030        let next_idx = (centroid_idx >> 4) as usize;
1031        let offset = (2 * sub_vec_idx + 1) * num_centroids * sub_vector_width;
1032        let next_centroid = &codebook
1033            [offset + next_idx * sub_vector_width..offset + (next_idx + 1) * sub_vector_width];
1034        centroids.extend_from_slice(next_centroid);
1035    }
1036    centroids
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041    use crate::vector::storage::StorageBuilder;
1042
1043    use super::*;
1044
1045    use arrow_array::{Float32Array, UInt32Array};
1046    use arrow_schema::{DataType, Field, Schema as ArrowSchema};
1047    use lance_arrow::FixedSizeListArrayExt;
1048    use lance_core::ROW_ID_FIELD;
1049    use rand::Rng;
1050
1051    const DIM: usize = 32;
1052    const TOTAL: usize = 512;
1053    const NUM_SUB_VECTORS: usize = 16;
1054
1055    async fn create_pq_storage() -> ProductQuantizationStorage {
1056        let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random()));
1057        let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap();
1058        let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot);
1059
1060        let schema = ArrowSchema::new(vec![
1061            Field::new(
1062                "vec",
1063                DataType::FixedSizeList(
1064                    Field::new_list_field(DataType::Float32, true).into(),
1065                    DIM as i32,
1066                ),
1067                true,
1068            ),
1069            ROW_ID_FIELD.clone(),
1070        ]);
1071        let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random()));
1072        let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
1073        let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
1074        let batch =
1075            RecordBatch::try_new(schema.into(), vec![Arc::new(fsl), Arc::new(row_ids)]).unwrap();
1076
1077        StorageBuilder::new("vec".to_owned(), pq.distance_type, pq, None)
1078            .unwrap()
1079            .build(vec![batch])
1080            .unwrap()
1081    }
1082
1083    async fn create_pq_storage_with_extra_column() -> ProductQuantizationStorage {
1084        let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random()));
1085        let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap();
1086        let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot);
1087
1088        let schema = ArrowSchema::new(vec![
1089            Field::new(
1090                "vec",
1091                DataType::FixedSizeList(
1092                    Field::new_list_field(DataType::Float32, true).into(),
1093                    DIM as i32,
1094                ),
1095                true,
1096            ),
1097            ROW_ID_FIELD.clone(),
1098            Field::new("extra", DataType::UInt32, true),
1099        ]);
1100        let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random()));
1101        let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
1102        let extra_column = UInt32Array::from_iter_values((0..TOTAL).map(|v| v as u32));
1103        let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
1104        let batch = RecordBatch::try_new(
1105            schema.into(),
1106            vec![Arc::new(fsl), Arc::new(row_ids), Arc::new(extra_column)],
1107        )
1108        .unwrap();
1109
1110        StorageBuilder::new("vec".to_owned(), pq.distance_type, pq, None)
1111            .unwrap()
1112            .build(vec![batch])
1113            .unwrap()
1114    }
1115
1116    #[tokio::test]
1117    async fn test_build_pq_storage() {
1118        let storage = create_pq_storage().await;
1119        assert_eq!(storage.len(), TOTAL);
1120        assert_eq!(storage.metadata.num_sub_vectors, NUM_SUB_VECTORS);
1121        assert_eq!(
1122            storage.metadata.codebook.as_ref().unwrap().values().len(),
1123            256 * DIM
1124        );
1125        assert_eq!(storage.pq_code.len(), TOTAL * NUM_SUB_VECTORS);
1126        assert_eq!(storage.row_ids.len(), TOTAL);
1127    }
1128
1129    #[tokio::test]
1130    async fn test_distance_all() {
1131        let storage = create_pq_storage().await;
1132        let query = Arc::new(Float32Array::from_iter_values((0..DIM).map(|v| v as f32)));
1133        let dist_calc = storage.dist_calculator(query, 0.0);
1134        let expected = (0..storage.len())
1135            .map(|id| dist_calc.distance(id as u32))
1136            .collect::<Vec<_>>();
1137        let distances = dist_calc.distance_all(100);
1138        assert_eq!(distances, expected);
1139    }
1140
1141    #[tokio::test]
1142    async fn test_dist_between() {
1143        let mut rng = rand::rng();
1144        let storage = create_pq_storage().await;
1145        let u = rng.random_range(0..storage.len() as u32);
1146        let v = rng.random_range(0..storage.len() as u32);
1147        let dist1 = storage.dist_between(u, v);
1148        let dist2 = storage.dist_between(v, u);
1149        assert_eq!(dist1, dist2);
1150    }
1151
1152    #[tokio::test]
1153    async fn test_remap_with_extra_column() {
1154        let storage = create_pq_storage_with_extra_column().await;
1155        let mut mapping = HashMap::new();
1156        for i in 0..TOTAL / 2 {
1157            mapping.insert(i as u64, Some((TOTAL + i) as u64));
1158        }
1159        for i in TOTAL / 2..TOTAL {
1160            mapping.insert(i as u64, None);
1161        }
1162        let new_storage = storage.remap(&mapping).unwrap();
1163        assert_eq!(new_storage.len(), TOTAL / 2);
1164        assert_eq!(new_storage.row_ids.len(), TOTAL / 2);
1165        for (i, row_id) in new_storage.row_ids().enumerate() {
1166            assert_eq!(*row_id, (TOTAL + i) as u64);
1167        }
1168        assert_eq!(new_storage.batch.num_columns(), 2);
1169        assert!(new_storage.batch.column_by_name(ROW_ID).is_some());
1170        assert!(new_storage.batch.column_by_name(PQ_CODE_COLUMN).is_some());
1171    }
1172}