1use std::collections::HashMap;
5use std::{any::Any, ops::Bound, sync::Arc};
6
7use arrow_array::{
8 cast::AsArray, types::UInt64Type, ArrayRef, BooleanArray, RecordBatch, UInt64Array,
9};
10use arrow_schema::{DataType, Field, Schema};
11use async_trait::async_trait;
12
13use datafusion::physical_plan::SendableRecordBatchStream;
14use datafusion_physical_expr::expressions::{in_list, lit, Column};
15use deepsize::DeepSizeOf;
16use lance_core::error::LanceOptionExt;
17use lance_core::utils::address::RowAddress;
18use lance_core::utils::mask::RowIdTreeMap;
19use lance_core::{Error, Result, ROW_ID};
20use roaring::RoaringBitmap;
21use snafu::location;
22
23use super::{btree::BTreeSubIndex, IndexStore, ScalarIndex};
24use super::{AnyQuery, MetricsCollector, SargableQuery, SearchResult};
25use crate::scalar::btree::{BTREE_IDS_COLUMN, BTREE_VALUES_COLUMN};
26use crate::scalar::registry::VALUE_COLUMN_NAME;
27use crate::scalar::{CreatedIndex, UpdateCriteria};
28use crate::{Index, IndexType};
29
30#[derive(Debug)]
37pub struct FlatIndex {
38 data: Arc<RecordBatch>,
39 has_nulls: bool,
40}
41
42impl DeepSizeOf for FlatIndex {
43 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
44 self.data.get_array_memory_size()
45 }
46}
47
48impl FlatIndex {
49 fn values(&self) -> &ArrayRef {
50 self.data.column(0)
51 }
52
53 fn ids(&self) -> &ArrayRef {
54 self.data.column(1)
55 }
56}
57
58fn remap_batch(batch: RecordBatch, mapping: &HashMap<u64, Option<u64>>) -> Result<RecordBatch> {
59 let row_ids = batch.column(1).as_primitive::<UInt64Type>();
60 let val_idx_and_new_id = row_ids
61 .values()
62 .iter()
63 .enumerate()
64 .filter_map(|(idx, old_id)| {
65 mapping
66 .get(old_id)
67 .copied()
68 .unwrap_or(Some(*old_id))
69 .map(|new_id| (idx, new_id))
70 })
71 .collect::<Vec<_>>();
72 let new_ids = Arc::new(UInt64Array::from_iter_values(
73 val_idx_and_new_id.iter().copied().map(|(_, new_id)| new_id),
74 ));
75 let new_val_indices = UInt64Array::from_iter_values(
76 val_idx_and_new_id
77 .into_iter()
78 .map(|(val_idx, _)| val_idx as u64),
79 );
80 let new_vals = arrow_select::take::take(batch.column(0), &new_val_indices, None)?;
81 Ok(RecordBatch::try_new(
82 batch.schema(),
83 vec![new_vals, new_ids],
84 )?)
85}
86
87#[derive(Debug)]
91pub struct FlatIndexMetadata {
92 schema: Arc<Schema>,
93}
94
95impl DeepSizeOf for FlatIndexMetadata {
96 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
97 self.schema.metadata.deep_size_of_children(context)
98 + self
99 .schema
100 .fields
101 .iter()
102 .map(|f| {
105 std::mem::size_of::<Field>()
106 + f.name().deep_size_of_children(context)
107 + f.metadata().deep_size_of_children(context)
108 })
109 .sum::<usize>()
110 }
111}
112
113impl FlatIndexMetadata {
114 pub fn new(value_type: DataType) -> Self {
115 let schema = Arc::new(Schema::new(vec![
116 Field::new(BTREE_VALUES_COLUMN, value_type, true),
117 Field::new(BTREE_IDS_COLUMN, DataType::UInt64, true),
118 ]));
119 Self { schema }
120 }
121}
122
123#[async_trait]
124impl BTreeSubIndex for FlatIndexMetadata {
125 fn schema(&self) -> &Arc<Schema> {
126 &self.schema
127 }
128
129 async fn train(&self, batch: RecordBatch) -> Result<RecordBatch> {
130 Ok(RecordBatch::try_new(
133 self.schema.clone(),
134 vec![
135 batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?.clone(),
136 batch.column_by_name(ROW_ID).expect_ok()?.clone(),
137 ],
138 )?)
139 }
140
141 async fn load_subindex(&self, serialized: RecordBatch) -> Result<Arc<dyn ScalarIndex>> {
142 let has_nulls = serialized.column(0).null_count() > 0;
143 Ok(Arc::new(FlatIndex {
144 data: Arc::new(serialized),
145 has_nulls,
146 }))
147 }
148
149 async fn remap_subindex(
150 &self,
151 serialized: RecordBatch,
152 mapping: &HashMap<u64, Option<u64>>,
153 ) -> Result<RecordBatch> {
154 remap_batch(serialized, mapping)
155 }
156
157 async fn retrieve_data(&self, serialized: RecordBatch) -> Result<RecordBatch> {
158 Ok(serialized)
159 }
160}
161
162#[async_trait]
163impl Index for FlatIndex {
164 fn as_any(&self) -> &dyn Any {
165 self
166 }
167
168 fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
169 self
170 }
171
172 fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
173 Err(Error::NotSupported {
174 source: "FlatIndex is not vector index".into(),
175 location: location!(),
176 })
177 }
178
179 fn index_type(&self) -> IndexType {
180 IndexType::Scalar
181 }
182
183 async fn prewarm(&self) -> Result<()> {
184 Ok(())
186 }
187
188 fn statistics(&self) -> Result<serde_json::Value> {
189 Ok(serde_json::json!({
190 "num_values": self.data.num_rows(),
191 }))
192 }
193
194 async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
195 let mut frag_ids = self
196 .ids()
197 .as_primitive::<UInt64Type>()
198 .iter()
199 .map(|row_id| RowAddress::from(row_id.unwrap()).fragment_id())
200 .collect::<Vec<_>>();
201 frag_ids.sort();
202 frag_ids.dedup();
203 Ok(RoaringBitmap::from_sorted_iter(frag_ids).unwrap())
204 }
205}
206
207#[async_trait]
208impl ScalarIndex for FlatIndex {
209 async fn search(
210 &self,
211 query: &dyn AnyQuery,
212 metrics: &dyn MetricsCollector,
213 ) -> Result<SearchResult> {
214 metrics.record_comparisons(self.data.num_rows());
215 let query = query.as_any().downcast_ref::<SargableQuery>().unwrap();
216 let mut predicate = match query {
219 SargableQuery::Equals(value) => {
220 if value.is_null() {
221 arrow::compute::is_null(self.values())?
222 } else {
223 arrow_ord::cmp::eq(self.values(), &value.to_scalar()?)?
224 }
225 }
226 SargableQuery::IsNull() => arrow::compute::is_null(self.values())?,
227 SargableQuery::IsIn(values) => {
228 let mut has_null = false;
229 let choices = values
230 .iter()
231 .map(|val| {
232 has_null |= val.is_null();
233 lit(val.clone())
234 })
235 .collect::<Vec<_>>();
236 let in_list_expr = in_list(
237 Arc::new(Column::new("values", 0)),
238 choices,
239 &false,
240 &self.data.schema(),
241 )?;
242 let result_col = in_list_expr.evaluate(&self.data)?;
243 let predicate = result_col
244 .into_array(self.data.num_rows())?
245 .as_any()
246 .downcast_ref::<BooleanArray>()
247 .expect("InList evaluation should return boolean array")
248 .clone();
249
250 if has_null && self.has_nulls {
252 let nulls = arrow::compute::is_null(self.values())?;
253 arrow::compute::or(&predicate, &nulls)?
254 } else {
255 predicate
256 }
257 }
258 SargableQuery::Range(lower_bound, upper_bound) => match (lower_bound, upper_bound) {
259 (Bound::Unbounded, Bound::Unbounded) => {
260 panic!("Scalar range query received with no upper or lower bound")
261 }
262 (Bound::Unbounded, Bound::Included(upper)) => {
263 arrow_ord::cmp::lt_eq(self.values(), &upper.to_scalar()?)?
264 }
265 (Bound::Unbounded, Bound::Excluded(upper)) => {
266 arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)?
267 }
268 (Bound::Included(lower), Bound::Unbounded) => {
269 arrow_ord::cmp::gt_eq(self.values(), &lower.to_scalar()?)?
270 }
271 (Bound::Included(lower), Bound::Included(upper)) => arrow::compute::and(
272 &arrow_ord::cmp::gt_eq(self.values(), &lower.to_scalar()?)?,
273 &arrow_ord::cmp::lt_eq(self.values(), &upper.to_scalar()?)?,
274 )?,
275 (Bound::Included(lower), Bound::Excluded(upper)) => arrow::compute::and(
276 &arrow_ord::cmp::gt_eq(self.values(), &lower.to_scalar()?)?,
277 &arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)?,
278 )?,
279 (Bound::Excluded(lower), Bound::Unbounded) => {
280 arrow_ord::cmp::gt(self.values(), &lower.to_scalar()?)?
281 }
282 (Bound::Excluded(lower), Bound::Included(upper)) => arrow::compute::and(
283 &arrow_ord::cmp::gt(self.values(), &lower.to_scalar()?)?,
284 &arrow_ord::cmp::lt_eq(self.values(), &upper.to_scalar()?)?,
285 )?,
286 (Bound::Excluded(lower), Bound::Excluded(upper)) => arrow::compute::and(
287 &arrow_ord::cmp::gt(self.values(), &lower.to_scalar()?)?,
288 &arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)?,
289 )?,
290 },
291 SargableQuery::FullTextSearch(_) => return Err(Error::invalid_input(
292 "full text search is not supported for flat index, build a inverted index for it",
293 location!(),
294 )),
295 };
296 if self.has_nulls && matches!(query, SargableQuery::Range(_, _)) {
297 let valid_values = arrow::compute::is_not_null(self.values())?;
300 predicate = arrow::compute::and(&valid_values, &predicate)?;
301 }
302 let matching_ids = arrow_select::filter::filter(self.ids(), &predicate)?;
303 let matching_ids = matching_ids
304 .as_any()
305 .downcast_ref::<UInt64Array>()
306 .expect("Result of arrow_select::filter::filter did not match input type");
307 Ok(SearchResult::Exact(RowIdTreeMap::from_iter(
308 matching_ids.values(),
309 )))
310 }
311
312 fn can_remap(&self) -> bool {
313 true
314 }
315
316 async fn remap(
318 &self,
319 _mapping: &HashMap<u64, Option<u64>>,
320 _dest_store: &dyn IndexStore,
321 ) -> Result<CreatedIndex> {
322 unimplemented!()
323 }
324
325 async fn update(
326 &self,
327 _new_data: SendableRecordBatchStream,
328 _dest_store: &dyn IndexStore,
329 ) -> Result<CreatedIndex> {
330 unimplemented!()
332 }
333
334 fn update_criteria(&self) -> UpdateCriteria {
335 unimplemented!()
336 }
337
338 fn derive_index_params(&self) -> Result<super::ScalarIndexParams> {
339 unimplemented!("FlatIndex is an internal index type and cannot be recreated")
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use crate::metrics::NoOpMetricsCollector;
347
348 use super::*;
349 use arrow_array::types::Int32Type;
350 use datafusion_common::ScalarValue;
351 use lance_datagen::{array, gen_batch, RowCount};
352
353 fn example_index() -> FlatIndex {
354 let batch = gen_batch()
355 .col(
356 "values",
357 array::cycle::<Int32Type>(vec![10, 100, 1000, 1234]),
358 )
359 .col("ids", array::cycle::<UInt64Type>(vec![5, 0, 3, 100]))
360 .into_batch_rows(RowCount::from(4))
361 .unwrap();
362
363 FlatIndex {
364 data: Arc::new(batch),
365 has_nulls: false,
366 }
367 }
368
369 async fn check_index(query: &SargableQuery, expected: &[u64]) {
370 let index = example_index();
371 let actual = index.search(query, &NoOpMetricsCollector).await.unwrap();
372 let SearchResult::Exact(actual_row_ids) = actual else {
373 panic! {"Expected exact search result"}
374 };
375 let expected = RowIdTreeMap::from_iter(expected);
376 assert_eq!(actual_row_ids, expected);
377 }
378
379 #[tokio::test]
380 async fn test_equality() {
381 check_index(&SargableQuery::Equals(ScalarValue::from(100)), &[0]).await;
382 check_index(&SargableQuery::Equals(ScalarValue::from(10)), &[5]).await;
383 check_index(&SargableQuery::Equals(ScalarValue::from(5)), &[]).await;
384 }
385
386 #[tokio::test]
387 async fn test_range() {
388 check_index(
389 &SargableQuery::Range(
390 Bound::Included(ScalarValue::from(100)),
391 Bound::Excluded(ScalarValue::from(1234)),
392 ),
393 &[0, 3],
394 )
395 .await;
396 check_index(
397 &SargableQuery::Range(Bound::Unbounded, Bound::Excluded(ScalarValue::from(1000))),
398 &[5, 0],
399 )
400 .await;
401 check_index(
402 &SargableQuery::Range(Bound::Included(ScalarValue::from(0)), Bound::Unbounded),
403 &[5, 0, 3, 100],
404 )
405 .await;
406 check_index(
407 &SargableQuery::Range(Bound::Included(ScalarValue::from(100000)), Bound::Unbounded),
408 &[],
409 )
410 .await;
411 }
412
413 #[tokio::test]
414 async fn test_is_in() {
415 check_index(
416 &SargableQuery::IsIn(vec![
417 ScalarValue::from(100),
418 ScalarValue::from(1234),
419 ScalarValue::from(3000),
420 ]),
421 &[0, 100],
422 )
423 .await;
424 }
425
426 #[tokio::test]
427 async fn test_remap() {
428 let index = example_index();
429 let mapping = HashMap::<u64, Option<u64>>::from_iter(vec![(0, Some(2000)), (3, None)]);
433 let metadata = FlatIndexMetadata::new(DataType::Int32);
434 let remapped = metadata
435 .remap_subindex((*index.data).clone(), &mapping)
436 .await
437 .unwrap();
438
439 let expected = gen_batch()
440 .col("values", array::cycle::<Int32Type>(vec![10, 100, 1234]))
441 .col("ids", array::cycle::<UInt64Type>(vec![5, 2000, 100]))
442 .into_batch_rows(RowCount::from(3))
443 .unwrap();
444 assert_eq!(remapped, expected);
445 }
446
447 #[tokio::test]
450 async fn test_remap_to_nothing() {
451 let index = example_index();
452 let mapping = HashMap::<u64, Option<u64>>::from_iter(vec![
453 (5, None),
454 (0, None),
455 (3, None),
456 (100, None),
457 ]);
458 let metadata = FlatIndexMetadata::new(DataType::Int32);
459 let remapped = metadata
460 .remap_subindex((*index.data).clone(), &mapping)
461 .await
462 .unwrap();
463 assert_eq!(remapped.num_rows(), 0);
464 }
465}