lance_index/vector/sq/
storage.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::ops::Range;
5
6use arrow::datatypes::Float64Type;
7use arrow::{compute::concat_batches, datatypes::Float16Type};
8use arrow_array::{
9    cast::AsArray,
10    types::{Float32Type, UInt64Type, UInt8Type},
11    ArrayRef, RecordBatch, UInt64Array, UInt8Array,
12};
13use arrow_schema::{DataType, SchemaRef};
14use async_trait::async_trait;
15use deepsize::DeepSizeOf;
16use lance_core::{Error, Result, ROW_ID};
17use lance_file::reader::FileReader;
18use lance_io::object_store::ObjectStore;
19use lance_linalg::distance::{dot_distance, l2_distance_uint_scalar, DistanceType};
20use lance_table::format::SelfDescribingFileReader;
21use object_store::path::Path;
22use serde::{Deserialize, Serialize};
23use snafu::location;
24use std::sync::Arc;
25
26use super::{inverse_scalar_dist, scale_to_u8, ScalarQuantizer};
27use crate::frag_reuse::FragReuseIndex;
28use crate::{
29    vector::{
30        quantizer::{QuantizerMetadata, QuantizerStorage},
31        storage::{DistCalculator, VectorStore},
32        transform::Transformer,
33        SQ_CODE_COLUMN,
34    },
35    IndexMetadata, INDEX_METADATA_SCHEMA_KEY,
36};
37
38pub const SQ_METADATA_KEY: &str = "lance:sq";
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ScalarQuantizationMetadata {
42    pub dim: usize,
43    pub num_bits: u16,
44    pub bounds: Range<f64>,
45}
46
47impl DeepSizeOf for ScalarQuantizationMetadata {
48    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
49        0
50    }
51}
52
53#[async_trait]
54impl QuantizerMetadata for ScalarQuantizationMetadata {
55    async fn load(reader: &FileReader) -> Result<Self> {
56        let metadata_str = reader
57            .schema()
58            .metadata
59            .get(SQ_METADATA_KEY)
60            .ok_or(Error::Index {
61                message: format!(
62                    "Reading SQ metadata: metadata key {} not found",
63                    SQ_METADATA_KEY
64                ),
65                location: location!(),
66            })?;
67        serde_json::from_str(metadata_str).map_err(|_| Error::Index {
68            message: format!("Failed to parse index metadata: {}", metadata_str),
69            location: location!(),
70        })
71    }
72}
73
74/// An immutable chunk of ScalarQuantizationStorage.
75#[derive(Debug, Clone)]
76struct SQStorageChunk {
77    batch: RecordBatch,
78
79    dim: usize,
80
81    // Helper fields, references to the batch
82    // These fields share the `Arc` pointer to the columns in batch,
83    // so it does not take more memory.
84    row_ids: UInt64Array,
85    sq_codes: UInt8Array,
86}
87
88impl SQStorageChunk {
89    // Create a new chunk from a RecordBatch.
90    fn new(batch: RecordBatch) -> Result<Self> {
91        let row_ids = batch
92            .column_by_name(ROW_ID)
93            .ok_or(Error::Index {
94                message: "Row ID column not found in the batch".to_owned(),
95                location: location!(),
96            })?
97            .as_primitive::<UInt64Type>()
98            .clone();
99        let fsl = batch
100            .column_by_name(SQ_CODE_COLUMN)
101            .ok_or(Error::Index {
102                message: "SQ code column not found in the batch".to_owned(),
103                location: location!(),
104            })?
105            .as_fixed_size_list();
106        let dim = fsl.value_length() as usize;
107        let sq_codes = fsl
108            .values()
109            .as_primitive_opt::<UInt8Type>()
110            .ok_or(Error::Index {
111                message: "SQ code column is not FixedSizeList<u8>".to_owned(),
112                location: location!(),
113            })?
114            .clone();
115        Ok(Self {
116            batch,
117            dim,
118            row_ids,
119            sq_codes,
120        })
121    }
122
123    /// Returns vector dimension
124    fn dim(&self) -> usize {
125        self.dim
126    }
127
128    fn len(&self) -> usize {
129        self.row_ids.len()
130    }
131
132    fn schema(&self) -> &SchemaRef {
133        self.batch.schema_ref()
134    }
135
136    #[inline]
137    fn row_id(&self, id: u32) -> u64 {
138        self.row_ids.value(id as usize)
139    }
140
141    /// Get a slice of SQ code for id
142    #[inline]
143    fn sq_code_slice(&self, id: u32) -> &[u8] {
144        // assert!(id < self.len() as u32);
145        &self.sq_codes.values()[id as usize * self.dim..(id + 1) as usize * self.dim]
146    }
147}
148
149impl DeepSizeOf for SQStorageChunk {
150    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
151        self.batch.get_array_memory_size()
152    }
153}
154
155#[derive(Debug, Clone)]
156pub struct ScalarQuantizationStorage {
157    quantizer: ScalarQuantizer,
158
159    distance_type: DistanceType,
160
161    /// Chunks of storage
162    offsets: Vec<u32>,
163    chunks: Vec<SQStorageChunk>,
164}
165
166impl DeepSizeOf for ScalarQuantizationStorage {
167    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
168        self.chunks
169            .iter()
170            .map(|c| c.deep_size_of_children(context))
171            .sum()
172    }
173}
174
175const SQ_CHUNK_CAPACITY: usize = 1024;
176
177impl ScalarQuantizationStorage {
178    pub fn try_new(
179        num_bits: u16,
180        distance_type: DistanceType,
181        bounds: Range<f64>,
182        batches: impl IntoIterator<Item = RecordBatch>,
183        fri: Option<Arc<FragReuseIndex>>,
184    ) -> Result<Self> {
185        let mut chunks = Vec::with_capacity(SQ_CHUNK_CAPACITY);
186        let mut offsets = Vec::with_capacity(SQ_CHUNK_CAPACITY + 1);
187        offsets.push(0);
188        for mut batch in batches.into_iter() {
189            if let Some(fri_ref) = fri.as_ref() {
190                batch = fri_ref.remap_row_ids_record_batch(batch, 0)?
191            }
192            offsets.push(offsets.last().unwrap() + batch.num_rows() as u32);
193            let chunk = SQStorageChunk::new(batch)?;
194            chunks.push(chunk);
195        }
196        let quantizer = ScalarQuantizer::with_bounds(num_bits, chunks[0].dim(), bounds);
197
198        Ok(Self {
199            quantizer,
200            distance_type,
201            offsets,
202            chunks,
203        })
204    }
205
206    /// Get the chunk that covers the id.
207    ///
208    /// Returns:
209    /// `(offset, chunk)`
210    ///
211    /// We did not check out of range in this call. But the out of range will
212    /// panic once you access the data in the last [SQStorageChunk].
213    fn chunk(&self, id: u32) -> (u32, &SQStorageChunk) {
214        match self.offsets.binary_search(&id) {
215            Ok(o) => (self.offsets[o], &self.chunks[o]),
216            Err(o) => (self.offsets[o - 1], &self.chunks[o - 1]),
217        }
218    }
219
220    pub async fn load(
221        object_store: &ObjectStore,
222        path: &Path,
223        fri: Option<Arc<FragReuseIndex>>,
224    ) -> Result<Self> {
225        let reader = FileReader::try_new_self_described(object_store, path, None).await?;
226        let schema = reader.schema();
227
228        let metadata_str = schema
229            .metadata
230            .get(INDEX_METADATA_SCHEMA_KEY)
231            .ok_or(Error::Index {
232                message: format!(
233                    "Reading SQ storage: index key {} not found",
234                    INDEX_METADATA_SCHEMA_KEY
235                ),
236                location: location!(),
237            })?;
238        let index_metadata: IndexMetadata =
239            serde_json::from_str(metadata_str).map_err(|_| Error::Index {
240                message: format!("Failed to parse index metadata: {}", metadata_str),
241                location: location!(),
242            })?;
243        let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
244        let metadata = ScalarQuantizationMetadata::load(&reader).await?;
245
246        Self::load_partition(&reader, 0..reader.len(), distance_type, &metadata, fri).await
247    }
248
249    fn optimize(self) -> Result<Self> {
250        if self.len() <= SQ_CHUNK_CAPACITY {
251            Ok(self)
252        } else {
253            let mut new = self.clone();
254            let batch = concat_batches(
255                self.chunks[0].schema(),
256                self.chunks.iter().map(|c| &c.batch),
257            )?;
258            new.offsets = vec![0, batch.num_rows() as u32];
259            new.chunks = vec![SQStorageChunk::new(batch)?];
260            Ok(new)
261        }
262    }
263}
264
265#[async_trait]
266impl QuantizerStorage for ScalarQuantizationStorage {
267    type Metadata = ScalarQuantizationMetadata;
268
269    fn try_from_batch(
270        batch: RecordBatch,
271        metadata: &Self::Metadata,
272        distance_type: DistanceType,
273        fri: Option<Arc<FragReuseIndex>>,
274    ) -> Result<Self>
275    where
276        Self: Sized,
277    {
278        Self::try_new(
279            metadata.num_bits,
280            distance_type,
281            metadata.bounds.clone(),
282            [batch],
283            fri,
284        )
285    }
286
287    fn metadata(&self) -> &Self::Metadata {
288        &self.quantizer.metadata
289    }
290
291    /// Load a partition of SQ storage from disk.
292    ///
293    /// Parameters
294    /// ----------
295    /// - *reader: file reader
296    /// - *range: row range of the partition
297    /// - *metric_type: metric type of the vectors
298    /// - *metadata: scalar quantization metadata
299    async fn load_partition(
300        reader: &FileReader,
301        range: std::ops::Range<usize>,
302        distance_type: DistanceType,
303        metadata: &Self::Metadata,
304        fri: Option<Arc<FragReuseIndex>>,
305    ) -> Result<Self> {
306        let schema = reader.schema();
307        let batch = reader.read_range(range, schema).await?;
308
309        Self::try_new(
310            metadata.num_bits,
311            distance_type,
312            metadata.bounds.clone(),
313            [batch],
314            fri,
315        )
316    }
317}
318
319impl VectorStore for ScalarQuantizationStorage {
320    type DistanceCalculator<'a> = SQDistCalculator<'a>;
321
322    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
323        Ok(self.chunks.iter().map(|c| c.batch.clone()))
324    }
325
326    fn append_batch(&self, batch: RecordBatch, vector_column: &str) -> Result<Self> {
327        // TODO: use chunked storage
328        let transformer = super::transform::SQTransformer::new(
329            self.quantizer.clone(),
330            vector_column.to_string(),
331            SQ_CODE_COLUMN.to_string(),
332        );
333
334        let new_batch = transformer.transform(&batch)?;
335
336        // self.quantizer.transform(data)
337        let mut storage = self.clone();
338        let offset = self.len() as u32;
339        let new_chunk = SQStorageChunk::new(new_batch)?;
340        storage.offsets.push(offset + new_chunk.len() as u32);
341        storage.chunks.push(new_chunk);
342
343        storage.optimize()
344    }
345
346    fn schema(&self) -> &SchemaRef {
347        self.chunks[0].schema()
348    }
349
350    fn as_any(&self) -> &dyn std::any::Any {
351        self
352    }
353
354    fn len(&self) -> usize {
355        *self.offsets.last().unwrap() as usize
356    }
357
358    /// Return the [DistanceType] of the vectors.
359    fn distance_type(&self) -> DistanceType {
360        self.distance_type
361    }
362
363    fn row_id(&self, id: u32) -> u64 {
364        let (offset, chunk) = self.chunk(id);
365        chunk.row_id(id - offset)
366    }
367
368    fn row_ids(&self) -> impl Iterator<Item = &u64> {
369        self.chunks.iter().flat_map(|c| c.row_ids.values())
370    }
371
372    /// Create a [DistCalculator] to compute the distance between the query.
373    ///
374    /// Using dist calculator can be more efficient as it can pre-compute some
375    /// values.
376    fn dist_calculator(&self, query: ArrayRef) -> Self::DistanceCalculator<'_> {
377        SQDistCalculator::new(query, self, self.quantizer.bounds())
378    }
379
380    fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
381        let (offset, chunk) = self.chunk(id);
382        let query_sq_code = chunk.sq_code_slice(id - offset).to_vec();
383        SQDistCalculator {
384            query_sq_code,
385            bounds: self.quantizer.bounds(),
386            storage: self,
387        }
388    }
389}
390
391pub struct SQDistCalculator<'a> {
392    query_sq_code: Vec<u8>,
393    bounds: Range<f64>,
394    storage: &'a ScalarQuantizationStorage,
395}
396
397impl<'a> SQDistCalculator<'a> {
398    fn new(query: ArrayRef, storage: &'a ScalarQuantizationStorage, bounds: Range<f64>) -> Self {
399        // This is okay-ish to use hand-rolled dynamic dispatch here
400        // since we search 10s-100s of partitions, we can afford the overhead
401        // this could be annoying at indexing time for HNSW, which requires constructing the
402        // dist calculator frequently. However, HNSW isn't first-class citizen in Lance yet. so be it.
403        let query_sq_code = match query.data_type() {
404            DataType::Float16 => {
405                scale_to_u8::<Float16Type>(query.as_primitive::<Float16Type>().values(), &bounds)
406            }
407            DataType::Float32 => {
408                scale_to_u8::<Float32Type>(query.as_primitive::<Float32Type>().values(), &bounds)
409            }
410            DataType::Float64 => {
411                scale_to_u8::<Float64Type>(query.as_primitive::<Float64Type>().values(), &bounds)
412            }
413            _ => {
414                panic!("Unsupported data type for ScalarQuantizationStorage");
415            }
416        };
417        Self {
418            query_sq_code,
419            bounds,
420            storage,
421        }
422    }
423}
424
425impl DistCalculator for SQDistCalculator<'_> {
426    fn distance(&self, id: u32) -> f32 {
427        let (offset, chunk) = self.storage.chunk(id);
428        let sq_code = chunk.sq_code_slice(id - offset);
429        let dist = match self.storage.distance_type {
430            DistanceType::L2 | DistanceType::Cosine => {
431                l2_distance_uint_scalar(sq_code, &self.query_sq_code)
432            }
433            DistanceType::Dot => dot_distance(sq_code, &self.query_sq_code),
434            _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
435        };
436        inverse_scalar_dist(std::iter::once(dist), &self.bounds)[0]
437    }
438
439    fn distance_all(&self, _k_hint: usize) -> Vec<f32> {
440        match self.storage.distance_type {
441            DistanceType::L2 | DistanceType::Cosine => inverse_scalar_dist(
442                self.storage.chunks.iter().flat_map(|c| {
443                    c.sq_codes
444                        .values()
445                        .chunks_exact(c.dim())
446                        .map(|sq_codes| l2_distance_uint_scalar(sq_codes, &self.query_sq_code))
447                }),
448                &self.bounds,
449            ),
450            DistanceType::Dot => inverse_scalar_dist(
451                self.storage.chunks.iter().flat_map(|c| {
452                    c.sq_codes
453                        .values()
454                        .chunks_exact(c.dim())
455                        .map(|sq_codes| dot_distance(sq_codes, &self.query_sq_code))
456                }),
457                &self.bounds,
458            ),
459            _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
460        }
461    }
462
463    #[allow(unused_variables)]
464    fn prefetch(&self, id: u32) {
465        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
466        {
467            const CACHE_LINE_SIZE: usize = 64;
468
469            let (offset, chunk) = self.storage.chunk(id);
470            let dim = chunk.dim();
471            let base_ptr = chunk.sq_code_slice(id - offset).as_ptr();
472
473            unsafe {
474                // Loop over the sq_code to prefetch each cache line
475                for offset in (0..dim).step_by(CACHE_LINE_SIZE) {
476                    {
477                        use core::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
478                        _mm_prefetch(base_ptr.add(offset) as *const i8, _MM_HINT_T0);
479                    }
480                }
481            }
482        }
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489
490    use std::iter::repeat_with;
491    use std::sync::Arc;
492
493    use arrow_array::FixedSizeListArray;
494    use arrow_schema::{DataType, Field, Schema};
495    use lance_arrow::FixedSizeListArrayExt;
496    use lance_testing::datagen::generate_random_array;
497    use rand::prelude::*;
498
499    fn create_record_batch(row_ids: Range<u64>) -> RecordBatch {
500        const DIM: usize = 64;
501
502        let mut rng = rand::thread_rng();
503        let row_ids = UInt64Array::from_iter_values(row_ids);
504        let sq_code =
505            UInt8Array::from_iter_values(repeat_with(|| rng.gen::<u8>()).take(row_ids.len() * DIM));
506        let code_arr = FixedSizeListArray::try_new_from_values(sq_code, DIM as i32).unwrap();
507
508        let schema = Arc::new(Schema::new(vec![
509            Field::new(ROW_ID, DataType::UInt64, false),
510            Field::new(
511                SQ_CODE_COLUMN,
512                DataType::FixedSizeList(
513                    Arc::new(Field::new("item", DataType::UInt8, true)),
514                    DIM as i32,
515                ),
516                false,
517            ),
518        ]));
519        RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(code_arr)]).unwrap()
520    }
521
522    #[test]
523    fn test_get_chunks() {
524        const DIM: usize = 64;
525
526        let storage = ScalarQuantizationStorage::try_new(
527            8,
528            DistanceType::L2,
529            -0.7..0.7,
530            (0..4).map(|start| create_record_batch(start * 100..(start + 1) * 100)),
531            None,
532        )
533        .unwrap();
534
535        assert_eq!(storage.len(), 400);
536
537        let (offset, chunk) = storage.chunk(0);
538        assert_eq!(offset, 0);
539        assert_eq!(chunk.row_id(20), 20);
540
541        let (offset, _) = storage.chunk(50);
542        assert_eq!(offset, 0);
543
544        let row_ids = UInt64Array::from_iter_values(100..250);
545        let vector_data = generate_random_array(row_ids.len() * DIM);
546        let fsl = FixedSizeListArray::try_new_from_values(vector_data, DIM as i32).unwrap();
547
548        let schema = Arc::new(Schema::new(vec![
549            Field::new(ROW_ID, DataType::UInt64, false),
550            Field::new(
551                "vector",
552                DataType::FixedSizeList(
553                    Arc::new(Field::new("item", DataType::Float32, true)),
554                    DIM as i32,
555                ),
556                false,
557            ),
558        ]));
559
560        let second_batch =
561            RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(fsl)]).unwrap();
562        let storage = storage.append_batch(second_batch, "vector").unwrap();
563
564        assert_eq!(storage.len(), 550);
565        let (offset, chunk) = storage.chunk(112);
566        assert_eq!(offset, 100);
567        assert_eq!(chunk.row_id(10), 110);
568
569        let (offset, chunk) = storage.chunk(432);
570        assert_eq!(offset, 400);
571        assert_eq!(chunk.row_id(5), 105);
572    }
573}