Skip to main content

lance_index/vector/sq/
storage.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::ops::Range;
5
6use arrow::datatypes::Float64Type;
7use arrow::{compute::concat_batches, datatypes::Float16Type};
8use arrow_array::{
9    ArrayRef, RecordBatch, UInt8Array, UInt64Array,
10    cast::AsArray,
11    types::{Float32Type, UInt8Type, UInt64Type},
12};
13use arrow_schema::{DataType, SchemaRef};
14use async_trait::async_trait;
15use deepsize::DeepSizeOf;
16use lance_core::{Error, ROW_ID, Result};
17use lance_file::previous::reader::FileReader as PreviousFileReader;
18use lance_io::object_store::ObjectStore;
19use lance_linalg::distance::{DistanceType, dot_distance, l2_distance_uint_scalar};
20use lance_table::format::SelfDescribingFileReader;
21use object_store::path::Path;
22use serde::{Deserialize, Serialize};
23use std::sync::Arc;
24
25use super::{ScalarQuantizer, scale_to_u8};
26use crate::frag_reuse::FragReuseIndex;
27use crate::{
28    INDEX_METADATA_SCHEMA_KEY, IndexMetadata,
29    vector::{
30        SQ_CODE_COLUMN,
31        quantizer::{QuantizerMetadata, QuantizerStorage},
32        storage::{DistCalculator, VectorStore},
33        transform::Transformer,
34    },
35};
36
37pub const SQ_METADATA_KEY: &str = "lance:sq";
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40pub struct ScalarQuantizationMetadata {
41    pub dim: usize,
42    pub num_bits: u16,
43    pub bounds: Range<f64>,
44}
45
46impl DeepSizeOf for ScalarQuantizationMetadata {
47    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
48        0
49    }
50}
51
52#[async_trait]
53impl QuantizerMetadata for ScalarQuantizationMetadata {
54    async fn load(reader: &PreviousFileReader) -> Result<Self> {
55        let metadata_str = reader
56            .schema()
57            .metadata
58            .get(SQ_METADATA_KEY)
59            .ok_or(Error::index(format!(
60                "Reading SQ metadata: metadata key {} not found",
61                SQ_METADATA_KEY
62            )))?;
63        serde_json::from_str(metadata_str)
64            .map_err(|_| Error::index(format!("Failed to parse index metadata: {}", metadata_str)))
65    }
66}
67
68/// An immutable chunk of ScalarQuantizationStorage.
69#[derive(Debug, Clone)]
70struct SQStorageChunk {
71    batch: RecordBatch,
72
73    dim: usize,
74
75    // Helper fields, references to the batch
76    // These fields share the `Arc` pointer to the columns in batch,
77    // so it does not take more memory.
78    row_ids: UInt64Array,
79    sq_codes: UInt8Array,
80}
81
82impl SQStorageChunk {
83    // Create a new chunk from a RecordBatch.
84    fn new(batch: RecordBatch) -> Result<Self> {
85        let row_ids = batch
86            .column_by_name(ROW_ID)
87            .ok_or(Error::index(
88                "Row ID column not found in the batch".to_owned(),
89            ))?
90            .as_primitive::<UInt64Type>()
91            .clone();
92        let fsl = batch
93            .column_by_name(SQ_CODE_COLUMN)
94            .ok_or(Error::index(
95                "SQ code column not found in the batch".to_owned(),
96            ))?
97            .as_fixed_size_list();
98        let dim = fsl.value_length() as usize;
99        let sq_codes = fsl
100            .values()
101            .as_primitive_opt::<UInt8Type>()
102            .ok_or(Error::index(
103                "SQ code column is not FixedSizeList<u8>".to_owned(),
104            ))?
105            .clone();
106        Ok(Self {
107            batch,
108            dim,
109            row_ids,
110            sq_codes,
111        })
112    }
113
114    /// Returns vector dimension
115    fn dim(&self) -> usize {
116        self.dim
117    }
118
119    fn len(&self) -> usize {
120        self.row_ids.len()
121    }
122
123    fn schema(&self) -> &SchemaRef {
124        self.batch.schema_ref()
125    }
126
127    #[inline]
128    fn row_id(&self, id: u32) -> u64 {
129        self.row_ids.value(id as usize)
130    }
131
132    /// Get a slice of SQ code for id
133    #[inline]
134    fn sq_code_slice(&self, id: u32) -> &[u8] {
135        // assert!(id < self.len() as u32);
136        &self.sq_codes.values()[id as usize * self.dim..(id + 1) as usize * self.dim]
137    }
138}
139
140impl DeepSizeOf for SQStorageChunk {
141    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
142        self.batch.get_array_memory_size()
143    }
144}
145
146#[derive(Debug, Clone)]
147pub struct ScalarQuantizationStorage {
148    quantizer: ScalarQuantizer,
149
150    distance_type: DistanceType,
151
152    /// Chunks of storage
153    offsets: Vec<u32>,
154    chunks: Vec<SQStorageChunk>,
155}
156
157impl DeepSizeOf for ScalarQuantizationStorage {
158    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
159        self.chunks
160            .iter()
161            .map(|c| c.deep_size_of_children(context))
162            .sum()
163    }
164}
165
166const SQ_CHUNK_CAPACITY: usize = 1024;
167
168impl ScalarQuantizationStorage {
169    pub fn try_new(
170        num_bits: u16,
171        distance_type: DistanceType,
172        bounds: Range<f64>,
173        batches: impl IntoIterator<Item = RecordBatch>,
174        frag_reuse_index: Option<Arc<FragReuseIndex>>,
175    ) -> Result<Self> {
176        let mut chunks = Vec::with_capacity(SQ_CHUNK_CAPACITY);
177        let mut offsets = Vec::with_capacity(SQ_CHUNK_CAPACITY + 1);
178        offsets.push(0);
179        for mut batch in batches.into_iter() {
180            if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
181                batch = frag_reuse_index_ref.remap_row_ids_record_batch(batch, 0)?
182            }
183            offsets.push(offsets.last().unwrap() + batch.num_rows() as u32);
184            let chunk = SQStorageChunk::new(batch)?;
185            chunks.push(chunk);
186        }
187        let quantizer = ScalarQuantizer::with_bounds(num_bits, chunks[0].dim(), bounds);
188
189        Ok(Self {
190            quantizer,
191            distance_type,
192            offsets,
193            chunks,
194        })
195    }
196
197    /// Get the chunk that covers the id.
198    ///
199    /// Returns:
200    /// `(offset, chunk)`
201    ///
202    /// We did not check out of range in this call. But the out of range will
203    /// panic once you access the data in the last [SQStorageChunk].
204    fn chunk(&self, id: u32) -> (u32, &SQStorageChunk) {
205        match self.offsets.binary_search(&id) {
206            Ok(o) => (self.offsets[o], &self.chunks[o]),
207            Err(o) => (self.offsets[o - 1], &self.chunks[o - 1]),
208        }
209    }
210
211    pub async fn load(
212        object_store: &ObjectStore,
213        path: &Path,
214        frag_reuse_index: Option<Arc<FragReuseIndex>>,
215    ) -> Result<Self> {
216        let reader = PreviousFileReader::try_new_self_described(object_store, path, None).await?;
217        let schema = reader.schema();
218
219        let metadata_str = schema
220            .metadata
221            .get(INDEX_METADATA_SCHEMA_KEY)
222            .ok_or(Error::index(format!(
223                "Reading SQ storage: index key {} not found",
224                INDEX_METADATA_SCHEMA_KEY
225            )))?;
226        let index_metadata: IndexMetadata = serde_json::from_str(metadata_str).map_err(|_| {
227            Error::index(format!("Failed to parse index metadata: {}", metadata_str))
228        })?;
229        let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
230        let metadata = ScalarQuantizationMetadata::load(&reader).await?;
231
232        Self::load_partition(
233            &reader,
234            0..reader.len(),
235            distance_type,
236            &metadata,
237            frag_reuse_index,
238        )
239        .await
240    }
241
242    fn optimize(self) -> Result<Self> {
243        if self.len() <= SQ_CHUNK_CAPACITY {
244            Ok(self)
245        } else {
246            let mut new = self.clone();
247            let batch = concat_batches(
248                self.chunks[0].schema(),
249                self.chunks.iter().map(|c| &c.batch),
250            )?;
251            new.offsets = vec![0, batch.num_rows() as u32];
252            new.chunks = vec![SQStorageChunk::new(batch)?];
253            Ok(new)
254        }
255    }
256}
257
258#[async_trait]
259impl QuantizerStorage for ScalarQuantizationStorage {
260    type Metadata = ScalarQuantizationMetadata;
261
262    fn try_from_batch(
263        batch: RecordBatch,
264        metadata: &Self::Metadata,
265        distance_type: DistanceType,
266        frag_reuse_index: Option<Arc<FragReuseIndex>>,
267    ) -> Result<Self>
268    where
269        Self: Sized,
270    {
271        Self::try_new(
272            metadata.num_bits,
273            distance_type,
274            metadata.bounds.clone(),
275            [batch],
276            frag_reuse_index,
277        )
278    }
279
280    fn metadata(&self) -> &Self::Metadata {
281        &self.quantizer.metadata
282    }
283
284    /// Load a partition of SQ storage from disk.
285    ///
286    /// Parameters
287    /// ----------
288    /// - *reader: file reader
289    /// - *range: row range of the partition
290    /// - *metric_type: metric type of the vectors
291    /// - *metadata: scalar quantization metadata
292    async fn load_partition(
293        reader: &PreviousFileReader,
294        range: std::ops::Range<usize>,
295        distance_type: DistanceType,
296        metadata: &Self::Metadata,
297        frag_reuse_index: Option<Arc<FragReuseIndex>>,
298    ) -> Result<Self> {
299        let schema = reader.schema();
300        let batch = reader.read_range(range, schema).await?;
301
302        Self::try_new(
303            metadata.num_bits,
304            distance_type,
305            metadata.bounds.clone(),
306            [batch],
307            frag_reuse_index,
308        )
309    }
310}
311
312impl VectorStore for ScalarQuantizationStorage {
313    type DistanceCalculator<'a> = SQDistCalculator<'a>;
314
315    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
316        Ok(self.chunks.iter().map(|c| c.batch.clone()))
317    }
318
319    fn append_batch(&self, batch: RecordBatch, vector_column: &str) -> Result<Self> {
320        // TODO: use chunked storage
321        let transformer = super::transform::SQTransformer::new(
322            self.quantizer.clone(),
323            vector_column.to_string(),
324            SQ_CODE_COLUMN.to_string(),
325        );
326
327        let new_batch = transformer.transform(&batch)?;
328
329        // self.quantizer.transform(data)
330        let mut storage = self.clone();
331        let offset = self.len() as u32;
332        let new_chunk = SQStorageChunk::new(new_batch)?;
333        storage.offsets.push(offset + new_chunk.len() as u32);
334        storage.chunks.push(new_chunk);
335
336        storage.optimize()
337    }
338
339    fn schema(&self) -> &SchemaRef {
340        self.chunks[0].schema()
341    }
342
343    fn as_any(&self) -> &dyn std::any::Any {
344        self
345    }
346
347    fn len(&self) -> usize {
348        *self.offsets.last().unwrap() as usize
349    }
350
351    /// Return the [DistanceType] of the vectors.
352    fn distance_type(&self) -> DistanceType {
353        self.distance_type
354    }
355
356    fn row_id(&self, id: u32) -> u64 {
357        let (offset, chunk) = self.chunk(id);
358        chunk.row_id(id - offset)
359    }
360
361    fn row_ids(&self) -> impl Iterator<Item = &u64> {
362        self.chunks.iter().flat_map(|c| c.row_ids.values())
363    }
364
365    /// Create a [DistCalculator] to compute the distance between the query.
366    ///
367    /// Using dist calculator can be more efficient as it can pre-compute some
368    /// values.
369    fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
370        SQDistCalculator::new(query, self, self.quantizer.bounds())
371    }
372
373    fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
374        let (offset, chunk) = self.chunk(id);
375        let query_sq_code = chunk.sq_code_slice(id - offset).to_vec();
376        let bounds = self.quantizer.bounds();
377        SQDistCalculator {
378            query_sq_code,
379            scale: sq_distance_scale(&bounds),
380            storage: self,
381        }
382    }
383}
384
385#[inline]
386fn sq_distance_scale(bounds: &Range<f64>) -> f32 {
387    let range = (bounds.end - bounds.start) as f32;
388    (range * range) / (255.0_f32 * 255.0_f32)
389}
390
391pub struct SQDistCalculator<'a> {
392    query_sq_code: Vec<u8>,
393    scale: f32,
394    storage: &'a ScalarQuantizationStorage,
395}
396
397impl<'a> SQDistCalculator<'a> {
398    fn new(query: ArrayRef, storage: &'a ScalarQuantizationStorage, bounds: Range<f64>) -> Self {
399        // This is okay-ish to use hand-rolled dynamic dispatch here
400        // since we search 10s-100s of partitions, we can afford the overhead
401        // this could be annoying at indexing time for HNSW, which requires constructing the
402        // dist calculator frequently. However, HNSW isn't first-class citizen in Lance yet. so be it.
403        let query_sq_code = match query.data_type() {
404            DataType::Float16 => {
405                scale_to_u8::<Float16Type>(query.as_primitive::<Float16Type>().values(), &bounds)
406            }
407            DataType::Float32 => {
408                scale_to_u8::<Float32Type>(query.as_primitive::<Float32Type>().values(), &bounds)
409            }
410            DataType::Float64 => {
411                scale_to_u8::<Float64Type>(query.as_primitive::<Float64Type>().values(), &bounds)
412            }
413            _ => {
414                panic!("Unsupported data type for ScalarQuantizationStorage");
415            }
416        };
417        Self {
418            query_sq_code,
419            scale: sq_distance_scale(&bounds),
420            storage,
421        }
422    }
423}
424
425impl DistCalculator for SQDistCalculator<'_> {
426    fn distance(&self, id: u32) -> f32 {
427        let (offset, chunk) = self.storage.chunk(id);
428        let sq_code = chunk.sq_code_slice(id - offset);
429        let dist = match self.storage.distance_type {
430            DistanceType::L2 | DistanceType::Cosine => {
431                l2_distance_uint_scalar(sq_code, &self.query_sq_code)
432            }
433            DistanceType::Dot => dot_distance(sq_code, &self.query_sq_code),
434            _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
435        };
436        dist * self.scale
437    }
438
439    fn distance_all(&self, _k_hint: usize) -> Vec<f32> {
440        match self.storage.distance_type {
441            DistanceType::L2 | DistanceType::Cosine => self
442                .storage
443                .chunks
444                .iter()
445                .flat_map(|c| {
446                    c.sq_codes
447                        .values()
448                        .chunks_exact(c.dim())
449                        .map(|sq_codes| l2_distance_uint_scalar(sq_codes, &self.query_sq_code))
450                })
451                .map(|dist| dist * self.scale)
452                .collect(),
453            DistanceType::Dot => self
454                .storage
455                .chunks
456                .iter()
457                .flat_map(|c| {
458                    c.sq_codes
459                        .values()
460                        .chunks_exact(c.dim())
461                        .map(|sq_codes| dot_distance(sq_codes, &self.query_sq_code))
462                })
463                .map(|dist| dist * self.scale)
464                .collect(),
465            _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
466        }
467    }
468
469    #[allow(unused_variables)]
470    fn prefetch(&self, id: u32) {
471        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
472        {
473            const CACHE_LINE_SIZE: usize = 64;
474
475            let (offset, chunk) = self.storage.chunk(id);
476            let dim = chunk.dim();
477            let base_ptr = chunk.sq_code_slice(id - offset).as_ptr();
478
479            unsafe {
480                // Loop over the sq_code to prefetch each cache line
481                for offset in (0..dim).step_by(CACHE_LINE_SIZE) {
482                    {
483                        use core::arch::x86_64::{_MM_HINT_T0, _mm_prefetch};
484                        _mm_prefetch(base_ptr.add(offset) as *const i8, _MM_HINT_T0);
485                    }
486                }
487            }
488        }
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    use std::iter::repeat_with;
497    use std::sync::Arc;
498
499    use arrow_array::FixedSizeListArray;
500    use arrow_schema::{DataType, Field, Schema};
501    use lance_arrow::FixedSizeListArrayExt;
502    use lance_testing::datagen::generate_random_array;
503    use rand::prelude::*;
504
505    fn create_record_batch(row_ids: Range<u64>) -> RecordBatch {
506        const DIM: usize = 64;
507
508        let mut rng = rand::rng();
509        let row_ids = UInt64Array::from_iter_values(row_ids);
510        let sq_code = UInt8Array::from_iter_values(
511            repeat_with(|| rng.random::<u8>()).take(row_ids.len() * DIM),
512        );
513        let code_arr = FixedSizeListArray::try_new_from_values(sq_code, DIM as i32).unwrap();
514
515        let schema = Arc::new(Schema::new(vec![
516            Field::new(ROW_ID, DataType::UInt64, false),
517            Field::new(
518                SQ_CODE_COLUMN,
519                DataType::FixedSizeList(
520                    Arc::new(Field::new("item", DataType::UInt8, true)),
521                    DIM as i32,
522                ),
523                false,
524            ),
525        ]));
526        RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(code_arr)]).unwrap()
527    }
528
529    #[test]
530    fn test_get_chunks() {
531        const DIM: usize = 64;
532
533        let storage = ScalarQuantizationStorage::try_new(
534            8,
535            DistanceType::L2,
536            -0.7..0.7,
537            (0..4).map(|start| create_record_batch(start * 100..(start + 1) * 100)),
538            None,
539        )
540        .unwrap();
541
542        assert_eq!(storage.len(), 400);
543
544        let (offset, chunk) = storage.chunk(0);
545        assert_eq!(offset, 0);
546        assert_eq!(chunk.row_id(20), 20);
547
548        let (offset, _) = storage.chunk(50);
549        assert_eq!(offset, 0);
550
551        let row_ids = UInt64Array::from_iter_values(100..250);
552        let vector_data = generate_random_array(row_ids.len() * DIM);
553        let fsl = FixedSizeListArray::try_new_from_values(vector_data, DIM as i32).unwrap();
554
555        let schema = Arc::new(Schema::new(vec![
556            Field::new(ROW_ID, DataType::UInt64, false),
557            Field::new(
558                "vector",
559                DataType::FixedSizeList(
560                    Arc::new(Field::new("item", DataType::Float32, true)),
561                    DIM as i32,
562                ),
563                false,
564            ),
565        ]));
566
567        let second_batch =
568            RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(fsl)]).unwrap();
569        let storage = storage.append_batch(second_batch, "vector").unwrap();
570
571        assert_eq!(storage.len(), 550);
572        let (offset, chunk) = storage.chunk(112);
573        assert_eq!(offset, 100);
574        assert_eq!(chunk.row_id(10), 110);
575
576        let (offset, chunk) = storage.chunk(432);
577        assert_eq!(offset, 400);
578        assert_eq!(chunk.row_id(5), 105);
579    }
580}