1use core::fmt;
5use std::str::FromStr;
6use std::sync::Arc;
7use std::{collections::HashMap, fmt::Debug};
8
9use arrow::{array::AsArray, compute::concat_batches, datatypes::UInt64Type};
10use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt32Array, UInt64Array};
11use arrow_schema::Field;
12use async_trait::async_trait;
13use bytes::Bytes;
14use deepsize::DeepSizeOf;
15use lance_arrow::RecordBatchExt;
16use lance_core::{Error, Result, ROW_ID};
17use lance_file::previous::reader::FileReader as PreviousFileReader;
18use lance_io::traits::Reader;
19use lance_linalg::distance::DistanceType;
20use lance_table::format::SelfDescribingFileReader;
21use serde::{Deserialize, Serialize};
22use snafu::location;
23
24use super::flat::index::{FlatBinQuantizer, FlatQuantizer};
25use super::pq::ProductQuantizer;
26use super::{ivf::storage::IvfModel, sq::ScalarQuantizer, storage::VectorStore};
27use crate::frag_reuse::FragReuseIndex;
28use crate::vector::bq::builder::RabitQuantizer;
29use crate::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY};
30
31pub trait Quantization:
32 Send
33 + Sync
34 + Clone
35 + Debug
36 + DeepSizeOf
37 + Into<Quantizer>
38 + TryFrom<Quantizer, Error = lance_core::Error>
39{
40 type BuildParams: QuantizerBuildParams + Send + Sync;
41 type Metadata: QuantizerMetadata + Send + Sync;
42 type Storage: QuantizerStorage<Metadata = Self::Metadata> + Debug;
43
44 fn build(
45 data: &dyn Array,
46 distance_type: DistanceType,
47 params: &Self::BuildParams,
48 ) -> Result<Self>;
49 fn retrain(&mut self, data: &dyn Array) -> Result<()>;
50 fn code_dim(&self) -> usize;
51 fn column(&self) -> &'static str;
52 fn use_residual(_: DistanceType) -> bool {
53 false
54 }
55 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef>;
56 fn metadata_key() -> &'static str;
57 fn quantization_type() -> QuantizationType;
58 fn metadata(&self, _: Option<QuantizationMetadata>) -> Self::Metadata;
59 fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer>;
60 fn field(&self) -> Field;
61 fn extra_fields(&self) -> Vec<Field> {
62 vec![]
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum QuantizationType {
68 Flat,
69 Product,
70 Scalar,
71 Rabit,
72}
73
74impl FromStr for QuantizationType {
75 type Err = Error;
76
77 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
78 match s {
79 "FLAT" => Ok(Self::Flat),
80 "PQ" => Ok(Self::Product),
81 "SQ" => Ok(Self::Scalar),
82 "RABIT" => Ok(Self::Rabit),
83 _ => Err(Error::Index {
84 message: format!("Unknown quantization type: {}", s),
85 location: location!(),
86 }),
87 }
88 }
89}
90
91impl std::fmt::Display for QuantizationType {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Self::Flat => write!(f, "FLAT"),
95 Self::Product => write!(f, "PQ"),
96 Self::Scalar => write!(f, "SQ"),
97 Self::Rabit => write!(f, "RQ"),
98 }
99 }
100}
101
102pub trait QuantizerBuildParams: Send + Sync {
103 fn sample_size(&self) -> usize;
104 fn use_residual(_: DistanceType) -> bool {
105 false
106 }
107}
108
109impl QuantizerBuildParams for () {
110 fn sample_size(&self) -> usize {
111 0
112 }
113}
114
115#[derive(Debug, Clone, DeepSizeOf)]
121pub enum Quantizer {
122 Flat(FlatQuantizer),
123 FlatBin(FlatBinQuantizer),
124 Product(ProductQuantizer),
125 Scalar(ScalarQuantizer),
126 Rabit(RabitQuantizer),
127}
128
129impl Quantizer {
130 pub fn code_dim(&self) -> usize {
131 match self {
132 Self::Flat(fq) => fq.code_dim(),
133 Self::FlatBin(fq) => fq.code_dim(),
134 Self::Product(pq) => pq.code_dim(),
135 Self::Scalar(sq) => sq.code_dim(),
136 Self::Rabit(rq) => rq.code_dim(),
137 }
138 }
139
140 pub fn column(&self) -> &'static str {
141 match self {
142 Self::Flat(fq) => fq.column(),
143 Self::FlatBin(fq) => fq.column(),
144 Self::Product(pq) => pq.column(),
145 Self::Scalar(sq) => sq.column(),
146 Self::Rabit(rq) => rq.column(),
147 }
148 }
149
150 pub fn metadata_key(&self) -> &'static str {
151 match self {
152 Self::Flat(_) => FlatQuantizer::metadata_key(),
153 Self::FlatBin(_) => FlatBinQuantizer::metadata_key(),
154 Self::Product(_) => ProductQuantizer::metadata_key(),
155 Self::Scalar(_) => ScalarQuantizer::metadata_key(),
156 Self::Rabit(_) => RabitQuantizer::metadata_key(),
157 }
158 }
159
160 pub fn quantization_type(&self) -> QuantizationType {
161 match self {
162 Self::Flat(_) => QuantizationType::Flat,
163 Self::FlatBin(_) => QuantizationType::Flat,
164 Self::Product(_) => QuantizationType::Product,
165 Self::Scalar(_) => QuantizationType::Scalar,
166 Self::Rabit(_) => QuantizationType::Rabit,
167 }
168 }
169
170 pub fn metadata(&self, args: Option<QuantizationMetadata>) -> Result<serde_json::Value> {
171 let metadata = match self {
172 Self::Flat(fq) => serde_json::to_value(fq.metadata(args))?,
173 Self::FlatBin(fq) => serde_json::to_value(fq.metadata(args))?,
174 Self::Product(pq) => serde_json::to_value(pq.metadata(args))?,
175 Self::Scalar(sq) => serde_json::to_value(sq.metadata(args))?,
176 Self::Rabit(rq) => serde_json::to_value(rq.metadata(args))?,
177 };
178 Ok(metadata)
179 }
180}
181
182impl From<ProductQuantizer> for Quantizer {
183 fn from(pq: ProductQuantizer) -> Self {
184 Self::Product(pq)
185 }
186}
187
188impl From<ScalarQuantizer> for Quantizer {
189 fn from(sq: ScalarQuantizer) -> Self {
190 Self::Scalar(sq)
191 }
192}
193
194#[derive(Debug, Clone, Default)]
195pub struct QuantizationMetadata {
196 pub codebook_position: Option<usize>,
198 pub codebook: Option<FixedSizeListArray>,
199 pub transposed: bool,
200}
201
202#[async_trait]
203pub trait QuantizerMetadata:
204 fmt::Debug + Clone + Sized + DeepSizeOf + for<'a> Deserialize<'a> + Serialize
205{
206 fn buffer_index(&self) -> Option<u32> {
208 None
209 }
210
211 fn set_buffer_index(&mut self, _: u32) {
212 }
214
215 fn parse_buffer(&mut self, _bytes: Bytes) -> Result<()> {
218 Ok(())
219 }
220
221 fn extra_metadata(&self) -> Result<Option<Bytes>> {
223 Ok(None)
224 }
225
226 async fn load(reader: &PreviousFileReader) -> Result<Self>;
227}
228
229#[async_trait::async_trait]
230pub trait QuantizerStorage: Clone + Sized + DeepSizeOf + VectorStore {
231 type Metadata: QuantizerMetadata;
232
233 fn try_from_batch(
236 batch: RecordBatch,
237 metadata: &Self::Metadata,
238 distance_type: DistanceType,
239 frag_reuse_index: Option<Arc<FragReuseIndex>>,
240 ) -> Result<Self>;
241
242 fn metadata(&self) -> &Self::Metadata;
243
244 fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
245 let batches = self
246 .to_batches()?
247 .map(|b| {
248 let mut indices = Vec::with_capacity(b.num_rows());
249 let mut new_row_ids = Vec::with_capacity(b.num_rows());
250
251 let row_ids = b.column(0).as_primitive::<UInt64Type>().values();
252 for (i, row_id) in row_ids.iter().enumerate() {
253 match mapping.get(row_id) {
254 Some(Some(new_id)) => {
255 indices.push(i as u32);
256 new_row_ids.push(*new_id);
257 }
258 Some(None) => {}
259 None => {
260 indices.push(i as u32);
261 new_row_ids.push(*row_id);
262 }
263 }
264 }
265
266 let indices = UInt32Array::from(indices);
267 let new_row_ids = UInt64Array::from(new_row_ids);
268 let b = b
269 .take(&indices)?
270 .replace_column_by_name(ROW_ID, Arc::new(new_row_ids))?;
271 Ok(b)
272 })
273 .collect::<Result<Vec<_>>>()?;
274
275 let batch = concat_batches(self.schema(), batches.iter())?;
276 Self::try_from_batch(batch, self.metadata(), self.distance_type(), None)
277 }
278
279 async fn load_partition(
280 reader: &PreviousFileReader,
281 range: std::ops::Range<usize>,
282 distance_type: DistanceType,
283 metadata: &Self::Metadata,
284 frag_reuse_index: Option<Arc<FragReuseIndex>>,
285 ) -> Result<Self>;
286}
287
288pub struct IvfQuantizationStorage<Q: Quantization> {
290 reader: PreviousFileReader,
291
292 distance_type: DistanceType,
293 quantizer: Quantizer,
294 metadata: Q::Metadata,
295
296 ivf: IvfModel,
297}
298
299impl<Q: Quantization> DeepSizeOf for IvfQuantizationStorage<Q> {
300 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
301 self.reader.deep_size_of_children(context)
302 + self.quantizer.deep_size_of_children(context)
303 + self.metadata.deep_size_of_children(context)
304 + self.ivf.deep_size_of_children(context)
305 }
306}
307
308impl<Q: Quantization> Clone for IvfQuantizationStorage<Q> {
309 fn clone(&self) -> Self {
310 Self {
311 reader: self.reader.clone(),
312 distance_type: self.distance_type,
313 quantizer: self.quantizer.clone(),
314 metadata: self.metadata.clone(),
315 ivf: self.ivf.clone(),
316 }
317 }
318}
319
320#[allow(dead_code)]
321impl<Q: Quantization> IvfQuantizationStorage<Q> {
322 pub async fn open(reader: Arc<dyn Reader>) -> Result<Self> {
326 let reader = PreviousFileReader::try_new_self_described_from_reader(reader, None).await?;
327 let schema = reader.schema();
328
329 let metadata_str = schema
330 .metadata
331 .get(INDEX_METADATA_SCHEMA_KEY)
332 .ok_or(Error::Index {
333 message: format!(
334 "Reading quantization storage: index key {} not found",
335 INDEX_METADATA_SCHEMA_KEY
336 ),
337 location: location!(),
338 })?;
339 let index_metadata: IndexMetadata =
340 serde_json::from_str(metadata_str).map_err(|_| Error::Index {
341 message: format!("Failed to parse index metadata: {}", metadata_str),
342 location: location!(),
343 })?;
344 let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
345
346 let ivf_data = IvfModel::load(&reader).await?;
347
348 let metadata = Q::Metadata::load(&reader).await?;
349 let quantizer = Q::from_metadata(&metadata, distance_type)?;
350 Ok(Self {
351 reader,
352 distance_type,
353 quantizer,
354 metadata,
355 ivf: ivf_data,
356 })
357 }
358
359 pub fn distance_type(&self) -> DistanceType {
360 self.distance_type
361 }
362
363 pub fn quantizer(&self) -> &Quantizer {
364 &self.quantizer
365 }
366
367 pub fn metadata(&self) -> &Q::Metadata {
368 &self.metadata
369 }
370
371 pub fn num_partitions(&self) -> usize {
373 self.ivf.num_partitions()
374 }
375
376 pub async fn load_partition(&self, part_id: usize) -> Result<Q::Storage> {
383 let range = self.ivf.row_range(part_id);
384 Q::Storage::load_partition(
385 &self.reader,
386 range,
387 self.distance_type,
388 &self.metadata,
389 None,
390 )
391 .await
392 }
393}