lance_index/vector/
quantizer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use core::fmt;
5use std::str::FromStr;
6use std::sync::Arc;
7use std::{collections::HashMap, fmt::Debug};
8
9use arrow::{array::AsArray, compute::concat_batches, datatypes::UInt64Type};
10use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt32Array, UInt64Array};
11use arrow_schema::Field;
12use async_trait::async_trait;
13use bytes::Bytes;
14use deepsize::DeepSizeOf;
15use lance_arrow::RecordBatchExt;
16use lance_core::{Error, Result, ROW_ID};
17use lance_file::previous::reader::FileReader as PreviousFileReader;
18use lance_io::traits::Reader;
19use lance_linalg::distance::DistanceType;
20use lance_table::format::SelfDescribingFileReader;
21use serde::{Deserialize, Serialize};
22use snafu::location;
23
24use super::flat::index::{FlatBinQuantizer, FlatQuantizer};
25use super::pq::ProductQuantizer;
26use super::{ivf::storage::IvfModel, sq::ScalarQuantizer, storage::VectorStore};
27use crate::frag_reuse::FragReuseIndex;
28use crate::vector::bq::builder::RabitQuantizer;
29use crate::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY};
30
31pub trait Quantization:
32    Send
33    + Sync
34    + Clone
35    + Debug
36    + DeepSizeOf
37    + Into<Quantizer>
38    + TryFrom<Quantizer, Error = lance_core::Error>
39{
40    type BuildParams: QuantizerBuildParams + Send + Sync;
41    type Metadata: QuantizerMetadata + Send + Sync;
42    type Storage: QuantizerStorage<Metadata = Self::Metadata> + Debug;
43
44    fn build(
45        data: &dyn Array,
46        distance_type: DistanceType,
47        params: &Self::BuildParams,
48    ) -> Result<Self>;
49    fn retrain(&mut self, data: &dyn Array) -> Result<()>;
50    fn code_dim(&self) -> usize;
51    fn column(&self) -> &'static str;
52    fn use_residual(_: DistanceType) -> bool {
53        false
54    }
55    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef>;
56    fn metadata_key() -> &'static str;
57    fn quantization_type() -> QuantizationType;
58    fn metadata(&self, _: Option<QuantizationMetadata>) -> Self::Metadata;
59    fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer>;
60    fn field(&self) -> Field;
61    fn extra_fields(&self) -> Vec<Field> {
62        vec![]
63    }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum QuantizationType {
68    Flat,
69    Product,
70    Scalar,
71    Rabit,
72}
73
74impl FromStr for QuantizationType {
75    type Err = Error;
76
77    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
78        match s {
79            "FLAT" => Ok(Self::Flat),
80            "PQ" => Ok(Self::Product),
81            "SQ" => Ok(Self::Scalar),
82            "RABIT" => Ok(Self::Rabit),
83            _ => Err(Error::Index {
84                message: format!("Unknown quantization type: {}", s),
85                location: location!(),
86            }),
87        }
88    }
89}
90
91impl std::fmt::Display for QuantizationType {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        match self {
94            Self::Flat => write!(f, "FLAT"),
95            Self::Product => write!(f, "PQ"),
96            Self::Scalar => write!(f, "SQ"),
97            Self::Rabit => write!(f, "RQ"),
98        }
99    }
100}
101
102pub trait QuantizerBuildParams: Send + Sync {
103    fn sample_size(&self) -> usize;
104    fn use_residual(_: DistanceType) -> bool {
105        false
106    }
107}
108
109impl QuantizerBuildParams for () {
110    fn sample_size(&self) -> usize {
111        0
112    }
113}
114
115/// Quantization Method.
116///
117/// <section class="warning">
118/// Internal use only. End-user does not use this directly.
119/// </section>
120#[derive(Debug, Clone, DeepSizeOf)]
121pub enum Quantizer {
122    Flat(FlatQuantizer),
123    FlatBin(FlatBinQuantizer),
124    Product(ProductQuantizer),
125    Scalar(ScalarQuantizer),
126    Rabit(RabitQuantizer),
127}
128
129impl Quantizer {
130    pub fn code_dim(&self) -> usize {
131        match self {
132            Self::Flat(fq) => fq.code_dim(),
133            Self::FlatBin(fq) => fq.code_dim(),
134            Self::Product(pq) => pq.code_dim(),
135            Self::Scalar(sq) => sq.code_dim(),
136            Self::Rabit(rq) => rq.code_dim(),
137        }
138    }
139
140    pub fn column(&self) -> &'static str {
141        match self {
142            Self::Flat(fq) => fq.column(),
143            Self::FlatBin(fq) => fq.column(),
144            Self::Product(pq) => pq.column(),
145            Self::Scalar(sq) => sq.column(),
146            Self::Rabit(rq) => rq.column(),
147        }
148    }
149
150    pub fn metadata_key(&self) -> &'static str {
151        match self {
152            Self::Flat(_) => FlatQuantizer::metadata_key(),
153            Self::FlatBin(_) => FlatBinQuantizer::metadata_key(),
154            Self::Product(_) => ProductQuantizer::metadata_key(),
155            Self::Scalar(_) => ScalarQuantizer::metadata_key(),
156            Self::Rabit(_) => RabitQuantizer::metadata_key(),
157        }
158    }
159
160    pub fn quantization_type(&self) -> QuantizationType {
161        match self {
162            Self::Flat(_) => QuantizationType::Flat,
163            Self::FlatBin(_) => QuantizationType::Flat,
164            Self::Product(_) => QuantizationType::Product,
165            Self::Scalar(_) => QuantizationType::Scalar,
166            Self::Rabit(_) => QuantizationType::Rabit,
167        }
168    }
169
170    pub fn metadata(&self, args: Option<QuantizationMetadata>) -> Result<serde_json::Value> {
171        let metadata = match self {
172            Self::Flat(fq) => serde_json::to_value(fq.metadata(args))?,
173            Self::FlatBin(fq) => serde_json::to_value(fq.metadata(args))?,
174            Self::Product(pq) => serde_json::to_value(pq.metadata(args))?,
175            Self::Scalar(sq) => serde_json::to_value(sq.metadata(args))?,
176            Self::Rabit(rq) => serde_json::to_value(rq.metadata(args))?,
177        };
178        Ok(metadata)
179    }
180}
181
182impl From<ProductQuantizer> for Quantizer {
183    fn from(pq: ProductQuantizer) -> Self {
184        Self::Product(pq)
185    }
186}
187
188impl From<ScalarQuantizer> for Quantizer {
189    fn from(sq: ScalarQuantizer) -> Self {
190        Self::Scalar(sq)
191    }
192}
193
194#[derive(Debug, Clone, Default)]
195pub struct QuantizationMetadata {
196    // For PQ
197    pub codebook_position: Option<usize>,
198    pub codebook: Option<FixedSizeListArray>,
199    pub transposed: bool,
200}
201
202#[async_trait]
203pub trait QuantizerMetadata:
204    fmt::Debug + Clone + Sized + DeepSizeOf + for<'a> Deserialize<'a> + Serialize
205{
206    // the extra metadata index in global buffer
207    fn buffer_index(&self) -> Option<u32> {
208        None
209    }
210
211    fn set_buffer_index(&mut self, _: u32) {
212        // do nothing
213    }
214
215    // parse the extra metadata bytes from global buffer,
216    // and set the metadata fields
217    fn parse_buffer(&mut self, _bytes: Bytes) -> Result<()> {
218        Ok(())
219    }
220
221    // the metadata that should be stored in global buffer
222    fn extra_metadata(&self) -> Result<Option<Bytes>> {
223        Ok(None)
224    }
225
226    async fn load(reader: &PreviousFileReader) -> Result<Self>;
227}
228
229#[async_trait::async_trait]
230pub trait QuantizerStorage: Clone + Sized + DeepSizeOf + VectorStore {
231    type Metadata: QuantizerMetadata;
232
233    /// Create a [QuantizerStorage] from a [RecordBatch].
234    /// The batch should consist of row IDs and quantized vector.
235    fn try_from_batch(
236        batch: RecordBatch,
237        metadata: &Self::Metadata,
238        distance_type: DistanceType,
239        frag_reuse_index: Option<Arc<FragReuseIndex>>,
240    ) -> Result<Self>;
241
242    fn metadata(&self) -> &Self::Metadata;
243
244    fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
245        let batches = self
246            .to_batches()?
247            .map(|b| {
248                let mut indices = Vec::with_capacity(b.num_rows());
249                let mut new_row_ids = Vec::with_capacity(b.num_rows());
250
251                let row_ids = b.column(0).as_primitive::<UInt64Type>().values();
252                for (i, row_id) in row_ids.iter().enumerate() {
253                    match mapping.get(row_id) {
254                        Some(Some(new_id)) => {
255                            indices.push(i as u32);
256                            new_row_ids.push(*new_id);
257                        }
258                        Some(None) => {}
259                        None => {
260                            indices.push(i as u32);
261                            new_row_ids.push(*row_id);
262                        }
263                    }
264                }
265
266                let indices = UInt32Array::from(indices);
267                let new_row_ids = UInt64Array::from(new_row_ids);
268                let b = b
269                    .take(&indices)?
270                    .replace_column_by_name(ROW_ID, Arc::new(new_row_ids))?;
271                Ok(b)
272            })
273            .collect::<Result<Vec<_>>>()?;
274
275        let batch = concat_batches(self.schema(), batches.iter())?;
276        Self::try_from_batch(batch, self.metadata(), self.distance_type(), None)
277    }
278
279    async fn load_partition(
280        reader: &PreviousFileReader,
281        range: std::ops::Range<usize>,
282        distance_type: DistanceType,
283        metadata: &Self::Metadata,
284        frag_reuse_index: Option<Arc<FragReuseIndex>>,
285    ) -> Result<Self>;
286}
287
288/// Loader to load partitioned [VectorStore] from disk.
289pub struct IvfQuantizationStorage<Q: Quantization> {
290    reader: PreviousFileReader,
291
292    distance_type: DistanceType,
293    quantizer: Quantizer,
294    metadata: Q::Metadata,
295
296    ivf: IvfModel,
297}
298
299impl<Q: Quantization> DeepSizeOf for IvfQuantizationStorage<Q> {
300    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
301        self.reader.deep_size_of_children(context)
302            + self.quantizer.deep_size_of_children(context)
303            + self.metadata.deep_size_of_children(context)
304            + self.ivf.deep_size_of_children(context)
305    }
306}
307
308impl<Q: Quantization> Clone for IvfQuantizationStorage<Q> {
309    fn clone(&self) -> Self {
310        Self {
311            reader: self.reader.clone(),
312            distance_type: self.distance_type,
313            quantizer: self.quantizer.clone(),
314            metadata: self.metadata.clone(),
315            ivf: self.ivf.clone(),
316        }
317    }
318}
319
320#[allow(dead_code)]
321impl<Q: Quantization> IvfQuantizationStorage<Q> {
322    /// Open a Loader.
323    ///
324    ///
325    pub async fn open(reader: Arc<dyn Reader>) -> Result<Self> {
326        let reader = PreviousFileReader::try_new_self_described_from_reader(reader, None).await?;
327        let schema = reader.schema();
328
329        let metadata_str = schema
330            .metadata
331            .get(INDEX_METADATA_SCHEMA_KEY)
332            .ok_or(Error::Index {
333                message: format!(
334                    "Reading quantization storage: index key {} not found",
335                    INDEX_METADATA_SCHEMA_KEY
336                ),
337                location: location!(),
338            })?;
339        let index_metadata: IndexMetadata =
340            serde_json::from_str(metadata_str).map_err(|_| Error::Index {
341                message: format!("Failed to parse index metadata: {}", metadata_str),
342                location: location!(),
343            })?;
344        let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
345
346        let ivf_data = IvfModel::load(&reader).await?;
347
348        let metadata = Q::Metadata::load(&reader).await?;
349        let quantizer = Q::from_metadata(&metadata, distance_type)?;
350        Ok(Self {
351            reader,
352            distance_type,
353            quantizer,
354            metadata,
355            ivf: ivf_data,
356        })
357    }
358
359    pub fn distance_type(&self) -> DistanceType {
360        self.distance_type
361    }
362
363    pub fn quantizer(&self) -> &Quantizer {
364        &self.quantizer
365    }
366
367    pub fn metadata(&self) -> &Q::Metadata {
368        &self.metadata
369    }
370
371    /// Get the number of partitions in the storage.
372    pub fn num_partitions(&self) -> usize {
373        self.ivf.num_partitions()
374    }
375
376    /// Load one partition of vector storage.
377    ///
378    /// # Parameters
379    /// - `part_id`, partition id
380    ///
381    ///
382    pub async fn load_partition(&self, part_id: usize) -> Result<Q::Storage> {
383        let range = self.ivf.row_range(part_id);
384        Q::Storage::load_partition(
385            &self.reader,
386            range,
387            self.distance_type,
388            &self.metadata,
389            None,
390        )
391        .await
392    }
393}