1use core::fmt;
5use std::str::FromStr;
6use std::sync::Arc;
7use std::{collections::HashMap, fmt::Debug};
8
9use arrow::{array::AsArray, compute::concat_batches, datatypes::UInt64Type};
10use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt32Array, UInt64Array};
11use arrow_schema::Field;
12use async_trait::async_trait;
13use bytes::Bytes;
14use deepsize::DeepSizeOf;
15use lance_arrow::RecordBatchExt;
16use lance_core::{Error, ROW_ID, Result};
17use lance_file::previous::reader::FileReader as PreviousFileReader;
18use lance_io::traits::Reader;
19use lance_linalg::distance::DistanceType;
20use lance_table::format::SelfDescribingFileReader;
21use serde::{Deserialize, Serialize};
22
23use super::flat::index::{FlatBinQuantizer, FlatQuantizer};
24use super::pq::ProductQuantizer;
25use super::{ivf::storage::IvfModel, sq::ScalarQuantizer, storage::VectorStore};
26use crate::frag_reuse::FragReuseIndex;
27use crate::vector::bq::builder::RabitQuantizer;
28use crate::{INDEX_METADATA_SCHEMA_KEY, IndexMetadata};
29
30pub trait Quantization:
31 Send
32 + Sync
33 + Clone
34 + Debug
35 + DeepSizeOf
36 + Into<Quantizer>
37 + TryFrom<Quantizer, Error = lance_core::Error>
38{
39 type BuildParams: QuantizerBuildParams + Send + Sync;
40 type Metadata: QuantizerMetadata + Send + Sync;
41 type Storage: QuantizerStorage<Metadata = Self::Metadata> + Debug;
42
43 fn build(
44 data: &dyn Array,
45 distance_type: DistanceType,
46 params: &Self::BuildParams,
47 ) -> Result<Self>;
48 fn retrain(&mut self, data: &dyn Array) -> Result<()>;
49 fn code_dim(&self) -> usize;
50 fn column(&self) -> &'static str;
51 fn use_residual(_: DistanceType) -> bool {
52 false
53 }
54 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef>;
55 fn metadata_key() -> &'static str;
56 fn quantization_type() -> QuantizationType;
57 fn metadata(&self, _: Option<QuantizationMetadata>) -> Self::Metadata;
58 fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer>;
59 fn field(&self) -> Field;
60 fn extra_fields(&self) -> Vec<Field> {
61 vec![]
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum QuantizationType {
67 Flat,
68 Product,
69 Scalar,
70 Rabit,
71}
72
73impl FromStr for QuantizationType {
74 type Err = Error;
75
76 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
77 match s {
78 "FLAT" => Ok(Self::Flat),
79 "PQ" => Ok(Self::Product),
80 "SQ" => Ok(Self::Scalar),
81 "RABIT" => Ok(Self::Rabit),
82 _ => Err(Error::index(format!("Unknown quantization type: {}", s))),
83 }
84 }
85}
86
87impl std::fmt::Display for QuantizationType {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match self {
90 Self::Flat => write!(f, "FLAT"),
91 Self::Product => write!(f, "PQ"),
92 Self::Scalar => write!(f, "SQ"),
93 Self::Rabit => write!(f, "RQ"),
94 }
95 }
96}
97
98pub trait QuantizerBuildParams: Send + Sync {
99 fn sample_size(&self) -> usize;
100 fn use_residual(_: DistanceType) -> bool {
101 false
102 }
103}
104
105impl QuantizerBuildParams for () {
106 fn sample_size(&self) -> usize {
107 0
108 }
109}
110
111#[derive(Debug, Clone, DeepSizeOf)]
117pub enum Quantizer {
118 Flat(FlatQuantizer),
119 FlatBin(FlatBinQuantizer),
120 Product(ProductQuantizer),
121 Scalar(ScalarQuantizer),
122 Rabit(RabitQuantizer),
123}
124
125impl Quantizer {
126 pub fn code_dim(&self) -> usize {
127 match self {
128 Self::Flat(fq) => fq.code_dim(),
129 Self::FlatBin(fq) => fq.code_dim(),
130 Self::Product(pq) => pq.code_dim(),
131 Self::Scalar(sq) => sq.code_dim(),
132 Self::Rabit(rq) => rq.code_dim(),
133 }
134 }
135
136 pub fn column(&self) -> &'static str {
137 match self {
138 Self::Flat(fq) => fq.column(),
139 Self::FlatBin(fq) => fq.column(),
140 Self::Product(pq) => pq.column(),
141 Self::Scalar(sq) => sq.column(),
142 Self::Rabit(rq) => rq.column(),
143 }
144 }
145
146 pub fn metadata_key(&self) -> &'static str {
147 match self {
148 Self::Flat(_) => FlatQuantizer::metadata_key(),
149 Self::FlatBin(_) => FlatBinQuantizer::metadata_key(),
150 Self::Product(_) => ProductQuantizer::metadata_key(),
151 Self::Scalar(_) => ScalarQuantizer::metadata_key(),
152 Self::Rabit(_) => RabitQuantizer::metadata_key(),
153 }
154 }
155
156 pub fn quantization_type(&self) -> QuantizationType {
157 match self {
158 Self::Flat(_) => QuantizationType::Flat,
159 Self::FlatBin(_) => QuantizationType::Flat,
160 Self::Product(_) => QuantizationType::Product,
161 Self::Scalar(_) => QuantizationType::Scalar,
162 Self::Rabit(_) => QuantizationType::Rabit,
163 }
164 }
165
166 pub fn metadata(&self, args: Option<QuantizationMetadata>) -> Result<serde_json::Value> {
167 let metadata = match self {
168 Self::Flat(fq) => serde_json::to_value(fq.metadata(args))?,
169 Self::FlatBin(fq) => serde_json::to_value(fq.metadata(args))?,
170 Self::Product(pq) => serde_json::to_value(pq.metadata(args))?,
171 Self::Scalar(sq) => serde_json::to_value(sq.metadata(args))?,
172 Self::Rabit(rq) => serde_json::to_value(rq.metadata(args))?,
173 };
174 Ok(metadata)
175 }
176}
177
178impl From<ProductQuantizer> for Quantizer {
179 fn from(pq: ProductQuantizer) -> Self {
180 Self::Product(pq)
181 }
182}
183
184impl From<ScalarQuantizer> for Quantizer {
185 fn from(sq: ScalarQuantizer) -> Self {
186 Self::Scalar(sq)
187 }
188}
189
190#[derive(Debug, Clone, Default)]
191pub struct QuantizationMetadata {
192 pub codebook_position: Option<usize>,
194 pub codebook: Option<FixedSizeListArray>,
195 pub transposed: bool,
196}
197
198#[async_trait]
199pub trait QuantizerMetadata:
200 fmt::Debug + Clone + Sized + DeepSizeOf + for<'a> Deserialize<'a> + Serialize
201{
202 fn buffer_index(&self) -> Option<u32> {
204 None
205 }
206
207 fn set_buffer_index(&mut self, _: u32) {
208 }
210
211 fn parse_buffer(&mut self, _bytes: Bytes) -> Result<()> {
214 Ok(())
215 }
216
217 fn extra_metadata(&self) -> Result<Option<Bytes>> {
219 Ok(None)
220 }
221
222 async fn load(reader: &PreviousFileReader) -> Result<Self>;
223}
224
225#[async_trait::async_trait]
226pub trait QuantizerStorage: Clone + Sized + DeepSizeOf + VectorStore {
227 type Metadata: QuantizerMetadata;
228
229 fn try_from_batch(
232 batch: RecordBatch,
233 metadata: &Self::Metadata,
234 distance_type: DistanceType,
235 frag_reuse_index: Option<Arc<FragReuseIndex>>,
236 ) -> Result<Self>;
237
238 fn metadata(&self) -> &Self::Metadata;
239
240 fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
241 let batches = self
242 .to_batches()?
243 .map(|b| {
244 let mut indices = Vec::with_capacity(b.num_rows());
245 let mut new_row_ids = Vec::with_capacity(b.num_rows());
246
247 let row_ids = b.column(0).as_primitive::<UInt64Type>().values();
248 for (i, row_id) in row_ids.iter().enumerate() {
249 match mapping.get(row_id) {
250 Some(Some(new_id)) => {
251 indices.push(i as u32);
252 new_row_ids.push(*new_id);
253 }
254 Some(None) => {}
255 None => {
256 indices.push(i as u32);
257 new_row_ids.push(*row_id);
258 }
259 }
260 }
261
262 let indices = UInt32Array::from(indices);
263 let new_row_ids = UInt64Array::from(new_row_ids);
264 let b = b
265 .take(&indices)?
266 .replace_column_by_name(ROW_ID, Arc::new(new_row_ids))?;
267 Ok(b)
268 })
269 .collect::<Result<Vec<_>>>()?;
270
271 let batch = concat_batches(self.schema(), batches.iter())?;
272 Self::try_from_batch(batch, self.metadata(), self.distance_type(), None)
273 }
274
275 async fn load_partition(
276 reader: &PreviousFileReader,
277 range: std::ops::Range<usize>,
278 distance_type: DistanceType,
279 metadata: &Self::Metadata,
280 frag_reuse_index: Option<Arc<FragReuseIndex>>,
281 ) -> Result<Self>;
282}
283
284pub struct IvfQuantizationStorage<Q: Quantization> {
286 reader: PreviousFileReader,
287
288 distance_type: DistanceType,
289 quantizer: Quantizer,
290 metadata: Q::Metadata,
291
292 ivf: IvfModel,
293}
294
295impl<Q: Quantization> DeepSizeOf for IvfQuantizationStorage<Q> {
296 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
297 self.reader.deep_size_of_children(context)
298 + self.quantizer.deep_size_of_children(context)
299 + self.metadata.deep_size_of_children(context)
300 + self.ivf.deep_size_of_children(context)
301 }
302}
303
304impl<Q: Quantization> Clone for IvfQuantizationStorage<Q> {
305 fn clone(&self) -> Self {
306 Self {
307 reader: self.reader.clone(),
308 distance_type: self.distance_type,
309 quantizer: self.quantizer.clone(),
310 metadata: self.metadata.clone(),
311 ivf: self.ivf.clone(),
312 }
313 }
314}
315
316#[allow(dead_code)]
317impl<Q: Quantization> IvfQuantizationStorage<Q> {
318 pub async fn open(reader: Arc<dyn Reader>) -> Result<Self> {
322 let reader = PreviousFileReader::try_new_self_described_from_reader(reader, None).await?;
323 let schema = reader.schema();
324
325 let metadata_str = schema
326 .metadata
327 .get(INDEX_METADATA_SCHEMA_KEY)
328 .ok_or(Error::index(format!(
329 "Reading quantization storage: index key {} not found",
330 INDEX_METADATA_SCHEMA_KEY
331 )))?;
332 let index_metadata: IndexMetadata = serde_json::from_str(metadata_str).map_err(|_| {
333 Error::index(format!("Failed to parse index metadata: {}", metadata_str))
334 })?;
335 let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
336
337 let ivf_data = IvfModel::load(&reader).await?;
338
339 let metadata = Q::Metadata::load(&reader).await?;
340 let quantizer = Q::from_metadata(&metadata, distance_type)?;
341 Ok(Self {
342 reader,
343 distance_type,
344 quantizer,
345 metadata,
346 ivf: ivf_data,
347 })
348 }
349
350 pub fn distance_type(&self) -> DistanceType {
351 self.distance_type
352 }
353
354 pub fn quantizer(&self) -> &Quantizer {
355 &self.quantizer
356 }
357
358 pub fn metadata(&self) -> &Q::Metadata {
359 &self.metadata
360 }
361
362 pub fn num_partitions(&self) -> usize {
364 self.ivf.num_partitions()
365 }
366
367 pub async fn load_partition(&self, part_id: usize) -> Result<Q::Storage> {
374 let range = self.ivf.row_range(part_id);
375 Q::Storage::load_partition(
376 &self.reader,
377 range,
378 self.distance_type,
379 &self.metadata,
380 None,
381 )
382 .await
383 }
384}