Skip to main content

lance_index/vector/flat/
storage.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use super::index::FlatMetadata;
7use crate::frag_reuse::FragReuseIndex;
8use crate::vector::quantizer::QuantizerStorage;
9use crate::vector::storage::{DistCalculator, VectorStore};
10use crate::vector::utils::do_prefetch;
11use arrow::array::AsArray;
12use arrow::compute::concat_batches;
13use arrow::datatypes::{Float16Type, Float64Type, UInt8Type};
14use arrow_array::ArrowPrimitiveType;
15use arrow_array::{
16    Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt64Array,
17    types::{Float32Type, UInt64Type},
18};
19use arrow_schema::{DataType, SchemaRef};
20use deepsize::DeepSizeOf;
21use lance_core::{Error, ROW_ID, Result};
22use lance_file::previous::reader::FileReader as PreviousFileReader;
23use lance_linalg::distance::hamming::hamming;
24use lance_linalg::distance::{Cosine, DistanceType, Dot, L2};
25
26pub const FLAT_COLUMN: &str = "flat";
27
28/// All data are stored in memory
29#[derive(Debug, Clone)]
30pub struct FlatFloatStorage {
31    metadata: FlatMetadata,
32    batch: RecordBatch,
33    distance_type: DistanceType,
34
35    // helper fields
36    pub(super) row_ids: Arc<UInt64Array>,
37    vectors: Arc<FixedSizeListArray>,
38}
39
40impl DeepSizeOf for FlatFloatStorage {
41    fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
42        self.batch.get_array_memory_size()
43    }
44}
45
46#[async_trait::async_trait]
47impl QuantizerStorage for FlatFloatStorage {
48    type Metadata = FlatMetadata;
49
50    fn try_from_batch(
51        batch: RecordBatch,
52        metadata: &Self::Metadata,
53        distance_type: DistanceType,
54        frag_reuse_index: Option<Arc<FragReuseIndex>>,
55    ) -> Result<Self> {
56        let batch = if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
57            frag_reuse_index_ref.remap_row_ids_record_batch(batch, 0)?
58        } else {
59            batch
60        };
61
62        let row_ids = Arc::new(
63            batch
64                .column_by_name(ROW_ID)
65                .ok_or(Error::schema(format!("column {} not found", ROW_ID)))?
66                .as_primitive::<UInt64Type>()
67                .clone(),
68        );
69        let vectors = Arc::new(
70            batch
71                .column_by_name(FLAT_COLUMN)
72                .ok_or(Error::schema("column flat not found".to_string()))?
73                .as_fixed_size_list()
74                .clone(),
75        );
76        Ok(Self {
77            metadata: metadata.clone(),
78            batch,
79            distance_type,
80            row_ids,
81            vectors,
82        })
83    }
84
85    fn metadata(&self) -> &Self::Metadata {
86        &self.metadata
87    }
88
89    async fn load_partition(
90        _: &PreviousFileReader,
91        _: std::ops::Range<usize>,
92        _: DistanceType,
93        _: &Self::Metadata,
94        _: Option<Arc<FragReuseIndex>>,
95    ) -> Result<Self> {
96        unimplemented!("Flat will be used in new index builder which doesn't require this")
97    }
98}
99
100impl FlatFloatStorage {
101    // used for only testing
102    pub fn new(vectors: FixedSizeListArray, distance_type: DistanceType) -> Self {
103        let row_ids = Arc::new(UInt64Array::from_iter_values(0..vectors.len() as u64));
104        let vectors = Arc::new(vectors);
105
106        let batch = RecordBatch::try_from_iter_with_nullable(vec![
107            (ROW_ID, row_ids.clone() as ArrayRef, true),
108            (FLAT_COLUMN, vectors.clone() as ArrayRef, true),
109        ])
110        .unwrap();
111
112        Self {
113            metadata: FlatMetadata {
114                dim: vectors.value_length() as usize,
115            },
116            batch,
117            distance_type,
118            row_ids,
119            vectors,
120        }
121    }
122
123    pub fn vector(&self, id: u32) -> ArrayRef {
124        self.vectors.value(id as usize)
125    }
126}
127
128impl VectorStore for FlatFloatStorage {
129    type DistanceCalculator<'a> = FlatFloatDistanceCalc<'a>;
130
131    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
132        Ok([self.batch.clone()].into_iter())
133    }
134
135    fn append_batch(&self, batch: RecordBatch, _vector_column: &str) -> Result<Self> {
136        // TODO: use chunked storage
137        let new_batch = concat_batches(&batch.schema(), vec![&self.batch, &batch].into_iter())?;
138        let mut storage = self.clone();
139        storage.row_ids = Arc::new(
140            new_batch
141                .column_by_name(ROW_ID)
142                .ok_or(Error::schema(format!("column {} not found", ROW_ID)))?
143                .as_primitive::<UInt64Type>()
144                .clone(),
145        );
146        storage.vectors = Arc::new(
147            new_batch
148                .column_by_name(FLAT_COLUMN)
149                .ok_or(Error::schema("column flat not found".to_string()))?
150                .as_fixed_size_list()
151                .clone(),
152        );
153        storage.batch = new_batch;
154        Ok(storage)
155    }
156
157    fn schema(&self) -> &SchemaRef {
158        self.batch.schema_ref()
159    }
160
161    fn as_any(&self) -> &dyn std::any::Any {
162        self
163    }
164
165    fn len(&self) -> usize {
166        self.vectors.len()
167    }
168
169    fn distance_type(&self) -> DistanceType {
170        self.distance_type
171    }
172
173    fn row_id(&self, id: u32) -> u64 {
174        self.row_ids.values()[id as usize]
175    }
176
177    fn row_ids(&self) -> impl Iterator<Item = &u64> {
178        self.row_ids.values().iter()
179    }
180
181    fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
182        Self::DistanceCalculator::new(self.vectors.as_ref(), query, self.distance_type)
183    }
184
185    fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
186        Self::DistanceCalculator::new(
187            self.vectors.as_ref(),
188            self.vectors.value(id as usize),
189            self.distance_type,
190        )
191    }
192}
193
194/// All data are stored in memory
195#[derive(Debug, Clone)]
196pub struct FlatBinStorage {
197    metadata: FlatMetadata,
198    batch: RecordBatch,
199    distance_type: DistanceType,
200
201    // helper fields
202    pub(super) row_ids: Arc<UInt64Array>,
203    vectors: Arc<FixedSizeListArray>,
204}
205
206impl DeepSizeOf for FlatBinStorage {
207    fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
208        self.batch.get_array_memory_size()
209    }
210}
211
212#[async_trait::async_trait]
213impl QuantizerStorage for FlatBinStorage {
214    type Metadata = FlatMetadata;
215
216    fn try_from_batch(
217        batch: RecordBatch,
218        metadata: &Self::Metadata,
219        distance_type: DistanceType,
220        frag_reuse_index: Option<Arc<FragReuseIndex>>,
221    ) -> Result<Self> {
222        let batch = if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
223            frag_reuse_index_ref.remap_row_ids_record_batch(batch, 0)?
224        } else {
225            batch
226        };
227
228        let row_ids = Arc::new(
229            batch
230                .column_by_name(ROW_ID)
231                .ok_or(Error::schema(format!("column {} not found", ROW_ID)))?
232                .as_primitive::<UInt64Type>()
233                .clone(),
234        );
235        let vectors = Arc::new(
236            batch
237                .column_by_name(FLAT_COLUMN)
238                .ok_or(Error::schema("column flat not found".to_string()))?
239                .as_fixed_size_list()
240                .clone(),
241        );
242        Ok(Self {
243            metadata: metadata.clone(),
244            batch,
245            distance_type,
246            row_ids,
247            vectors,
248        })
249    }
250
251    fn metadata(&self) -> &Self::Metadata {
252        &self.metadata
253    }
254
255    async fn load_partition(
256        _: &PreviousFileReader,
257        _: std::ops::Range<usize>,
258        _: DistanceType,
259        _: &Self::Metadata,
260        _: Option<Arc<FragReuseIndex>>,
261    ) -> Result<Self> {
262        unimplemented!("Flat will be used in new index builder which doesn't require this")
263    }
264}
265
266impl FlatBinStorage {
267    // used for only testing
268    pub fn new(vectors: FixedSizeListArray, distance_type: DistanceType) -> Self {
269        let row_ids = Arc::new(UInt64Array::from_iter_values(0..vectors.len() as u64));
270        let vectors = Arc::new(vectors);
271
272        let batch = RecordBatch::try_from_iter_with_nullable(vec![
273            (ROW_ID, row_ids.clone() as ArrayRef, true),
274            (FLAT_COLUMN, vectors.clone() as ArrayRef, true),
275        ])
276        .unwrap();
277
278        Self {
279            metadata: FlatMetadata {
280                dim: vectors.value_length() as usize,
281            },
282            batch,
283            distance_type,
284            row_ids,
285            vectors,
286        }
287    }
288
289    pub fn vector(&self, id: u32) -> ArrayRef {
290        self.vectors.value(id as usize)
291    }
292}
293
294impl VectorStore for FlatBinStorage {
295    type DistanceCalculator<'a> = FlatDistanceCal<'a, UInt8Type>;
296
297    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
298        Ok([self.batch.clone()].into_iter())
299    }
300
301    fn append_batch(&self, batch: RecordBatch, _vector_column: &str) -> Result<Self> {
302        // TODO: use chunked storage
303        let new_batch = concat_batches(&batch.schema(), vec![&self.batch, &batch].into_iter())?;
304        let mut storage = self.clone();
305        storage.row_ids = Arc::new(
306            new_batch
307                .column_by_name(ROW_ID)
308                .ok_or(Error::schema(format!("column {} not found", ROW_ID)))?
309                .as_primitive::<UInt64Type>()
310                .clone(),
311        );
312        storage.vectors = Arc::new(
313            new_batch
314                .column_by_name(FLAT_COLUMN)
315                .ok_or(Error::schema("column flat not found".to_string()))?
316                .as_fixed_size_list()
317                .clone(),
318        );
319        storage.batch = new_batch;
320        Ok(storage)
321    }
322
323    fn schema(&self) -> &SchemaRef {
324        self.batch.schema_ref()
325    }
326
327    fn as_any(&self) -> &dyn std::any::Any {
328        self
329    }
330
331    fn len(&self) -> usize {
332        self.vectors.len()
333    }
334
335    fn distance_type(&self) -> DistanceType {
336        self.distance_type
337    }
338
339    fn row_id(&self, id: u32) -> u64 {
340        self.row_ids.values()[id as usize]
341    }
342
343    fn row_ids(&self) -> impl Iterator<Item = &u64> {
344        self.row_ids.values().iter()
345    }
346
347    fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
348        Self::DistanceCalculator::new_binary(self.vectors.as_ref(), query, self.distance_type)
349    }
350
351    fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
352        Self::DistanceCalculator::new_binary(
353            self.vectors.as_ref(),
354            self.vectors.value(id as usize),
355            self.distance_type,
356        )
357    }
358}
359
360pub struct FlatDistanceCal<'a, T: ArrowPrimitiveType> {
361    vectors: &'a [T::Native],
362    query: Vec<T::Native>,
363    dimension: usize,
364    #[allow(clippy::type_complexity)]
365    distance_fn: fn(&[T::Native], &[T::Native]) -> f32,
366}
367
368impl<'a, T> FlatDistanceCal<'a, T>
369where
370    T: ArrowPrimitiveType,
371    T::Native: L2 + Cosine + Dot,
372{
373    fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, distance_type: DistanceType) -> Self {
374        // Gained significant performance improvement by using strong typed primitive slice.
375        let flat_array = vectors.values().as_primitive::<T>();
376        let dimension = vectors.value_length() as usize;
377        Self {
378            vectors: flat_array.values(),
379            query: query.as_primitive::<T>().values().to_vec(),
380            dimension,
381            distance_fn: distance_type.func(),
382        }
383    }
384}
385
386impl<'a> FlatDistanceCal<'a, UInt8Type> {
387    fn new_binary(
388        vectors: &'a FixedSizeListArray,
389        query: ArrayRef,
390        _distance_type: DistanceType,
391    ) -> Self {
392        // Gained significant performance improvement by using strong typed primitive slice.
393        // TODO: to support other data types other than `f32`, make FlatDistanceCal a generic struct.
394        let flat_array = vectors.values().as_primitive::<UInt8Type>();
395        let dimension = vectors.value_length() as usize;
396        Self {
397            vectors: flat_array.values(),
398            query: query.as_primitive::<UInt8Type>().values().to_vec(),
399            dimension,
400            distance_fn: hamming,
401        }
402    }
403}
404
405impl<T: ArrowPrimitiveType> FlatDistanceCal<'_, T> {
406    #[inline]
407    fn get_vector(&self, id: u32) -> &[T::Native] {
408        &self.vectors[self.dimension * id as usize..self.dimension * (id + 1) as usize]
409    }
410}
411
412impl<T: ArrowPrimitiveType> DistCalculator for FlatDistanceCal<'_, T> {
413    #[inline]
414    fn distance(&self, id: u32) -> f32 {
415        let vector = self.get_vector(id);
416        (self.distance_fn)(&self.query, vector)
417    }
418
419    fn distance_all(&self, _k_hint: usize) -> Vec<f32> {
420        let query = &self.query;
421        self.vectors
422            .chunks_exact(self.dimension)
423            .map(|vector| (self.distance_fn)(query, vector))
424            .collect()
425    }
426
427    #[inline]
428    fn prefetch(&self, id: u32) {
429        let vector = self.get_vector(id);
430        do_prefetch(vector.as_ptr_range())
431    }
432}
433
434pub enum FlatFloatDistanceCalc<'a> {
435    Float16(FlatDistanceCal<'a, Float16Type>),
436    Float32(FlatDistanceCal<'a, Float32Type>),
437    Float64(FlatDistanceCal<'a, Float64Type>),
438}
439
440impl<'a> FlatFloatDistanceCalc<'a> {
441    fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, distance_type: DistanceType) -> Self {
442        match vectors.value_type() {
443            DataType::Float16 => Self::Float16(FlatDistanceCal::<Float16Type>::new(
444                vectors,
445                query,
446                distance_type,
447            )),
448            DataType::Float32 => Self::Float32(FlatDistanceCal::<Float32Type>::new(
449                vectors,
450                query,
451                distance_type,
452            )),
453            DataType::Float64 => Self::Float64(FlatDistanceCal::<Float64Type>::new(
454                vectors,
455                query,
456                distance_type,
457            )),
458            dt => panic!("flat float storage does not support data type {dt}"),
459        }
460    }
461}
462
463impl DistCalculator for FlatFloatDistanceCalc<'_> {
464    fn distance(&self, id: u32) -> f32 {
465        match self {
466            Self::Float16(calc) => calc.distance(id),
467            Self::Float32(calc) => calc.distance(id),
468            Self::Float64(calc) => calc.distance(id),
469        }
470    }
471
472    fn distance_all(&self, k_hint: usize) -> Vec<f32> {
473        match self {
474            Self::Float16(calc) => calc.distance_all(k_hint),
475            Self::Float32(calc) => calc.distance_all(k_hint),
476            Self::Float64(calc) => calc.distance_all(k_hint),
477        }
478    }
479
480    fn prefetch(&self, id: u32) {
481        match self {
482            Self::Float16(calc) => calc.prefetch(id),
483            Self::Float32(calc) => calc.prefetch(id),
484            Self::Float64(calc) => calc.prefetch(id),
485        }
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492
493    use arrow_array::{Float16Array, Float64Array};
494    use half::f16;
495    use lance_arrow::FixedSizeListArrayExt;
496
497    fn make_f16_storage() -> FlatFloatStorage {
498        let values = Float16Array::from(vec![
499            f16::from_f32(1.0),
500            f16::from_f32(2.0),
501            f16::from_f32(4.0),
502            f16::from_f32(6.0),
503        ]);
504        let vectors = FixedSizeListArray::try_new_from_values(values, 2).unwrap();
505        FlatFloatStorage::new(vectors, DistanceType::L2)
506    }
507
508    fn make_f64_storage() -> FlatFloatStorage {
509        let values = Float64Array::from(vec![1.0, 2.0, 4.0, 6.0]);
510        let vectors = FixedSizeListArray::try_new_from_values(values, 2).unwrap();
511        FlatFloatStorage::new(vectors, DistanceType::L2)
512    }
513
514    #[test]
515    fn test_flat_float_storage_distance_f16() {
516        let storage = make_f16_storage();
517        let query: ArrayRef = Arc::new(Float16Array::from(vec![
518            f16::from_f32(1.0),
519            f16::from_f32(2.0),
520        ]));
521
522        let calc = storage.dist_calculator(query, 0.0);
523        let distances = calc.distance_all(2);
524
525        assert_eq!(distances.len(), 2);
526        assert_eq!(distances[0], 0.0);
527        assert!((distances[1] - 25.0).abs() < 1e-4);
528    }
529
530    #[test]
531    fn test_flat_float_storage_distance_f64() {
532        let storage = make_f64_storage();
533        let query: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0]));
534
535        let calc = storage.dist_calculator(query, 0.0);
536        let distances = calc.distance_all(2);
537
538        assert_eq!(distances.len(), 2);
539        assert_eq!(distances[0], 0.0);
540        assert!((distances[1] - 25.0).abs() < 1e-6);
541    }
542}