Skip to main content

lance_index/vector/
quantizer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use core::fmt;
5use std::str::FromStr;
6use std::sync::Arc;
7use std::{collections::HashMap, fmt::Debug};
8
9use arrow::{array::AsArray, compute::concat_batches, datatypes::UInt64Type};
10use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt32Array, UInt64Array};
11use arrow_schema::Field;
12use async_trait::async_trait;
13use bytes::Bytes;
14use deepsize::DeepSizeOf;
15use lance_arrow::RecordBatchExt;
16use lance_core::{Error, ROW_ID, Result};
17use lance_file::previous::reader::FileReader as PreviousFileReader;
18use lance_io::traits::Reader;
19use lance_linalg::distance::DistanceType;
20use lance_table::format::SelfDescribingFileReader;
21use serde::{Deserialize, Serialize};
22
23use super::flat::index::{FlatBinQuantizer, FlatQuantizer};
24use super::pq::ProductQuantizer;
25use super::{ivf::storage::IvfModel, sq::ScalarQuantizer, storage::VectorStore};
26use crate::frag_reuse::FragReuseIndex;
27use crate::vector::bq::builder::RabitQuantizer;
28use crate::{INDEX_METADATA_SCHEMA_KEY, IndexMetadata};
29
30pub trait Quantization:
31    Send
32    + Sync
33    + Clone
34    + Debug
35    + DeepSizeOf
36    + Into<Quantizer>
37    + TryFrom<Quantizer, Error = lance_core::Error>
38{
39    type BuildParams: QuantizerBuildParams + Send + Sync;
40    type Metadata: QuantizerMetadata + Send + Sync;
41    type Storage: QuantizerStorage<Metadata = Self::Metadata> + Debug;
42
43    fn build(
44        data: &dyn Array,
45        distance_type: DistanceType,
46        params: &Self::BuildParams,
47    ) -> Result<Self>;
48    fn retrain(&mut self, data: &dyn Array) -> Result<()>;
49    fn code_dim(&self) -> usize;
50    fn column(&self) -> &'static str;
51    fn use_residual(_: DistanceType) -> bool {
52        false
53    }
54    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef>;
55    fn metadata_key() -> &'static str;
56    fn quantization_type() -> QuantizationType;
57    fn metadata(&self, _: Option<QuantizationMetadata>) -> Self::Metadata;
58    fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer>;
59    fn field(&self) -> Field;
60    fn extra_fields(&self) -> Vec<Field> {
61        vec![]
62    }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum QuantizationType {
67    Flat,
68    Product,
69    Scalar,
70    Rabit,
71}
72
73impl FromStr for QuantizationType {
74    type Err = Error;
75
76    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
77        match s {
78            "FLAT" => Ok(Self::Flat),
79            "PQ" => Ok(Self::Product),
80            "SQ" => Ok(Self::Scalar),
81            "RABIT" => Ok(Self::Rabit),
82            _ => Err(Error::index(format!("Unknown quantization type: {}", s))),
83        }
84    }
85}
86
87impl std::fmt::Display for QuantizationType {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        match self {
90            Self::Flat => write!(f, "FLAT"),
91            Self::Product => write!(f, "PQ"),
92            Self::Scalar => write!(f, "SQ"),
93            Self::Rabit => write!(f, "RQ"),
94        }
95    }
96}
97
98pub trait QuantizerBuildParams: Send + Sync {
99    fn sample_size(&self) -> usize;
100    fn use_residual(_: DistanceType) -> bool {
101        false
102    }
103}
104
105impl QuantizerBuildParams for () {
106    fn sample_size(&self) -> usize {
107        0
108    }
109}
110
111/// Quantization Method.
112///
113/// <section class="warning">
114/// Internal use only. End-user does not use this directly.
115/// </section>
116#[derive(Debug, Clone, DeepSizeOf)]
117pub enum Quantizer {
118    Flat(FlatQuantizer),
119    FlatBin(FlatBinQuantizer),
120    Product(ProductQuantizer),
121    Scalar(ScalarQuantizer),
122    Rabit(RabitQuantizer),
123}
124
125impl Quantizer {
126    pub fn code_dim(&self) -> usize {
127        match self {
128            Self::Flat(fq) => fq.code_dim(),
129            Self::FlatBin(fq) => fq.code_dim(),
130            Self::Product(pq) => pq.code_dim(),
131            Self::Scalar(sq) => sq.code_dim(),
132            Self::Rabit(rq) => rq.code_dim(),
133        }
134    }
135
136    pub fn column(&self) -> &'static str {
137        match self {
138            Self::Flat(fq) => fq.column(),
139            Self::FlatBin(fq) => fq.column(),
140            Self::Product(pq) => pq.column(),
141            Self::Scalar(sq) => sq.column(),
142            Self::Rabit(rq) => rq.column(),
143        }
144    }
145
146    pub fn metadata_key(&self) -> &'static str {
147        match self {
148            Self::Flat(_) => FlatQuantizer::metadata_key(),
149            Self::FlatBin(_) => FlatBinQuantizer::metadata_key(),
150            Self::Product(_) => ProductQuantizer::metadata_key(),
151            Self::Scalar(_) => ScalarQuantizer::metadata_key(),
152            Self::Rabit(_) => RabitQuantizer::metadata_key(),
153        }
154    }
155
156    pub fn quantization_type(&self) -> QuantizationType {
157        match self {
158            Self::Flat(_) => QuantizationType::Flat,
159            Self::FlatBin(_) => QuantizationType::Flat,
160            Self::Product(_) => QuantizationType::Product,
161            Self::Scalar(_) => QuantizationType::Scalar,
162            Self::Rabit(_) => QuantizationType::Rabit,
163        }
164    }
165
166    pub fn metadata(&self, args: Option<QuantizationMetadata>) -> Result<serde_json::Value> {
167        let metadata = match self {
168            Self::Flat(fq) => serde_json::to_value(fq.metadata(args))?,
169            Self::FlatBin(fq) => serde_json::to_value(fq.metadata(args))?,
170            Self::Product(pq) => serde_json::to_value(pq.metadata(args))?,
171            Self::Scalar(sq) => serde_json::to_value(sq.metadata(args))?,
172            Self::Rabit(rq) => serde_json::to_value(rq.metadata(args))?,
173        };
174        Ok(metadata)
175    }
176}
177
178impl From<ProductQuantizer> for Quantizer {
179    fn from(pq: ProductQuantizer) -> Self {
180        Self::Product(pq)
181    }
182}
183
184impl From<ScalarQuantizer> for Quantizer {
185    fn from(sq: ScalarQuantizer) -> Self {
186        Self::Scalar(sq)
187    }
188}
189
190#[derive(Debug, Clone, Default)]
191pub struct QuantizationMetadata {
192    // For PQ
193    pub codebook_position: Option<usize>,
194    pub codebook: Option<FixedSizeListArray>,
195    pub transposed: bool,
196}
197
198#[async_trait]
199pub trait QuantizerMetadata:
200    fmt::Debug + Clone + Sized + DeepSizeOf + for<'a> Deserialize<'a> + Serialize
201{
202    // the extra metadata index in global buffer
203    fn buffer_index(&self) -> Option<u32> {
204        None
205    }
206
207    fn set_buffer_index(&mut self, _: u32) {
208        // do nothing
209    }
210
211    // parse the extra metadata bytes from global buffer,
212    // and set the metadata fields
213    fn parse_buffer(&mut self, _bytes: Bytes) -> Result<()> {
214        Ok(())
215    }
216
217    // the metadata that should be stored in global buffer
218    fn extra_metadata(&self) -> Result<Option<Bytes>> {
219        Ok(None)
220    }
221
222    async fn load(reader: &PreviousFileReader) -> Result<Self>;
223}
224
225#[async_trait::async_trait]
226pub trait QuantizerStorage: Clone + Sized + DeepSizeOf + VectorStore {
227    type Metadata: QuantizerMetadata;
228
229    /// Create a [QuantizerStorage] from a [RecordBatch].
230    /// The batch should consist of row IDs and quantized vector.
231    fn try_from_batch(
232        batch: RecordBatch,
233        metadata: &Self::Metadata,
234        distance_type: DistanceType,
235        frag_reuse_index: Option<Arc<FragReuseIndex>>,
236    ) -> Result<Self>;
237
238    fn metadata(&self) -> &Self::Metadata;
239
240    fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
241        let batches = self
242            .to_batches()?
243            .map(|b| {
244                let mut indices = Vec::with_capacity(b.num_rows());
245                let mut new_row_ids = Vec::with_capacity(b.num_rows());
246
247                let row_ids = b.column(0).as_primitive::<UInt64Type>().values();
248                for (i, row_id) in row_ids.iter().enumerate() {
249                    match mapping.get(row_id) {
250                        Some(Some(new_id)) => {
251                            indices.push(i as u32);
252                            new_row_ids.push(*new_id);
253                        }
254                        Some(None) => {}
255                        None => {
256                            indices.push(i as u32);
257                            new_row_ids.push(*row_id);
258                        }
259                    }
260                }
261
262                let indices = UInt32Array::from(indices);
263                let new_row_ids = UInt64Array::from(new_row_ids);
264                let b = b
265                    .take(&indices)?
266                    .replace_column_by_name(ROW_ID, Arc::new(new_row_ids))?;
267                Ok(b)
268            })
269            .collect::<Result<Vec<_>>>()?;
270
271        let batch = concat_batches(self.schema(), batches.iter())?;
272        Self::try_from_batch(batch, self.metadata(), self.distance_type(), None)
273    }
274
275    async fn load_partition(
276        reader: &PreviousFileReader,
277        range: std::ops::Range<usize>,
278        distance_type: DistanceType,
279        metadata: &Self::Metadata,
280        frag_reuse_index: Option<Arc<FragReuseIndex>>,
281    ) -> Result<Self>;
282}
283
284/// Loader to load partitioned [VectorStore] from disk.
285pub struct IvfQuantizationStorage<Q: Quantization> {
286    reader: PreviousFileReader,
287
288    distance_type: DistanceType,
289    quantizer: Quantizer,
290    metadata: Q::Metadata,
291
292    ivf: IvfModel,
293}
294
295impl<Q: Quantization> DeepSizeOf for IvfQuantizationStorage<Q> {
296    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
297        self.reader.deep_size_of_children(context)
298            + self.quantizer.deep_size_of_children(context)
299            + self.metadata.deep_size_of_children(context)
300            + self.ivf.deep_size_of_children(context)
301    }
302}
303
304impl<Q: Quantization> Clone for IvfQuantizationStorage<Q> {
305    fn clone(&self) -> Self {
306        Self {
307            reader: self.reader.clone(),
308            distance_type: self.distance_type,
309            quantizer: self.quantizer.clone(),
310            metadata: self.metadata.clone(),
311            ivf: self.ivf.clone(),
312        }
313    }
314}
315
316#[allow(dead_code)]
317impl<Q: Quantization> IvfQuantizationStorage<Q> {
318    /// Open a Loader.
319    ///
320    ///
321    pub async fn open(reader: Arc<dyn Reader>) -> Result<Self> {
322        let reader = PreviousFileReader::try_new_self_described_from_reader(reader, None).await?;
323        let schema = reader.schema();
324
325        let metadata_str = schema
326            .metadata
327            .get(INDEX_METADATA_SCHEMA_KEY)
328            .ok_or(Error::index(format!(
329                "Reading quantization storage: index key {} not found",
330                INDEX_METADATA_SCHEMA_KEY
331            )))?;
332        let index_metadata: IndexMetadata = serde_json::from_str(metadata_str).map_err(|_| {
333            Error::index(format!("Failed to parse index metadata: {}", metadata_str))
334        })?;
335        let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
336
337        let ivf_data = IvfModel::load(&reader).await?;
338
339        let metadata = Q::Metadata::load(&reader).await?;
340        let quantizer = Q::from_metadata(&metadata, distance_type)?;
341        Ok(Self {
342            reader,
343            distance_type,
344            quantizer,
345            metadata,
346            ivf: ivf_data,
347        })
348    }
349
350    pub fn distance_type(&self) -> DistanceType {
351        self.distance_type
352    }
353
354    pub fn quantizer(&self) -> &Quantizer {
355        &self.quantizer
356    }
357
358    pub fn metadata(&self) -> &Q::Metadata {
359        &self.metadata
360    }
361
362    /// Get the number of partitions in the storage.
363    pub fn num_partitions(&self) -> usize {
364        self.ivf.num_partitions()
365    }
366
367    /// Load one partition of vector storage.
368    ///
369    /// # Parameters
370    /// - `part_id`, partition id
371    ///
372    ///
373    pub async fn load_partition(&self, part_id: usize) -> Result<Q::Storage> {
374        let range = self.ivf.row_range(part_id);
375        Q::Storage::load_partition(
376            &self.reader,
377            range,
378            self.distance_type,
379            &self.metadata,
380            None,
381        )
382        .await
383    }
384}