lance_index/scalar/
flat.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashMap;
5use std::{any::Any, ops::Bound, sync::Arc};
6
7use arrow_array::{
8    cast::AsArray, types::UInt64Type, ArrayRef, BooleanArray, RecordBatch, UInt64Array,
9};
10use arrow_schema::{DataType, Field, Schema};
11use async_trait::async_trait;
12
13use datafusion::physical_plan::SendableRecordBatchStream;
14use datafusion_physical_expr::expressions::{in_list, lit, Column};
15use deepsize::DeepSizeOf;
16use lance_core::error::LanceOptionExt;
17use lance_core::utils::address::RowAddress;
18use lance_core::utils::mask::RowIdTreeMap;
19use lance_core::{Error, Result, ROW_ID};
20use roaring::RoaringBitmap;
21use snafu::location;
22
23use super::{btree::BTreeSubIndex, IndexStore, ScalarIndex};
24use super::{AnyQuery, MetricsCollector, SargableQuery, SearchResult};
25use crate::scalar::btree::{BTREE_IDS_COLUMN, BTREE_VALUES_COLUMN};
26use crate::scalar::registry::VALUE_COLUMN_NAME;
27use crate::scalar::{CreatedIndex, UpdateCriteria};
28use crate::{Index, IndexType};
29
30/// A flat index is just a batch of value/row-id pairs
31///
32/// The batch always has two columns.  The first column "values" contains
33/// the values.  The second column "row_ids" contains the row ids
34///
35/// Evaluating a query requires O(N) time where N is the # of rows
36#[derive(Debug)]
37pub struct FlatIndex {
38    data: Arc<RecordBatch>,
39    has_nulls: bool,
40}
41
42impl DeepSizeOf for FlatIndex {
43    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
44        self.data.get_array_memory_size()
45    }
46}
47
48impl FlatIndex {
49    fn values(&self) -> &ArrayRef {
50        self.data.column(0)
51    }
52
53    fn ids(&self) -> &ArrayRef {
54        self.data.column(1)
55    }
56}
57
58fn remap_batch(batch: RecordBatch, mapping: &HashMap<u64, Option<u64>>) -> Result<RecordBatch> {
59    let row_ids = batch.column(1).as_primitive::<UInt64Type>();
60    let val_idx_and_new_id = row_ids
61        .values()
62        .iter()
63        .enumerate()
64        .filter_map(|(idx, old_id)| {
65            mapping
66                .get(old_id)
67                .copied()
68                .unwrap_or(Some(*old_id))
69                .map(|new_id| (idx, new_id))
70        })
71        .collect::<Vec<_>>();
72    let new_ids = Arc::new(UInt64Array::from_iter_values(
73        val_idx_and_new_id.iter().copied().map(|(_, new_id)| new_id),
74    ));
75    let new_val_indices = UInt64Array::from_iter_values(
76        val_idx_and_new_id
77            .into_iter()
78            .map(|(val_idx, _)| val_idx as u64),
79    );
80    let new_vals = arrow_select::take::take(batch.column(0), &new_val_indices, None)?;
81    Ok(RecordBatch::try_new(
82        batch.schema(),
83        vec![new_vals, new_ids],
84    )?)
85}
86
87/// Trains a flat index from a record batch of values & ids by simply storing the batch
88///
89/// This allows the flat index to be used as a sub-index
90#[derive(Debug)]
91pub struct FlatIndexMetadata {
92    schema: Arc<Schema>,
93}
94
95impl DeepSizeOf for FlatIndexMetadata {
96    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
97        self.schema.metadata.deep_size_of_children(context)
98            + self
99                .schema
100                .fields
101                .iter()
102                // This undercounts slightly because it doesn't account for the size of the
103                // field data types
104                .map(|f| {
105                    std::mem::size_of::<Field>()
106                        + f.name().deep_size_of_children(context)
107                        + f.metadata().deep_size_of_children(context)
108                })
109                .sum::<usize>()
110    }
111}
112
113impl FlatIndexMetadata {
114    pub fn new(value_type: DataType) -> Self {
115        let schema = Arc::new(Schema::new(vec![
116            Field::new(BTREE_VALUES_COLUMN, value_type, true),
117            Field::new(BTREE_IDS_COLUMN, DataType::UInt64, true),
118        ]));
119        Self { schema }
120    }
121}
122
123#[async_trait]
124impl BTreeSubIndex for FlatIndexMetadata {
125    fn schema(&self) -> &Arc<Schema> {
126        &self.schema
127    }
128
129    async fn train(&self, batch: RecordBatch) -> Result<RecordBatch> {
130        // The data source may not call the columns "values" and "row_ids" so we need to replace
131        // the schema
132        Ok(RecordBatch::try_new(
133            self.schema.clone(),
134            vec![
135                batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?.clone(),
136                batch.column_by_name(ROW_ID).expect_ok()?.clone(),
137            ],
138        )?)
139    }
140
141    async fn load_subindex(&self, serialized: RecordBatch) -> Result<Arc<dyn ScalarIndex>> {
142        let has_nulls = serialized.column(0).null_count() > 0;
143        Ok(Arc::new(FlatIndex {
144            data: Arc::new(serialized),
145            has_nulls,
146        }))
147    }
148
149    async fn remap_subindex(
150        &self,
151        serialized: RecordBatch,
152        mapping: &HashMap<u64, Option<u64>>,
153    ) -> Result<RecordBatch> {
154        remap_batch(serialized, mapping)
155    }
156
157    async fn retrieve_data(&self, serialized: RecordBatch) -> Result<RecordBatch> {
158        Ok(serialized)
159    }
160}
161
162#[async_trait]
163impl Index for FlatIndex {
164    fn as_any(&self) -> &dyn Any {
165        self
166    }
167
168    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
169        self
170    }
171
172    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
173        Err(Error::NotSupported {
174            source: "FlatIndex is not vector index".into(),
175            location: location!(),
176        })
177    }
178
179    fn index_type(&self) -> IndexType {
180        IndexType::Scalar
181    }
182
183    async fn prewarm(&self) -> Result<()> {
184        // There is nothing to pre-warm
185        Ok(())
186    }
187
188    fn statistics(&self) -> Result<serde_json::Value> {
189        Ok(serde_json::json!({
190            "num_values": self.data.num_rows(),
191        }))
192    }
193
194    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
195        let mut frag_ids = self
196            .ids()
197            .as_primitive::<UInt64Type>()
198            .iter()
199            .map(|row_id| RowAddress::from(row_id.unwrap()).fragment_id())
200            .collect::<Vec<_>>();
201        frag_ids.sort();
202        frag_ids.dedup();
203        Ok(RoaringBitmap::from_sorted_iter(frag_ids).unwrap())
204    }
205}
206
207#[async_trait]
208impl ScalarIndex for FlatIndex {
209    async fn search(
210        &self,
211        query: &dyn AnyQuery,
212        metrics: &dyn MetricsCollector,
213    ) -> Result<SearchResult> {
214        metrics.record_comparisons(self.data.num_rows());
215        let query = query.as_any().downcast_ref::<SargableQuery>().unwrap();
216        // Since we have all the values in memory we can use basic arrow-rs compute
217        // functions to satisfy scalar queries.
218        let mut predicate = match query {
219            SargableQuery::Equals(value) => {
220                if value.is_null() {
221                    arrow::compute::is_null(self.values())?
222                } else {
223                    arrow_ord::cmp::eq(self.values(), &value.to_scalar()?)?
224                }
225            }
226            SargableQuery::IsNull() => arrow::compute::is_null(self.values())?,
227            SargableQuery::IsIn(values) => {
228                let mut has_null = false;
229                let choices = values
230                    .iter()
231                    .map(|val| {
232                        has_null |= val.is_null();
233                        lit(val.clone())
234                    })
235                    .collect::<Vec<_>>();
236                let in_list_expr = in_list(
237                    Arc::new(Column::new("values", 0)),
238                    choices,
239                    &false,
240                    &self.data.schema(),
241                )?;
242                let result_col = in_list_expr.evaluate(&self.data)?;
243                let predicate = result_col
244                    .into_array(self.data.num_rows())?
245                    .as_any()
246                    .downcast_ref::<BooleanArray>()
247                    .expect("InList evaluation should return boolean array")
248                    .clone();
249
250                // Arrow's in_list does not handle nulls so we need to join them in here if user asked for them
251                if has_null && self.has_nulls {
252                    let nulls = arrow::compute::is_null(self.values())?;
253                    arrow::compute::or(&predicate, &nulls)?
254                } else {
255                    predicate
256                }
257            }
258            SargableQuery::Range(lower_bound, upper_bound) => match (lower_bound, upper_bound) {
259                (Bound::Unbounded, Bound::Unbounded) => {
260                    panic!("Scalar range query received with no upper or lower bound")
261                }
262                (Bound::Unbounded, Bound::Included(upper)) => {
263                    arrow_ord::cmp::lt_eq(self.values(), &upper.to_scalar()?)?
264                }
265                (Bound::Unbounded, Bound::Excluded(upper)) => {
266                    arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)?
267                }
268                (Bound::Included(lower), Bound::Unbounded) => {
269                    arrow_ord::cmp::gt_eq(self.values(), &lower.to_scalar()?)?
270                }
271                (Bound::Included(lower), Bound::Included(upper)) => arrow::compute::and(
272                    &arrow_ord::cmp::gt_eq(self.values(), &lower.to_scalar()?)?,
273                    &arrow_ord::cmp::lt_eq(self.values(), &upper.to_scalar()?)?,
274                )?,
275                (Bound::Included(lower), Bound::Excluded(upper)) => arrow::compute::and(
276                    &arrow_ord::cmp::gt_eq(self.values(), &lower.to_scalar()?)?,
277                    &arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)?,
278                )?,
279                (Bound::Excluded(lower), Bound::Unbounded) => {
280                    arrow_ord::cmp::gt(self.values(), &lower.to_scalar()?)?
281                }
282                (Bound::Excluded(lower), Bound::Included(upper)) => arrow::compute::and(
283                    &arrow_ord::cmp::gt(self.values(), &lower.to_scalar()?)?,
284                    &arrow_ord::cmp::lt_eq(self.values(), &upper.to_scalar()?)?,
285                )?,
286                (Bound::Excluded(lower), Bound::Excluded(upper)) => arrow::compute::and(
287                    &arrow_ord::cmp::gt(self.values(), &lower.to_scalar()?)?,
288                    &arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)?,
289                )?,
290            },
291            SargableQuery::FullTextSearch(_) => return Err(Error::invalid_input(
292                "full text search is not supported for flat index, build a inverted index for it",
293                location!(),
294            )),
295        };
296        if self.has_nulls && matches!(query, SargableQuery::Range(_, _)) {
297            // Arrow's comparison kernels do not return false for nulls.  They consider nulls to
298            // be less than any value.  So we need to filter out the nulls manually.
299            let valid_values = arrow::compute::is_not_null(self.values())?;
300            predicate = arrow::compute::and(&valid_values, &predicate)?;
301        }
302        let matching_ids = arrow_select::filter::filter(self.ids(), &predicate)?;
303        let matching_ids = matching_ids
304            .as_any()
305            .downcast_ref::<UInt64Array>()
306            .expect("Result of arrow_select::filter::filter did not match input type");
307        Ok(SearchResult::Exact(RowIdTreeMap::from_iter(
308            matching_ids.values(),
309        )))
310    }
311
312    fn can_remap(&self) -> bool {
313        true
314    }
315
316    // Same as above, this is dead code at the moment but should work
317    async fn remap(
318        &self,
319        _mapping: &HashMap<u64, Option<u64>>,
320        _dest_store: &dyn IndexStore,
321    ) -> Result<CreatedIndex> {
322        unimplemented!()
323    }
324
325    async fn update(
326        &self,
327        _new_data: SendableRecordBatchStream,
328        _dest_store: &dyn IndexStore,
329    ) -> Result<CreatedIndex> {
330        // If this was desired, then you would need to merge new_data and data and write it back out
331        unimplemented!()
332    }
333
334    fn update_criteria(&self) -> UpdateCriteria {
335        unimplemented!()
336    }
337
338    fn derive_index_params(&self) -> Result<super::ScalarIndexParams> {
339        // FlatIndex is used internally and doesn't have user-configurable parameters
340        unimplemented!("FlatIndex is an internal index type and cannot be recreated")
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use crate::metrics::NoOpMetricsCollector;
347
348    use super::*;
349    use arrow_array::types::Int32Type;
350    use datafusion_common::ScalarValue;
351    use lance_datagen::{array, gen_batch, RowCount};
352
353    fn example_index() -> FlatIndex {
354        let batch = gen_batch()
355            .col(
356                "values",
357                array::cycle::<Int32Type>(vec![10, 100, 1000, 1234]),
358            )
359            .col("ids", array::cycle::<UInt64Type>(vec![5, 0, 3, 100]))
360            .into_batch_rows(RowCount::from(4))
361            .unwrap();
362
363        FlatIndex {
364            data: Arc::new(batch),
365            has_nulls: false,
366        }
367    }
368
369    async fn check_index(query: &SargableQuery, expected: &[u64]) {
370        let index = example_index();
371        let actual = index.search(query, &NoOpMetricsCollector).await.unwrap();
372        let SearchResult::Exact(actual_row_ids) = actual else {
373            panic! {"Expected exact search result"}
374        };
375        let expected = RowIdTreeMap::from_iter(expected);
376        assert_eq!(actual_row_ids, expected);
377    }
378
379    #[tokio::test]
380    async fn test_equality() {
381        check_index(&SargableQuery::Equals(ScalarValue::from(100)), &[0]).await;
382        check_index(&SargableQuery::Equals(ScalarValue::from(10)), &[5]).await;
383        check_index(&SargableQuery::Equals(ScalarValue::from(5)), &[]).await;
384    }
385
386    #[tokio::test]
387    async fn test_range() {
388        check_index(
389            &SargableQuery::Range(
390                Bound::Included(ScalarValue::from(100)),
391                Bound::Excluded(ScalarValue::from(1234)),
392            ),
393            &[0, 3],
394        )
395        .await;
396        check_index(
397            &SargableQuery::Range(Bound::Unbounded, Bound::Excluded(ScalarValue::from(1000))),
398            &[5, 0],
399        )
400        .await;
401        check_index(
402            &SargableQuery::Range(Bound::Included(ScalarValue::from(0)), Bound::Unbounded),
403            &[5, 0, 3, 100],
404        )
405        .await;
406        check_index(
407            &SargableQuery::Range(Bound::Included(ScalarValue::from(100000)), Bound::Unbounded),
408            &[],
409        )
410        .await;
411    }
412
413    #[tokio::test]
414    async fn test_is_in() {
415        check_index(
416            &SargableQuery::IsIn(vec![
417                ScalarValue::from(100),
418                ScalarValue::from(1234),
419                ScalarValue::from(3000),
420            ]),
421            &[0, 100],
422        )
423        .await;
424    }
425
426    #[tokio::test]
427    async fn test_remap() {
428        let index = example_index();
429        // 0 -> 2000
430        // 3 -> delete
431        // Keep remaining as is
432        let mapping = HashMap::<u64, Option<u64>>::from_iter(vec![(0, Some(2000)), (3, None)]);
433        let metadata = FlatIndexMetadata::new(DataType::Int32);
434        let remapped = metadata
435            .remap_subindex((*index.data).clone(), &mapping)
436            .await
437            .unwrap();
438
439        let expected = gen_batch()
440            .col("values", array::cycle::<Int32Type>(vec![10, 100, 1234]))
441            .col("ids", array::cycle::<UInt64Type>(vec![5, 2000, 100]))
442            .into_batch_rows(RowCount::from(3))
443            .unwrap();
444        assert_eq!(remapped, expected);
445    }
446
447    // It's possible, during compaction, that an entire page of values is deleted.  We just serialize
448    // it as an empty record batch.
449    #[tokio::test]
450    async fn test_remap_to_nothing() {
451        let index = example_index();
452        let mapping = HashMap::<u64, Option<u64>>::from_iter(vec![
453            (5, None),
454            (0, None),
455            (3, None),
456            (100, None),
457        ]);
458        let metadata = FlatIndexMetadata::new(DataType::Int32);
459        let remapped = metadata
460            .remap_subindex((*index.data).clone(), &mapping)
461            .await
462            .unwrap();
463        assert_eq!(remapped.num_rows(), 0);
464    }
465}