Skip to main content

lance_index/vector/pq/
storage.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Product Quantization storage
5//!
6//! Used as storage backend for Graph based algorithms.
7
8use std::{cmp::min, collections::HashMap, sync::Arc};
9
10use arrow::datatypes::{self, UInt8Type};
11use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
12use arrow_array::{
13    FixedSizeListArray, RecordBatch, UInt8Array, UInt64Array,
14    cast::AsArray,
15    types::{Float32Type, UInt64Type},
16};
17use arrow_schema::{DataType, SchemaRef};
18use async_trait::async_trait;
19use bytes::{Bytes, BytesMut};
20use deepsize::DeepSizeOf;
21use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
22use lance_core::{Error, ROW_ID, Result};
23use lance_file::previous::{
24    reader::FileReader as PreviousFileReader, writer::FileWriter as PreviousFileWriter,
25};
26use lance_io::{object_store::ObjectStore, utils::read_message};
27use lance_linalg::distance::{DistanceType, Dot, L2};
28use lance_table::utils::LanceIteratorExtension;
29use lance_table::{format::SelfDescribingFileReader, io::manifest::ManifestDescribing};
30use object_store::path::Path;
31use prost::Message;
32use serde::{Deserialize, Serialize};
33
34use super::ProductQuantizer;
35use super::distance::{build_distance_table_dot, build_distance_table_l2, compute_pq_distance};
36use crate::frag_reuse::FragReuseIndex;
37use crate::vector::graph::{OrderedFloat, OrderedNode};
38use crate::{
39    INDEX_METADATA_SCHEMA_KEY, IndexMetadata, pb,
40    vector::{
41        PQ_CODE_COLUMN,
42        pq::transform::PQTransformer,
43        quantizer::{QuantizerMetadata, QuantizerStorage},
44        storage::{DistCalculator, VectorStore},
45        transform::Transformer,
46    },
47};
48
49pub const PQ_METADATA_KEY: &str = "lance:pq";
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ProductQuantizationMetadata {
53    pub codebook_position: usize,
54    pub nbits: u32,
55    pub num_sub_vectors: usize,
56    pub dimension: usize,
57
58    #[serde(skip)]
59    pub codebook: Option<FixedSizeListArray>,
60
61    // empty for v1 format
62    // used for v3 format
63    // deprecated in later version
64    pub codebook_tensor: Vec<u8>,
65    pub transposed: bool,
66}
67
68impl DeepSizeOf for ProductQuantizationMetadata {
69    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
70        self.codebook
71            .as_ref()
72            .map(|codebook| codebook.get_array_memory_size())
73            .unwrap_or(0)
74    }
75}
76
77impl PartialEq for ProductQuantizationMetadata {
78    fn eq(&self, other: &Self) -> bool {
79        self.num_sub_vectors == other.num_sub_vectors
80            && self.nbits == other.nbits
81            && self.dimension == other.dimension
82            && self.codebook == other.codebook
83    }
84}
85
86#[async_trait]
87impl QuantizerMetadata for ProductQuantizationMetadata {
88    fn buffer_index(&self) -> Option<u32> {
89        if self.codebook_position > 0 {
90            // the global buffer index starts from 1
91            Some(self.codebook_position as u32)
92        } else {
93            None
94        }
95    }
96
97    fn set_buffer_index(&mut self, index: u32) {
98        self.codebook_position = index as usize;
99    }
100
101    fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> {
102        debug_assert!(!bytes.is_empty());
103        debug_assert!(self.codebook.is_none());
104        let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?;
105        self.codebook = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
106        Ok(())
107    }
108
109    fn extra_metadata(&self) -> Result<Option<Bytes>> {
110        if let Some(codebook) = &self.codebook {
111            let codebook_tensor: pb::Tensor = pb::Tensor::try_from(codebook)?;
112            let mut bytes = BytesMut::new();
113            codebook_tensor.encode(&mut bytes)?;
114            Ok(Some(bytes.freeze()))
115        } else if !self.codebook_tensor.is_empty() {
116            // Legacy format: codebook is stored inline in the metadata JSON.
117            // Return it as-is; it's already a protobuf-encoded Tensor that
118            // parse_buffer() can handle.
119            Ok(Some(Bytes::from(self.codebook_tensor.clone())))
120        } else {
121            Ok(None)
122        }
123    }
124
125    async fn load(reader: &PreviousFileReader) -> Result<Self> {
126        let metadata = reader
127            .schema()
128            .metadata
129            .get(PQ_METADATA_KEY)
130            .ok_or(Error::index(format!(
131                "Reading PQ storage: metadata key {} not found",
132                PQ_METADATA_KEY
133            )))?;
134        let mut metadata: Self = serde_json::from_str(metadata)
135            .map_err(|_| Error::index(format!("Failed to parse PQ metadata: {}", metadata)))?;
136
137        debug_assert!(metadata.codebook.is_none());
138        debug_assert!(metadata.codebook_tensor.is_empty());
139
140        let codebook_tensor: pb::Tensor =
141            read_message(reader.object_reader.as_ref(), metadata.codebook_position).await?;
142        metadata.codebook = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
143        Ok(metadata)
144    }
145}
146
147/// Product Quantization Storage
148///
149/// It stores PQ code, as well as the row ID to the original vectors.
150///
151/// It is possible to store additional metadata to accelerate filtering later.
152#[derive(Clone, Debug)]
153pub struct ProductQuantizationStorage {
154    metadata: ProductQuantizationMetadata,
155    distance_type: DistanceType,
156    batch: RecordBatch,
157
158    // For easy access
159    pq_code: Arc<UInt8Array>,
160    row_ids: Arc<UInt64Array>,
161}
162
163impl DeepSizeOf for ProductQuantizationStorage {
164    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
165        self.batch.get_array_memory_size()
166            + self
167                .metadata
168                .codebook
169                .as_ref()
170                .map(|codebook| codebook.get_array_memory_size())
171                .unwrap_or(0)
172    }
173}
174
175impl PartialEq for ProductQuantizationStorage {
176    fn eq(&self, other: &Self) -> bool {
177        self.distance_type == other.distance_type
178            && self.metadata.eq(&other.metadata)
179            && self.batch.columns().eq(other.batch.columns())
180    }
181}
182
183impl ProductQuantizationStorage {
184    #[allow(clippy::too_many_arguments)]
185    pub fn new(
186        codebook: FixedSizeListArray,
187        mut batch: RecordBatch,
188        num_bits: u32,
189        num_sub_vectors: usize,
190        dimension: usize,
191        distance_type: DistanceType,
192        transposed: bool,
193        frag_reuse_index: Option<Arc<FragReuseIndex>>,
194    ) -> Result<Self> {
195        if batch.num_columns() != 2 {
196            log::warn!(
197                "PQ storage should have 2 columns, but got {} columns: {}",
198                batch.num_columns(),
199                batch.schema(),
200            );
201            batch = batch.project(&[
202                batch.schema().index_of(ROW_ID)?,
203                batch.schema().index_of(PQ_CODE_COLUMN)?,
204            ])?;
205        }
206
207        let Some(row_ids) = batch.column_by_name(ROW_ID) else {
208            return Err(Error::index(
209                "Row ID column not found from PQ storage".to_string(),
210            ));
211        };
212        let row_ids: Arc<UInt64Array> = row_ids
213            .as_primitive_opt::<UInt64Type>()
214            .ok_or(Error::index(
215                "Row ID column is not of type UInt64".to_string(),
216            ))?
217            .clone()
218            .into();
219
220        if !transposed {
221            let num_sub_vectors_in_byte = if num_bits == 4 {
222                num_sub_vectors / 2
223            } else {
224                num_sub_vectors
225            };
226            let pq_col = batch[PQ_CODE_COLUMN].as_fixed_size_list();
227            let transposed_code = transpose(
228                pq_col.values().as_primitive::<UInt8Type>(),
229                row_ids.len(),
230                num_sub_vectors_in_byte,
231            );
232            let pq_code_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
233                transposed_code,
234                num_sub_vectors_in_byte as i32,
235            )?);
236            batch = batch.replace_column_by_name(PQ_CODE_COLUMN, pq_code_fsl)?;
237        }
238
239        let mut pq_code: Arc<UInt8Array> = batch[PQ_CODE_COLUMN]
240            .as_fixed_size_list()
241            .values()
242            .as_primitive()
243            .clone()
244            .into();
245
246        if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
247            let transposed_codes = pq_code.values();
248            let mut new_row_ids = Vec::with_capacity(row_ids.len());
249            let mut new_codes = Vec::with_capacity(row_ids.len() * num_sub_vectors);
250
251            let row_ids_values = row_ids.values();
252            for (i, row_id) in row_ids_values.iter().enumerate() {
253                if let Some(mapped_value) = frag_reuse_index_ref.remap_row_id(*row_id) {
254                    new_row_ids.push(mapped_value);
255                    new_codes.extend(get_pq_code(
256                        transposed_codes,
257                        num_bits,
258                        num_sub_vectors,
259                        i as u32,
260                    ));
261                }
262            }
263
264            let new_row_ids = Arc::new(UInt64Array::from(new_row_ids));
265            let new_codes = UInt8Array::from(new_codes);
266            batch = if new_row_ids.is_empty() {
267                RecordBatch::new_empty(batch.schema())
268            } else {
269                let num_bytes_in_code = new_codes.len() / new_row_ids.len();
270                let new_transposed_codes =
271                    transpose(&new_codes, new_row_ids.len(), num_bytes_in_code);
272                let codes_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
273                    new_transposed_codes,
274                    num_bytes_in_code as i32,
275                )?);
276                RecordBatch::try_new(batch.schema(), vec![new_row_ids, codes_fsl])?
277            };
278            pq_code = batch[PQ_CODE_COLUMN]
279                .as_fixed_size_list()
280                .values()
281                .as_primitive::<UInt8Type>()
282                .clone()
283                .into();
284        }
285
286        let distance_type = match distance_type {
287            DistanceType::Cosine => DistanceType::L2,
288            _ => distance_type,
289        };
290        let metadata = ProductQuantizationMetadata {
291            codebook_position: 0,
292            nbits: num_bits,
293            num_sub_vectors,
294            dimension,
295            codebook: Some(codebook),
296            codebook_tensor: Vec::new(), // empty for v1 format
297            transposed: true,
298        };
299        Ok(Self {
300            metadata,
301            distance_type,
302            batch,
303            pq_code,
304            row_ids,
305        })
306    }
307
308    pub fn batch(&self) -> &RecordBatch {
309        &self.batch
310    }
311
312    /// Build a PQ storage from ProductQuantizer and a RecordBatch.
313    ///
314    /// Parameters
315    /// ----------
316    /// quantizer: ProductQuantizer
317    ///    The quantizer used to transform the vectors.
318    /// batch: RecordBatch
319    ///   The batch of vectors to be transformed.
320    /// vector_col: &str
321    ///   The name of the column containing the vectors.
322    pub async fn build(
323        quantizer: ProductQuantizer,
324        batch: &RecordBatch,
325        vector_col: &str,
326        frag_reuse_index: Option<Arc<FragReuseIndex>>,
327    ) -> Result<Self> {
328        let codebook = quantizer.codebook.clone();
329        let num_bits = quantizer.num_bits;
330        let dimension = quantizer.dimension;
331        let num_sub_vectors = quantizer.num_sub_vectors;
332        let metric_type = quantizer.distance_type;
333        let transform = PQTransformer::new(quantizer, vector_col, PQ_CODE_COLUMN);
334        let batch = transform.transform(batch)?;
335        Self::new(
336            codebook,
337            batch,
338            num_bits,
339            num_sub_vectors,
340            dimension,
341            metric_type,
342            false,
343            frag_reuse_index,
344        )
345    }
346
347    pub fn codebook(&self) -> &FixedSizeListArray {
348        self.metadata.codebook.as_ref().unwrap()
349    }
350
351    /// Load full PQ storage from disk.
352    ///
353    /// Parameters
354    /// ----------
355    /// object_store: &ObjectStore
356    ///   The object store to load the storage from.
357    /// path: &Path
358    ///  The path to the storage.
359    ///
360    /// Returns
361    /// --------
362    /// Self
363    ///
364    /// Currently it loads everything in memory.
365    /// TODO: support lazy loading later.
366    pub async fn load(
367        object_store: &ObjectStore,
368        path: &Path,
369        frag_reuse_index: Option<Arc<FragReuseIndex>>,
370    ) -> Result<Self> {
371        let reader = PreviousFileReader::try_new_self_described(object_store, path, None).await?;
372        let schema = reader.schema();
373
374        let metadata_str = schema
375            .metadata
376            .get(INDEX_METADATA_SCHEMA_KEY)
377            .ok_or(Error::index(format!(
378                "Reading PQ storage: index key {} not found",
379                INDEX_METADATA_SCHEMA_KEY
380            )))?;
381        let index_metadata: IndexMetadata = serde_json::from_str(metadata_str).map_err(|_| {
382            Error::index(format!("Failed to parse index metadata: {}", metadata_str))
383        })?;
384        let distance_type: DistanceType =
385            DistanceType::try_from(index_metadata.distance_type.as_str())?;
386
387        let metadata = ProductQuantizationMetadata::load(&reader).await?;
388        Self::load_partition(
389            &reader,
390            0..reader.len(),
391            distance_type,
392            &metadata,
393            frag_reuse_index,
394        )
395        .await
396    }
397
398    pub fn schema(&self) -> SchemaRef {
399        self.batch.schema()
400    }
401
402    pub fn get_row_ids(&self, ids: &[u32]) -> Vec<u64> {
403        ids.iter()
404            .map(|&id| self.row_ids.value(id as usize))
405            .collect()
406    }
407
408    /// Write the PQ storage as a Lance partition to disk,
409    /// and returns the number of rows written.
410    ///
411    pub async fn write_partition(
412        &self,
413        writer: &mut PreviousFileWriter<ManifestDescribing>,
414    ) -> Result<usize> {
415        let batch_size: usize = 10240; // TODO: make it configurable
416        for offset in (0..self.batch.num_rows()).step_by(batch_size) {
417            let length = min(batch_size, self.batch.num_rows() - offset);
418            let slice = self.batch.slice(offset, length);
419            writer.write(&[slice]).await?;
420        }
421        Ok(self.batch.num_rows())
422    }
423}
424
425pub fn transpose<T: ArrowPrimitiveType>(
426    original: &PrimitiveArray<T>,
427    num_rows: usize,
428    num_columns: usize,
429) -> PrimitiveArray<T>
430where
431    PrimitiveArray<T>: From<Vec<T::Native>>,
432{
433    if original.is_empty() {
434        return original.clone();
435    }
436
437    let mut transposed_codes = vec![T::default_value(); original.len()];
438    for (vec_idx, codes) in original.values().chunks_exact(num_columns).enumerate() {
439        for (sub_vec_idx, code) in codes.iter().enumerate() {
440            transposed_codes[sub_vec_idx * num_rows + vec_idx] = *code;
441        }
442    }
443
444    transposed_codes.into()
445}
446
447#[async_trait]
448impl QuantizerStorage for ProductQuantizationStorage {
449    type Metadata = ProductQuantizationMetadata;
450
451    fn try_from_batch(
452        batch: RecordBatch,
453        metadata: &Self::Metadata,
454        distance_type: DistanceType,
455        frag_reuse_index: Option<Arc<FragReuseIndex>>,
456    ) -> Result<Self>
457    where
458        Self: Sized,
459    {
460        let distance_type = match distance_type {
461            DistanceType::Cosine => DistanceType::L2,
462            _ => distance_type,
463        };
464
465        // now it supports only Float32Type
466        let codebook = match &metadata.codebook {
467            Some(codebook) => codebook.clone(),
468            None => {
469                // legacy format would contains codebook tensor but not codebook
470                debug_assert!(!metadata.codebook_tensor.is_empty());
471                let codebook_tensor = pb::Tensor::decode(metadata.codebook_tensor.as_slice())?;
472                FixedSizeListArray::try_from(&codebook_tensor)?
473            }
474        };
475
476        Self::new(
477            codebook,
478            batch,
479            metadata.nbits,
480            metadata.num_sub_vectors,
481            metadata.dimension,
482            distance_type,
483            metadata.transposed,
484            frag_reuse_index,
485        )
486    }
487
488    fn metadata(&self) -> &Self::Metadata {
489        &self.metadata
490    }
491
492    // we can't use the default implementation of remap,
493    // because PQ Storage transposed the PQ codes
494    fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
495        let transposed_codes = self.pq_code.values();
496        let mut new_row_ids = Vec::with_capacity(self.len());
497        let mut new_codes = Vec::with_capacity(self.len() * self.metadata.num_sub_vectors);
498
499        let row_ids = self.row_ids.values();
500        for (i, row_id) in row_ids.iter().enumerate() {
501            match mapping.get(row_id) {
502                Some(Some(new_id)) => {
503                    new_row_ids.push(*new_id);
504                    new_codes.extend(get_pq_code(
505                        transposed_codes,
506                        self.metadata.nbits,
507                        self.metadata.num_sub_vectors,
508                        i as u32,
509                    ));
510                }
511                Some(None) => {}
512                None => {
513                    new_row_ids.push(*row_id);
514                    new_codes.extend(get_pq_code(
515                        transposed_codes,
516                        self.metadata.nbits,
517                        self.metadata.num_sub_vectors,
518                        i as u32,
519                    ));
520                }
521            }
522        }
523
524        let new_row_ids = Arc::new(UInt64Array::from(new_row_ids));
525        let new_codes = UInt8Array::from(new_codes);
526        let batch = if new_row_ids.is_empty() {
527            RecordBatch::new_empty(self.schema())
528        } else {
529            let num_bytes_in_code = new_codes.len() / new_row_ids.len();
530            let new_transposed_codes = transpose(&new_codes, new_row_ids.len(), num_bytes_in_code);
531            let codes_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
532                new_transposed_codes,
533                num_bytes_in_code as i32,
534            )?);
535            RecordBatch::try_new(self.schema(), vec![new_row_ids.clone(), codes_fsl])?
536        };
537        let transposed_codes = batch[PQ_CODE_COLUMN]
538            .as_fixed_size_list()
539            .values()
540            .as_primitive::<UInt8Type>()
541            .clone();
542
543        Ok(Self {
544            metadata: self.metadata.clone(),
545            distance_type: self.distance_type,
546            batch,
547            pq_code: Arc::new(transposed_codes),
548            row_ids: new_row_ids,
549        })
550    }
551
552    /// Load a partition of PQ storage from disk.
553    ///
554    /// Parameters
555    /// ----------
556    /// - *reader: &PreviousFileReader
557    async fn load_partition(
558        reader: &PreviousFileReader,
559        range: std::ops::Range<usize>,
560        distance_type: DistanceType,
561        metadata: &Self::Metadata,
562        frag_reuse_index: Option<Arc<FragReuseIndex>>,
563    ) -> Result<Self> {
564        // Hard coded to float32 for now
565        let codebook = metadata
566            .codebook
567            .as_ref()
568            .ok_or(Error::index(
569                "Codebook not found in PQ metadata".to_string(),
570            ))?
571            .values()
572            .as_primitive::<Float32Type>()
573            .clone();
574
575        let codebook =
576            FixedSizeListArray::try_new_from_values(codebook, metadata.dimension as i32)?;
577
578        let schema = reader.schema();
579        let batch = reader.read_range(range, schema).await?;
580
581        Self::new(
582            codebook,
583            batch,
584            metadata.nbits,
585            metadata.num_sub_vectors,
586            metadata.dimension,
587            distance_type,
588            metadata.transposed,
589            frag_reuse_index,
590        )
591    }
592}
593
594impl VectorStore for ProductQuantizationStorage {
595    type DistanceCalculator<'a> = PQDistCalculator;
596
597    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
598        Ok(std::iter::once(self.batch.clone()))
599    }
600
601    fn append_batch(&self, _batch: RecordBatch, _vector_column: &str) -> Result<Self> {
602        unimplemented!()
603    }
604
605    fn schema(&self) -> &SchemaRef {
606        self.batch.schema_ref()
607    }
608
609    fn as_any(&self) -> &dyn std::any::Any {
610        self
611    }
612
613    fn len(&self) -> usize {
614        self.batch.num_rows()
615    }
616
617    fn distance_type(&self) -> DistanceType {
618        self.distance_type
619    }
620
621    fn row_id(&self, id: u32) -> u64 {
622        self.row_ids.values()[id as usize]
623    }
624
625    fn row_ids(&self) -> impl Iterator<Item = &u64> {
626        self.row_ids.values().iter()
627    }
628
629    fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
630        let codebook = self.metadata.codebook.as_ref().unwrap();
631        match codebook.value_type() {
632            DataType::Float16 => PQDistCalculator::new(
633                codebook
634                    .values()
635                    .as_primitive::<datatypes::Float16Type>()
636                    .values(),
637                self.metadata.nbits,
638                self.metadata.num_sub_vectors,
639                self.pq_code.clone(),
640                query.as_primitive::<datatypes::Float16Type>().values(),
641                self.distance_type,
642            ),
643            DataType::Float32 => PQDistCalculator::new(
644                codebook
645                    .values()
646                    .as_primitive::<datatypes::Float32Type>()
647                    .values(),
648                self.metadata.nbits,
649                self.metadata.num_sub_vectors,
650                self.pq_code.clone(),
651                query.as_primitive::<datatypes::Float32Type>().values(),
652                self.distance_type,
653            ),
654            DataType::Float64 => PQDistCalculator::new(
655                codebook
656                    .values()
657                    .as_primitive::<datatypes::Float64Type>()
658                    .values(),
659                self.metadata.nbits,
660                self.metadata.num_sub_vectors,
661                self.pq_code.clone(),
662                query.as_primitive::<datatypes::Float64Type>().values(),
663                self.distance_type,
664            ),
665            _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
666        }
667    }
668
669    fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
670        let codes = get_pq_code(
671            self.pq_code.values(),
672            self.metadata.nbits,
673            self.metadata.num_sub_vectors,
674            id,
675        );
676        let codebook = self.metadata.codebook.as_ref().unwrap();
677        match codebook.value_type() {
678            DataType::Float16 => {
679                let codebook = codebook
680                    .values()
681                    .as_primitive::<datatypes::Float16Type>()
682                    .values();
683                let query = get_centroids(
684                    codebook,
685                    self.metadata.nbits,
686                    self.metadata.num_sub_vectors,
687                    self.metadata.dimension,
688                    codes,
689                );
690                PQDistCalculator::new(
691                    codebook,
692                    self.metadata.nbits,
693                    self.metadata.num_sub_vectors,
694                    self.pq_code.clone(),
695                    &query,
696                    self.distance_type,
697                )
698            }
699            DataType::Float32 => {
700                let codebook = codebook
701                    .values()
702                    .as_primitive::<datatypes::Float32Type>()
703                    .values();
704                let query = get_centroids(
705                    codebook,
706                    self.metadata.nbits,
707                    self.metadata.num_sub_vectors,
708                    self.metadata.dimension,
709                    codes,
710                );
711                PQDistCalculator::new(
712                    codebook,
713                    self.metadata.nbits,
714                    self.metadata.num_sub_vectors,
715                    self.pq_code.clone(),
716                    &query,
717                    self.distance_type,
718                )
719            }
720            DataType::Float64 => {
721                let codebook = codebook
722                    .values()
723                    .as_primitive::<datatypes::Float64Type>()
724                    .values();
725                let query = get_centroids(
726                    codebook,
727                    self.metadata.nbits,
728                    self.metadata.num_sub_vectors,
729                    self.metadata.dimension,
730                    codes,
731                );
732                PQDistCalculator::new(
733                    codebook,
734                    self.metadata.nbits,
735                    self.metadata.num_sub_vectors,
736                    self.pq_code.clone(),
737                    &query,
738                    self.distance_type,
739                )
740            }
741            _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
742        }
743    }
744
745    fn dist_between(&self, u: u32, v: u32) -> f32 {
746        // this is a fast way to compute distance between two vectors in the same storage.
747        // it doesn't construct the distance table.
748        let pq_codes = self.pq_code.values();
749        let u_codes = get_pq_code(
750            pq_codes,
751            self.metadata.nbits,
752            self.metadata.num_sub_vectors,
753            u,
754        );
755        let v_codes = get_pq_code(
756            pq_codes,
757            self.metadata.nbits,
758            self.metadata.num_sub_vectors,
759            v,
760        );
761        let codebook = self.metadata.codebook.as_ref().unwrap();
762
763        match codebook.value_type() {
764            DataType::Float16 => {
765                let qu = get_centroids(
766                    codebook
767                        .values()
768                        .as_primitive::<datatypes::Float16Type>()
769                        .values(),
770                    self.metadata.nbits,
771                    self.metadata.num_sub_vectors,
772                    self.metadata.dimension,
773                    u_codes,
774                );
775                let qv = get_centroids(
776                    codebook
777                        .values()
778                        .as_primitive::<datatypes::Float16Type>()
779                        .values(),
780                    self.metadata.nbits,
781                    self.metadata.num_sub_vectors,
782                    self.metadata.dimension,
783                    v_codes,
784                );
785                self.distance_type.func()(&qu, &qv)
786            }
787            DataType::Float32 => {
788                let qu = get_centroids(
789                    codebook
790                        .values()
791                        .as_primitive::<datatypes::Float32Type>()
792                        .values(),
793                    self.metadata.nbits,
794                    self.metadata.num_sub_vectors,
795                    self.metadata.dimension,
796                    u_codes,
797                );
798                let qv = get_centroids(
799                    codebook
800                        .values()
801                        .as_primitive::<datatypes::Float32Type>()
802                        .values(),
803                    self.metadata.nbits,
804                    self.metadata.num_sub_vectors,
805                    self.metadata.dimension,
806                    v_codes,
807                );
808                self.distance_type.func()(&qu, &qv)
809            }
810            DataType::Float64 => {
811                let qu = get_centroids(
812                    codebook
813                        .values()
814                        .as_primitive::<datatypes::Float64Type>()
815                        .values(),
816                    self.metadata.nbits,
817                    self.metadata.num_sub_vectors,
818                    self.metadata.dimension,
819                    u_codes,
820                );
821                let qv = get_centroids(
822                    codebook
823                        .values()
824                        .as_primitive::<datatypes::Float64Type>()
825                        .values(),
826                    self.metadata.nbits,
827                    self.metadata.num_sub_vectors,
828                    self.metadata.dimension,
829                    v_codes,
830                );
831                self.distance_type.func()(&qu, &qv)
832            }
833            _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
834        }
835    }
836
837    fn prefers_candidate(&self, candidate: &OrderedNode, selected: &[OrderedNode]) -> bool {
838        selected
839            .iter()
840            .all(|other| candidate.dist < OrderedFloat(self.dist_between(candidate.id, other.id)))
841    }
842}
843
844/// Distance calculator backed by PQ code.
845pub struct PQDistCalculator {
846    distance_table: Vec<f32>,
847    pq_code: Arc<UInt8Array>,
848    num_sub_vectors: usize,
849    num_bits: u32,
850    distance_type: DistanceType,
851}
852
853impl PQDistCalculator {
854    fn new<T: L2 + Dot>(
855        codebook: &[T],
856        num_bits: u32,
857        num_sub_vectors: usize,
858        pq_code: Arc<UInt8Array>,
859        query: &[T],
860        distance_type: DistanceType,
861    ) -> Self {
862        let distance_table = match distance_type {
863            DistanceType::L2 | DistanceType::Cosine => {
864                build_distance_table_l2(codebook, num_bits, num_sub_vectors, query)
865            }
866            DistanceType::Dot => {
867                build_distance_table_dot(codebook, num_bits, num_sub_vectors, query)
868            }
869            _ => unimplemented!("DistanceType is not supported: {:?}", distance_type),
870        };
871        Self {
872            distance_table,
873            num_sub_vectors,
874            pq_code,
875            num_bits,
876            distance_type,
877        }
878    }
879
880    fn get_pq_code(&self, id: u32) -> impl Iterator<Item = usize> + '_ {
881        get_pq_code(
882            self.pq_code.values(),
883            self.num_bits,
884            self.num_sub_vectors,
885            id,
886        )
887        .map(|v| v as usize)
888    }
889}
890
891impl DistCalculator for PQDistCalculator {
892    fn distance(&self, id: u32) -> f32 {
893        let num_centroids = 2_usize.pow(self.num_bits);
894        let pq_code = self.get_pq_code(id);
895        let diff = self.num_sub_vectors as f32 - 1.0;
896        let dist = if self.num_bits == 4 {
897            pq_code
898                .enumerate()
899                .map(|(i, c)| {
900                    let current_idx = c & 0x0F;
901                    let next_idx = c >> 4;
902
903                    self.distance_table[2 * i * num_centroids + current_idx]
904                        + self.distance_table[(2 * i + 1) * num_centroids + next_idx]
905                })
906                .sum()
907        } else {
908            pq_code
909                .enumerate()
910                .map(|(i, c)| self.distance_table[i * num_centroids + c])
911                .sum()
912        };
913
914        if self.distance_type == DistanceType::Dot {
915            dist - diff
916        } else {
917            dist
918        }
919    }
920
921    fn distance_all(&self, k_hint: usize) -> Vec<f32> {
922        match self.distance_type {
923            DistanceType::L2 => compute_pq_distance(
924                &self.distance_table,
925                self.num_bits,
926                self.num_sub_vectors,
927                self.pq_code.values(),
928                k_hint,
929            ),
930            DistanceType::Cosine => {
931                // it seems we implemented cosine distance at some version,
932                // but from now on, we should use normalized L2 distance.
933                debug_assert!(
934                    false,
935                    "cosine distance should be converted to normalized L2 distance"
936                );
937                // L2 over normalized vectors:  ||x - y|| = x^2 + y^2 - 2 * xy = 1 + 1 - 2 * xy = 2 * (1 - xy)
938                // Cosine distance: 1 - |xy| / (||x|| * ||y||) = 1 - xy / (x^2 * y^2) = 1 - xy / (1 * 1) = 1 - xy
939                // Therefore, Cosine = L2 / 2
940                let l2_dists = compute_pq_distance(
941                    &self.distance_table,
942                    self.num_bits,
943                    self.num_sub_vectors,
944                    self.pq_code.values(),
945                    k_hint,
946                );
947                l2_dists.into_iter().map(|v| v / 2.0).collect()
948            }
949            DistanceType::Dot => {
950                let dot_dists = compute_pq_distance(
951                    &self.distance_table,
952                    self.num_bits,
953                    self.num_sub_vectors,
954                    self.pq_code.values(),
955                    k_hint,
956                );
957                let diff = self.num_sub_vectors as f32 - 1.0;
958                dot_dists.into_iter().map(|v| v - diff).collect()
959            }
960            _ => unimplemented!("distance type is not supported: {:?}", self.distance_type),
961        }
962    }
963}
964
965fn get_pq_code(
966    pq_code: &[u8],
967    num_bits: u32,
968    num_sub_vectors: usize,
969    id: u32,
970) -> impl Iterator<Item = u8> + '_ {
971    let num_bytes = if num_bits == 4 {
972        num_sub_vectors / 2
973    } else {
974        num_sub_vectors
975    };
976
977    let num_vectors = pq_code.len() / num_bytes;
978    pq_code
979        .iter()
980        .skip(id as usize)
981        .step_by(num_vectors)
982        .copied()
983        .exact_size(num_bytes)
984}
985
986fn get_centroids<T: Clone>(
987    codebook: &[T],
988    num_bits: u32,
989    num_sub_vectors: usize,
990    dimension: usize,
991    codes: impl Iterator<Item = u8>,
992) -> Vec<T> {
993    // codebook[i][j] is the j-th centroid of the i-th sub-vector.
994    // the codebook is stored as a flat array, codebook[i * num_centroids + j] = codebook[i][j]
995
996    if num_bits == 4 {
997        return get_centroids_4bit(codebook, num_sub_vectors, dimension, codes);
998    }
999
1000    let num_centroids: usize = 2_usize.pow(8);
1001    let sub_vector_width = dimension / num_sub_vectors;
1002    let mut centroids = Vec::with_capacity(dimension);
1003    for (sub_vec_idx, centroid_idx) in codes.enumerate() {
1004        let centroid_idx = centroid_idx as usize;
1005        let centroid = &codebook[sub_vec_idx * num_centroids * sub_vector_width
1006            + centroid_idx * sub_vector_width
1007            ..sub_vec_idx * num_centroids * sub_vector_width
1008                + (centroid_idx + 1) * sub_vector_width];
1009        centroids.extend_from_slice(centroid);
1010    }
1011    centroids
1012}
1013
1014fn get_centroids_4bit<T: Clone>(
1015    codebook: &[T],
1016    num_sub_vectors: usize,
1017    dimension: usize,
1018    codes: impl Iterator<Item = u8>,
1019) -> Vec<T> {
1020    let num_centroids: usize = 16;
1021    let sub_vector_width = dimension / num_sub_vectors;
1022    let mut centroids = Vec::with_capacity(dimension);
1023    for (sub_vec_idx, centroid_idx) in codes.into_iter().enumerate() {
1024        let current_idx = (centroid_idx & 0x0F) as usize;
1025        let offset = 2 * sub_vec_idx * num_centroids * sub_vector_width;
1026        let current_centroid = &codebook[offset + current_idx * sub_vector_width
1027            ..offset + (current_idx + 1) * sub_vector_width];
1028        centroids.extend_from_slice(current_centroid);
1029
1030        let next_idx = (centroid_idx >> 4) as usize;
1031        let offset = (2 * sub_vec_idx + 1) * num_centroids * sub_vector_width;
1032        let next_centroid = &codebook
1033            [offset + next_idx * sub_vector_width..offset + (next_idx + 1) * sub_vector_width];
1034        centroids.extend_from_slice(next_centroid);
1035    }
1036    centroids
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041    use crate::vector::storage::StorageBuilder;
1042
1043    use super::*;
1044
1045    use arrow_array::{Float32Array, UInt32Array};
1046    use arrow_schema::{DataType, Field, Schema as ArrowSchema};
1047    use lance_arrow::FixedSizeListArrayExt;
1048    use lance_core::ROW_ID_FIELD;
1049    use rand::Rng;
1050
1051    const DIM: usize = 32;
1052    const TOTAL: usize = 512;
1053    const NUM_SUB_VECTORS: usize = 16;
1054
1055    async fn create_pq_storage() -> ProductQuantizationStorage {
1056        let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random()));
1057        let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap();
1058        let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot);
1059
1060        let schema = ArrowSchema::new(vec![
1061            Field::new(
1062                "vec",
1063                DataType::FixedSizeList(
1064                    Field::new_list_field(DataType::Float32, true).into(),
1065                    DIM as i32,
1066                ),
1067                true,
1068            ),
1069            ROW_ID_FIELD.clone(),
1070        ]);
1071        let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random()));
1072        let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
1073        let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
1074        let batch =
1075            RecordBatch::try_new(schema.into(), vec![Arc::new(fsl), Arc::new(row_ids)]).unwrap();
1076
1077        StorageBuilder::new("vec".to_owned(), pq.distance_type, pq, None)
1078            .unwrap()
1079            .build(vec![batch])
1080            .unwrap()
1081    }
1082
1083    async fn create_pq_storage_with_extra_column() -> ProductQuantizationStorage {
1084        let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random()));
1085        let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap();
1086        let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot);
1087
1088        let schema = ArrowSchema::new(vec![
1089            Field::new(
1090                "vec",
1091                DataType::FixedSizeList(
1092                    Field::new_list_field(DataType::Float32, true).into(),
1093                    DIM as i32,
1094                ),
1095                true,
1096            ),
1097            ROW_ID_FIELD.clone(),
1098            Field::new("extra", DataType::UInt32, true),
1099        ]);
1100        let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random()));
1101        let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
1102        let extra_column = UInt32Array::from_iter_values((0..TOTAL).map(|v| v as u32));
1103        let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
1104        let batch = RecordBatch::try_new(
1105            schema.into(),
1106            vec![Arc::new(fsl), Arc::new(row_ids), Arc::new(extra_column)],
1107        )
1108        .unwrap();
1109
1110        StorageBuilder::new("vec".to_owned(), pq.distance_type, pq, None)
1111            .unwrap()
1112            .build(vec![batch])
1113            .unwrap()
1114    }
1115
1116    #[tokio::test]
1117    async fn test_build_pq_storage() {
1118        let storage = create_pq_storage().await;
1119        assert_eq!(storage.len(), TOTAL);
1120        assert_eq!(storage.metadata.num_sub_vectors, NUM_SUB_VECTORS);
1121        assert_eq!(
1122            storage.metadata.codebook.as_ref().unwrap().values().len(),
1123            256 * DIM
1124        );
1125        assert_eq!(storage.pq_code.len(), TOTAL * NUM_SUB_VECTORS);
1126        assert_eq!(storage.row_ids.len(), TOTAL);
1127    }
1128
1129    #[tokio::test]
1130    async fn test_distance_all() {
1131        let storage = create_pq_storage().await;
1132        let query = Arc::new(Float32Array::from_iter_values((0..DIM).map(|v| v as f32)));
1133        let dist_calc = storage.dist_calculator(query, 0.0);
1134        let expected = (0..storage.len())
1135            .map(|id| dist_calc.distance(id as u32))
1136            .collect::<Vec<_>>();
1137        let distances = dist_calc.distance_all(100);
1138        assert_eq!(distances, expected);
1139    }
1140
1141    #[tokio::test]
1142    async fn test_dist_between() {
1143        let mut rng = rand::rng();
1144        let storage = create_pq_storage().await;
1145        let u = rng.random_range(0..storage.len() as u32);
1146        let v = rng.random_range(0..storage.len() as u32);
1147        let dist1 = storage.dist_between(u, v);
1148        let dist2 = storage.dist_between(v, u);
1149        assert_eq!(dist1, dist2);
1150    }
1151
1152    #[tokio::test]
1153    async fn test_remap_with_extra_column() {
1154        let storage = create_pq_storage_with_extra_column().await;
1155        let mut mapping = HashMap::new();
1156        for i in 0..TOTAL / 2 {
1157            mapping.insert(i as u64, Some((TOTAL + i) as u64));
1158        }
1159        for i in TOTAL / 2..TOTAL {
1160            mapping.insert(i as u64, None);
1161        }
1162        let new_storage = storage.remap(&mapping).unwrap();
1163        assert_eq!(new_storage.len(), TOTAL / 2);
1164        assert_eq!(new_storage.row_ids.len(), TOTAL / 2);
1165        for (i, row_id) in new_storage.row_ids().enumerate() {
1166            assert_eq!(*row_id, (TOTAL + i) as u64);
1167        }
1168        assert_eq!(new_storage.batch.num_columns(), 2);
1169        assert!(new_storage.batch.column_by_name(ROW_ID).is_some());
1170        assert!(new_storage.batch.column_by_name(PQ_CODE_COLUMN).is_some());
1171    }
1172}