lance_index/vector/flat/
storage.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use super::index::FlatMetadata;
7use crate::frag_reuse::FragReuseIndex;
8use crate::vector::quantizer::QuantizerStorage;
9use crate::vector::storage::{DistCalculator, VectorStore};
10use crate::vector::utils::do_prefetch;
11use arrow::array::AsArray;
12use arrow::compute::concat_batches;
13use arrow::datatypes::UInt8Type;
14use arrow_array::ArrowPrimitiveType;
15use arrow_array::{
16    types::{Float32Type, UInt64Type},
17    Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt64Array,
18};
19use arrow_schema::SchemaRef;
20use deepsize::DeepSizeOf;
21use lance_core::{Error, Result, ROW_ID};
22use lance_file::previous::reader::FileReader as PreviousFileReader;
23use lance_linalg::distance::hamming::hamming;
24use lance_linalg::distance::DistanceType;
25use snafu::location;
26
27pub const FLAT_COLUMN: &str = "flat";
28
29/// All data are stored in memory
30#[derive(Debug, Clone)]
31pub struct FlatFloatStorage {
32    metadata: FlatMetadata,
33    batch: RecordBatch,
34    distance_type: DistanceType,
35
36    // helper fields
37    pub(super) row_ids: Arc<UInt64Array>,
38    vectors: Arc<FixedSizeListArray>,
39}
40
41impl DeepSizeOf for FlatFloatStorage {
42    fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
43        self.batch.get_array_memory_size()
44    }
45}
46
47#[async_trait::async_trait]
48impl QuantizerStorage for FlatFloatStorage {
49    type Metadata = FlatMetadata;
50
51    fn try_from_batch(
52        batch: RecordBatch,
53        metadata: &Self::Metadata,
54        distance_type: DistanceType,
55        frag_reuse_index: Option<Arc<FragReuseIndex>>,
56    ) -> Result<Self> {
57        let batch = if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
58            frag_reuse_index_ref.remap_row_ids_record_batch(batch, 0)?
59        } else {
60            batch
61        };
62
63        let row_ids = Arc::new(
64            batch
65                .column_by_name(ROW_ID)
66                .ok_or(Error::Schema {
67                    message: format!("column {} not found", ROW_ID),
68                    location: location!(),
69                })?
70                .as_primitive::<UInt64Type>()
71                .clone(),
72        );
73        let vectors = Arc::new(
74            batch
75                .column_by_name(FLAT_COLUMN)
76                .ok_or(Error::Schema {
77                    message: "column flat not found".to_string(),
78                    location: location!(),
79                })?
80                .as_fixed_size_list()
81                .clone(),
82        );
83        Ok(Self {
84            metadata: metadata.clone(),
85            batch,
86            distance_type,
87            row_ids,
88            vectors,
89        })
90    }
91
92    fn metadata(&self) -> &Self::Metadata {
93        &self.metadata
94    }
95
96    async fn load_partition(
97        _: &PreviousFileReader,
98        _: std::ops::Range<usize>,
99        _: DistanceType,
100        _: &Self::Metadata,
101        _: Option<Arc<FragReuseIndex>>,
102    ) -> Result<Self> {
103        unimplemented!("Flat will be used in new index builder which doesn't require this")
104    }
105}
106
107impl FlatFloatStorage {
108    // used for only testing
109    pub fn new(vectors: FixedSizeListArray, distance_type: DistanceType) -> Self {
110        let row_ids = Arc::new(UInt64Array::from_iter_values(0..vectors.len() as u64));
111        let vectors = Arc::new(vectors);
112
113        let batch = RecordBatch::try_from_iter_with_nullable(vec![
114            (ROW_ID, row_ids.clone() as ArrayRef, true),
115            (FLAT_COLUMN, vectors.clone() as ArrayRef, true),
116        ])
117        .unwrap();
118
119        Self {
120            metadata: FlatMetadata {
121                dim: vectors.value_length() as usize,
122            },
123            batch,
124            distance_type,
125            row_ids,
126            vectors,
127        }
128    }
129
130    pub fn vector(&self, id: u32) -> ArrayRef {
131        self.vectors.value(id as usize)
132    }
133}
134
135impl VectorStore for FlatFloatStorage {
136    type DistanceCalculator<'a> = FlatDistanceCal<'a, Float32Type>;
137
138    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
139        Ok([self.batch.clone()].into_iter())
140    }
141
142    fn append_batch(&self, batch: RecordBatch, _vector_column: &str) -> Result<Self> {
143        // TODO: use chunked storage
144        let new_batch = concat_batches(&batch.schema(), vec![&self.batch, &batch].into_iter())?;
145        let mut storage = self.clone();
146        storage.batch = new_batch;
147        Ok(storage)
148    }
149
150    fn schema(&self) -> &SchemaRef {
151        self.batch.schema_ref()
152    }
153
154    fn as_any(&self) -> &dyn std::any::Any {
155        self
156    }
157
158    fn len(&self) -> usize {
159        self.vectors.len()
160    }
161
162    fn distance_type(&self) -> DistanceType {
163        self.distance_type
164    }
165
166    fn row_id(&self, id: u32) -> u64 {
167        self.row_ids.values()[id as usize]
168    }
169
170    fn row_ids(&self) -> impl Iterator<Item = &u64> {
171        self.row_ids.values().iter()
172    }
173
174    fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
175        Self::DistanceCalculator::new(self.vectors.as_ref(), query, self.distance_type)
176    }
177
178    fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
179        Self::DistanceCalculator::new(
180            self.vectors.as_ref(),
181            self.vectors.value(id as usize),
182            self.distance_type,
183        )
184    }
185}
186
187/// All data are stored in memory
188#[derive(Debug, Clone)]
189pub struct FlatBinStorage {
190    metadata: FlatMetadata,
191    batch: RecordBatch,
192    distance_type: DistanceType,
193
194    // helper fields
195    pub(super) row_ids: Arc<UInt64Array>,
196    vectors: Arc<FixedSizeListArray>,
197}
198
199impl DeepSizeOf for FlatBinStorage {
200    fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
201        self.batch.get_array_memory_size()
202    }
203}
204
205#[async_trait::async_trait]
206impl QuantizerStorage for FlatBinStorage {
207    type Metadata = FlatMetadata;
208
209    fn try_from_batch(
210        batch: RecordBatch,
211        metadata: &Self::Metadata,
212        distance_type: DistanceType,
213        frag_reuse_index: Option<Arc<FragReuseIndex>>,
214    ) -> Result<Self> {
215        let batch = if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
216            frag_reuse_index_ref.remap_row_ids_record_batch(batch, 0)?
217        } else {
218            batch
219        };
220
221        let row_ids = Arc::new(
222            batch
223                .column_by_name(ROW_ID)
224                .ok_or(Error::Schema {
225                    message: format!("column {} not found", ROW_ID),
226                    location: location!(),
227                })?
228                .as_primitive::<UInt64Type>()
229                .clone(),
230        );
231        let vectors = Arc::new(
232            batch
233                .column_by_name(FLAT_COLUMN)
234                .ok_or(Error::Schema {
235                    message: "column flat not found".to_string(),
236                    location: location!(),
237                })?
238                .as_fixed_size_list()
239                .clone(),
240        );
241        Ok(Self {
242            metadata: metadata.clone(),
243            batch,
244            distance_type,
245            row_ids,
246            vectors,
247        })
248    }
249
250    fn metadata(&self) -> &Self::Metadata {
251        &self.metadata
252    }
253
254    async fn load_partition(
255        _: &PreviousFileReader,
256        _: std::ops::Range<usize>,
257        _: DistanceType,
258        _: &Self::Metadata,
259        _: Option<Arc<FragReuseIndex>>,
260    ) -> Result<Self> {
261        unimplemented!("Flat will be used in new index builder which doesn't require this")
262    }
263}
264
265impl FlatBinStorage {
266    // used for only testing
267    pub fn new(vectors: FixedSizeListArray, distance_type: DistanceType) -> Self {
268        let row_ids = Arc::new(UInt64Array::from_iter_values(0..vectors.len() as u64));
269        let vectors = Arc::new(vectors);
270
271        let batch = RecordBatch::try_from_iter_with_nullable(vec![
272            (ROW_ID, row_ids.clone() as ArrayRef, true),
273            (FLAT_COLUMN, vectors.clone() as ArrayRef, true),
274        ])
275        .unwrap();
276
277        Self {
278            metadata: FlatMetadata {
279                dim: vectors.value_length() as usize,
280            },
281            batch,
282            distance_type,
283            row_ids,
284            vectors,
285        }
286    }
287
288    pub fn vector(&self, id: u32) -> ArrayRef {
289        self.vectors.value(id as usize)
290    }
291}
292
293impl VectorStore for FlatBinStorage {
294    type DistanceCalculator<'a> = FlatDistanceCal<'a, UInt8Type>;
295
296    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
297        Ok([self.batch.clone()].into_iter())
298    }
299
300    fn append_batch(&self, batch: RecordBatch, _vector_column: &str) -> Result<Self> {
301        // TODO: use chunked storage
302        let new_batch = concat_batches(&batch.schema(), vec![&self.batch, &batch].into_iter())?;
303        let mut storage = self.clone();
304        storage.batch = new_batch;
305        Ok(storage)
306    }
307
308    fn schema(&self) -> &SchemaRef {
309        self.batch.schema_ref()
310    }
311
312    fn as_any(&self) -> &dyn std::any::Any {
313        self
314    }
315
316    fn len(&self) -> usize {
317        self.vectors.len()
318    }
319
320    fn distance_type(&self) -> DistanceType {
321        self.distance_type
322    }
323
324    fn row_id(&self, id: u32) -> u64 {
325        self.row_ids.values()[id as usize]
326    }
327
328    fn row_ids(&self) -> impl Iterator<Item = &u64> {
329        self.row_ids.values().iter()
330    }
331
332    fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
333        Self::DistanceCalculator::new(self.vectors.as_ref(), query, self.distance_type)
334    }
335
336    fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
337        Self::DistanceCalculator::new(
338            self.vectors.as_ref(),
339            self.vectors.value(id as usize),
340            self.distance_type,
341        )
342    }
343}
344
345pub struct FlatDistanceCal<'a, T: ArrowPrimitiveType> {
346    vectors: &'a [T::Native],
347    query: Vec<T::Native>,
348    dimension: usize,
349    #[allow(clippy::type_complexity)]
350    distance_fn: fn(&[T::Native], &[T::Native]) -> f32,
351}
352
353impl<'a> FlatDistanceCal<'a, Float32Type> {
354    fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, distance_type: DistanceType) -> Self {
355        // Gained significant performance improvement by using strong typed primitive slice.
356        // TODO: to support other data types other than `f32`, make FlatDistanceCal a generic struct.
357        let flat_array = vectors.values().as_primitive::<Float32Type>();
358        let dimension = vectors.value_length() as usize;
359        Self {
360            vectors: flat_array.values(),
361            query: query.as_primitive::<Float32Type>().values().to_vec(),
362            dimension,
363            distance_fn: distance_type.func(),
364        }
365    }
366}
367
368impl<'a> FlatDistanceCal<'a, UInt8Type> {
369    fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, _distance_type: DistanceType) -> Self {
370        // Gained significant performance improvement by using strong typed primitive slice.
371        // TODO: to support other data types other than `f32`, make FlatDistanceCal a generic struct.
372        let flat_array = vectors.values().as_primitive::<UInt8Type>();
373        let dimension = vectors.value_length() as usize;
374        Self {
375            vectors: flat_array.values(),
376            query: query.as_primitive::<UInt8Type>().values().to_vec(),
377            dimension,
378            distance_fn: hamming,
379        }
380    }
381}
382
383impl<T: ArrowPrimitiveType> FlatDistanceCal<'_, T> {
384    #[inline]
385    fn get_vector(&self, id: u32) -> &[T::Native] {
386        &self.vectors[self.dimension * id as usize..self.dimension * (id + 1) as usize]
387    }
388}
389
390impl<T: ArrowPrimitiveType> DistCalculator for FlatDistanceCal<'_, T> {
391    #[inline]
392    fn distance(&self, id: u32) -> f32 {
393        let vector = self.get_vector(id);
394        (self.distance_fn)(&self.query, vector)
395    }
396
397    fn distance_all(&self, _k_hint: usize) -> Vec<f32> {
398        let query = &self.query;
399        self.vectors
400            .chunks_exact(self.dimension)
401            .map(|vector| (self.distance_fn)(query, vector))
402            .collect()
403    }
404
405    #[inline]
406    fn prefetch(&self, id: u32) {
407        let vector = self.get_vector(id);
408        do_prefetch(vector.as_ptr_range())
409    }
410}