1use std::sync::Arc;
5
6use super::index::FlatMetadata;
7use crate::frag_reuse::FragReuseIndex;
8use crate::vector::quantizer::QuantizerStorage;
9use crate::vector::storage::{DistCalculator, VectorStore};
10use crate::vector::utils::do_prefetch;
11use arrow::array::AsArray;
12use arrow::compute::concat_batches;
13use arrow::datatypes::{Float16Type, Float64Type, UInt8Type};
14use arrow_array::ArrowPrimitiveType;
15use arrow_array::{
16 Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt64Array,
17 types::{Float32Type, UInt64Type},
18};
19use arrow_schema::{DataType, SchemaRef};
20use deepsize::DeepSizeOf;
21use lance_core::{Error, ROW_ID, Result};
22use lance_file::previous::reader::FileReader as PreviousFileReader;
23use lance_linalg::distance::hamming::hamming;
24use lance_linalg::distance::{Cosine, DistanceType, Dot, L2};
25
26pub const FLAT_COLUMN: &str = "flat";
27
28#[derive(Debug, Clone)]
30pub struct FlatFloatStorage {
31 metadata: FlatMetadata,
32 batch: RecordBatch,
33 distance_type: DistanceType,
34
35 pub(super) row_ids: Arc<UInt64Array>,
37 vectors: Arc<FixedSizeListArray>,
38}
39
40impl DeepSizeOf for FlatFloatStorage {
41 fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
42 self.batch.get_array_memory_size()
43 }
44}
45
46#[async_trait::async_trait]
47impl QuantizerStorage for FlatFloatStorage {
48 type Metadata = FlatMetadata;
49
50 fn try_from_batch(
51 batch: RecordBatch,
52 metadata: &Self::Metadata,
53 distance_type: DistanceType,
54 frag_reuse_index: Option<Arc<FragReuseIndex>>,
55 ) -> Result<Self> {
56 let batch = if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
57 frag_reuse_index_ref.remap_row_ids_record_batch(batch, 0)?
58 } else {
59 batch
60 };
61
62 let row_ids = Arc::new(
63 batch
64 .column_by_name(ROW_ID)
65 .ok_or(Error::schema(format!("column {} not found", ROW_ID)))?
66 .as_primitive::<UInt64Type>()
67 .clone(),
68 );
69 let vectors = Arc::new(
70 batch
71 .column_by_name(FLAT_COLUMN)
72 .ok_or(Error::schema("column flat not found".to_string()))?
73 .as_fixed_size_list()
74 .clone(),
75 );
76 Ok(Self {
77 metadata: metadata.clone(),
78 batch,
79 distance_type,
80 row_ids,
81 vectors,
82 })
83 }
84
85 fn metadata(&self) -> &Self::Metadata {
86 &self.metadata
87 }
88
89 async fn load_partition(
90 _: &PreviousFileReader,
91 _: std::ops::Range<usize>,
92 _: DistanceType,
93 _: &Self::Metadata,
94 _: Option<Arc<FragReuseIndex>>,
95 ) -> Result<Self> {
96 unimplemented!("Flat will be used in new index builder which doesn't require this")
97 }
98}
99
100impl FlatFloatStorage {
101 pub fn new(vectors: FixedSizeListArray, distance_type: DistanceType) -> Self {
103 let row_ids = Arc::new(UInt64Array::from_iter_values(0..vectors.len() as u64));
104 let vectors = Arc::new(vectors);
105
106 let batch = RecordBatch::try_from_iter_with_nullable(vec![
107 (ROW_ID, row_ids.clone() as ArrayRef, true),
108 (FLAT_COLUMN, vectors.clone() as ArrayRef, true),
109 ])
110 .unwrap();
111
112 Self {
113 metadata: FlatMetadata {
114 dim: vectors.value_length() as usize,
115 },
116 batch,
117 distance_type,
118 row_ids,
119 vectors,
120 }
121 }
122
123 pub fn vector(&self, id: u32) -> ArrayRef {
124 self.vectors.value(id as usize)
125 }
126}
127
128impl VectorStore for FlatFloatStorage {
129 type DistanceCalculator<'a> = FlatFloatDistanceCalc<'a>;
130
131 fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
132 Ok([self.batch.clone()].into_iter())
133 }
134
135 fn append_batch(&self, batch: RecordBatch, _vector_column: &str) -> Result<Self> {
136 let new_batch = concat_batches(&batch.schema(), vec![&self.batch, &batch].into_iter())?;
138 let mut storage = self.clone();
139 storage.row_ids = Arc::new(
140 new_batch
141 .column_by_name(ROW_ID)
142 .ok_or(Error::schema(format!("column {} not found", ROW_ID)))?
143 .as_primitive::<UInt64Type>()
144 .clone(),
145 );
146 storage.vectors = Arc::new(
147 new_batch
148 .column_by_name(FLAT_COLUMN)
149 .ok_or(Error::schema("column flat not found".to_string()))?
150 .as_fixed_size_list()
151 .clone(),
152 );
153 storage.batch = new_batch;
154 Ok(storage)
155 }
156
157 fn schema(&self) -> &SchemaRef {
158 self.batch.schema_ref()
159 }
160
161 fn as_any(&self) -> &dyn std::any::Any {
162 self
163 }
164
165 fn len(&self) -> usize {
166 self.vectors.len()
167 }
168
169 fn distance_type(&self) -> DistanceType {
170 self.distance_type
171 }
172
173 fn row_id(&self, id: u32) -> u64 {
174 self.row_ids.values()[id as usize]
175 }
176
177 fn row_ids(&self) -> impl Iterator<Item = &u64> {
178 self.row_ids.values().iter()
179 }
180
181 fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
182 Self::DistanceCalculator::new(self.vectors.as_ref(), query, self.distance_type)
183 }
184
185 fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
186 Self::DistanceCalculator::new(
187 self.vectors.as_ref(),
188 self.vectors.value(id as usize),
189 self.distance_type,
190 )
191 }
192}
193
194#[derive(Debug, Clone)]
196pub struct FlatBinStorage {
197 metadata: FlatMetadata,
198 batch: RecordBatch,
199 distance_type: DistanceType,
200
201 pub(super) row_ids: Arc<UInt64Array>,
203 vectors: Arc<FixedSizeListArray>,
204}
205
206impl DeepSizeOf for FlatBinStorage {
207 fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
208 self.batch.get_array_memory_size()
209 }
210}
211
212#[async_trait::async_trait]
213impl QuantizerStorage for FlatBinStorage {
214 type Metadata = FlatMetadata;
215
216 fn try_from_batch(
217 batch: RecordBatch,
218 metadata: &Self::Metadata,
219 distance_type: DistanceType,
220 frag_reuse_index: Option<Arc<FragReuseIndex>>,
221 ) -> Result<Self> {
222 let batch = if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
223 frag_reuse_index_ref.remap_row_ids_record_batch(batch, 0)?
224 } else {
225 batch
226 };
227
228 let row_ids = Arc::new(
229 batch
230 .column_by_name(ROW_ID)
231 .ok_or(Error::schema(format!("column {} not found", ROW_ID)))?
232 .as_primitive::<UInt64Type>()
233 .clone(),
234 );
235 let vectors = Arc::new(
236 batch
237 .column_by_name(FLAT_COLUMN)
238 .ok_or(Error::schema("column flat not found".to_string()))?
239 .as_fixed_size_list()
240 .clone(),
241 );
242 Ok(Self {
243 metadata: metadata.clone(),
244 batch,
245 distance_type,
246 row_ids,
247 vectors,
248 })
249 }
250
251 fn metadata(&self) -> &Self::Metadata {
252 &self.metadata
253 }
254
255 async fn load_partition(
256 _: &PreviousFileReader,
257 _: std::ops::Range<usize>,
258 _: DistanceType,
259 _: &Self::Metadata,
260 _: Option<Arc<FragReuseIndex>>,
261 ) -> Result<Self> {
262 unimplemented!("Flat will be used in new index builder which doesn't require this")
263 }
264}
265
266impl FlatBinStorage {
267 pub fn new(vectors: FixedSizeListArray, distance_type: DistanceType) -> Self {
269 let row_ids = Arc::new(UInt64Array::from_iter_values(0..vectors.len() as u64));
270 let vectors = Arc::new(vectors);
271
272 let batch = RecordBatch::try_from_iter_with_nullable(vec![
273 (ROW_ID, row_ids.clone() as ArrayRef, true),
274 (FLAT_COLUMN, vectors.clone() as ArrayRef, true),
275 ])
276 .unwrap();
277
278 Self {
279 metadata: FlatMetadata {
280 dim: vectors.value_length() as usize,
281 },
282 batch,
283 distance_type,
284 row_ids,
285 vectors,
286 }
287 }
288
289 pub fn vector(&self, id: u32) -> ArrayRef {
290 self.vectors.value(id as usize)
291 }
292}
293
294impl VectorStore for FlatBinStorage {
295 type DistanceCalculator<'a> = FlatDistanceCal<'a, UInt8Type>;
296
297 fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
298 Ok([self.batch.clone()].into_iter())
299 }
300
301 fn append_batch(&self, batch: RecordBatch, _vector_column: &str) -> Result<Self> {
302 let new_batch = concat_batches(&batch.schema(), vec![&self.batch, &batch].into_iter())?;
304 let mut storage = self.clone();
305 storage.row_ids = Arc::new(
306 new_batch
307 .column_by_name(ROW_ID)
308 .ok_or(Error::schema(format!("column {} not found", ROW_ID)))?
309 .as_primitive::<UInt64Type>()
310 .clone(),
311 );
312 storage.vectors = Arc::new(
313 new_batch
314 .column_by_name(FLAT_COLUMN)
315 .ok_or(Error::schema("column flat not found".to_string()))?
316 .as_fixed_size_list()
317 .clone(),
318 );
319 storage.batch = new_batch;
320 Ok(storage)
321 }
322
323 fn schema(&self) -> &SchemaRef {
324 self.batch.schema_ref()
325 }
326
327 fn as_any(&self) -> &dyn std::any::Any {
328 self
329 }
330
331 fn len(&self) -> usize {
332 self.vectors.len()
333 }
334
335 fn distance_type(&self) -> DistanceType {
336 self.distance_type
337 }
338
339 fn row_id(&self, id: u32) -> u64 {
340 self.row_ids.values()[id as usize]
341 }
342
343 fn row_ids(&self) -> impl Iterator<Item = &u64> {
344 self.row_ids.values().iter()
345 }
346
347 fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
348 Self::DistanceCalculator::new_binary(self.vectors.as_ref(), query, self.distance_type)
349 }
350
351 fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
352 Self::DistanceCalculator::new_binary(
353 self.vectors.as_ref(),
354 self.vectors.value(id as usize),
355 self.distance_type,
356 )
357 }
358}
359
360pub struct FlatDistanceCal<'a, T: ArrowPrimitiveType> {
361 vectors: &'a [T::Native],
362 query: Vec<T::Native>,
363 dimension: usize,
364 #[allow(clippy::type_complexity)]
365 distance_fn: fn(&[T::Native], &[T::Native]) -> f32,
366}
367
368impl<'a, T> FlatDistanceCal<'a, T>
369where
370 T: ArrowPrimitiveType,
371 T::Native: L2 + Cosine + Dot,
372{
373 fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, distance_type: DistanceType) -> Self {
374 let flat_array = vectors.values().as_primitive::<T>();
376 let dimension = vectors.value_length() as usize;
377 Self {
378 vectors: flat_array.values(),
379 query: query.as_primitive::<T>().values().to_vec(),
380 dimension,
381 distance_fn: distance_type.func(),
382 }
383 }
384}
385
386impl<'a> FlatDistanceCal<'a, UInt8Type> {
387 fn new_binary(
388 vectors: &'a FixedSizeListArray,
389 query: ArrayRef,
390 _distance_type: DistanceType,
391 ) -> Self {
392 let flat_array = vectors.values().as_primitive::<UInt8Type>();
395 let dimension = vectors.value_length() as usize;
396 Self {
397 vectors: flat_array.values(),
398 query: query.as_primitive::<UInt8Type>().values().to_vec(),
399 dimension,
400 distance_fn: hamming,
401 }
402 }
403}
404
405impl<T: ArrowPrimitiveType> FlatDistanceCal<'_, T> {
406 #[inline]
407 fn get_vector(&self, id: u32) -> &[T::Native] {
408 &self.vectors[self.dimension * id as usize..self.dimension * (id + 1) as usize]
409 }
410}
411
412impl<T: ArrowPrimitiveType> DistCalculator for FlatDistanceCal<'_, T> {
413 #[inline]
414 fn distance(&self, id: u32) -> f32 {
415 let vector = self.get_vector(id);
416 (self.distance_fn)(&self.query, vector)
417 }
418
419 fn distance_all(&self, _k_hint: usize) -> Vec<f32> {
420 let query = &self.query;
421 self.vectors
422 .chunks_exact(self.dimension)
423 .map(|vector| (self.distance_fn)(query, vector))
424 .collect()
425 }
426
427 #[inline]
428 fn prefetch(&self, id: u32) {
429 let vector = self.get_vector(id);
430 do_prefetch(vector.as_ptr_range())
431 }
432}
433
434pub enum FlatFloatDistanceCalc<'a> {
435 Float16(FlatDistanceCal<'a, Float16Type>),
436 Float32(FlatDistanceCal<'a, Float32Type>),
437 Float64(FlatDistanceCal<'a, Float64Type>),
438}
439
440impl<'a> FlatFloatDistanceCalc<'a> {
441 fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, distance_type: DistanceType) -> Self {
442 match vectors.value_type() {
443 DataType::Float16 => Self::Float16(FlatDistanceCal::<Float16Type>::new(
444 vectors,
445 query,
446 distance_type,
447 )),
448 DataType::Float32 => Self::Float32(FlatDistanceCal::<Float32Type>::new(
449 vectors,
450 query,
451 distance_type,
452 )),
453 DataType::Float64 => Self::Float64(FlatDistanceCal::<Float64Type>::new(
454 vectors,
455 query,
456 distance_type,
457 )),
458 dt => panic!("flat float storage does not support data type {dt}"),
459 }
460 }
461}
462
463impl DistCalculator for FlatFloatDistanceCalc<'_> {
464 fn distance(&self, id: u32) -> f32 {
465 match self {
466 Self::Float16(calc) => calc.distance(id),
467 Self::Float32(calc) => calc.distance(id),
468 Self::Float64(calc) => calc.distance(id),
469 }
470 }
471
472 fn distance_all(&self, k_hint: usize) -> Vec<f32> {
473 match self {
474 Self::Float16(calc) => calc.distance_all(k_hint),
475 Self::Float32(calc) => calc.distance_all(k_hint),
476 Self::Float64(calc) => calc.distance_all(k_hint),
477 }
478 }
479
480 fn prefetch(&self, id: u32) {
481 match self {
482 Self::Float16(calc) => calc.prefetch(id),
483 Self::Float32(calc) => calc.prefetch(id),
484 Self::Float64(calc) => calc.prefetch(id),
485 }
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 use arrow_array::{Float16Array, Float64Array};
494 use half::f16;
495 use lance_arrow::FixedSizeListArrayExt;
496
497 fn make_f16_storage() -> FlatFloatStorage {
498 let values = Float16Array::from(vec![
499 f16::from_f32(1.0),
500 f16::from_f32(2.0),
501 f16::from_f32(4.0),
502 f16::from_f32(6.0),
503 ]);
504 let vectors = FixedSizeListArray::try_new_from_values(values, 2).unwrap();
505 FlatFloatStorage::new(vectors, DistanceType::L2)
506 }
507
508 fn make_f64_storage() -> FlatFloatStorage {
509 let values = Float64Array::from(vec![1.0, 2.0, 4.0, 6.0]);
510 let vectors = FixedSizeListArray::try_new_from_values(values, 2).unwrap();
511 FlatFloatStorage::new(vectors, DistanceType::L2)
512 }
513
514 #[test]
515 fn test_flat_float_storage_distance_f16() {
516 let storage = make_f16_storage();
517 let query: ArrayRef = Arc::new(Float16Array::from(vec![
518 f16::from_f32(1.0),
519 f16::from_f32(2.0),
520 ]));
521
522 let calc = storage.dist_calculator(query, 0.0);
523 let distances = calc.distance_all(2);
524
525 assert_eq!(distances.len(), 2);
526 assert_eq!(distances[0], 0.0);
527 assert!((distances[1] - 25.0).abs() < 1e-4);
528 }
529
530 #[test]
531 fn test_flat_float_storage_distance_f64() {
532 let storage = make_f64_storage();
533 let query: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0]));
534
535 let calc = storage.dist_calculator(query, 0.0);
536 let distances = calc.distance_all(2);
537
538 assert_eq!(distances.len(), 2);
539 assert_eq!(distances[0], 0.0);
540 assert!((distances[1] - 25.0).abs() < 1e-6);
541 }
542}