Skip to main content

lance_index/scalar/
lance_format.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Utilities for serializing and deserializing scalar indices in the lance format
5
6use super::{IndexReader, IndexStore, IndexWriter};
7use arrow_array::RecordBatch;
8use arrow_schema::Schema;
9use async_trait::async_trait;
10use bytes::Bytes;
11use deepsize::DeepSizeOf;
12use futures::TryStreamExt;
13use lance_core::{Error, Result, cache::LanceCache};
14use lance_encoding::decoder::{DecoderPlugins, FilterExpression};
15use lance_encoding::version::LanceFileVersion;
16use lance_file::previous::{
17    reader::FileReader as PreviousFileReader,
18    writer::{FileWriter as PreviousFileWriter, ManifestProvider as PreviousManifestProvider},
19};
20use lance_file::reader::{self as current_reader, FileReaderOptions, ReaderProjection};
21use lance_file::writer as current_writer;
22use lance_io::scheduler::{ScanScheduler, SchedulerConfig};
23use lance_io::utils::CachedFileSize;
24use lance_io::{ReadBatchParams, object_store::ObjectStore};
25use lance_table::format::SelfDescribingFileReader;
26use lance_table::format::{IndexFile, list_index_files_with_sizes};
27use object_store::path::Path;
28use std::cmp::min;
29use std::collections::HashMap;
30use std::{any::Any, sync::Arc};
31
32/// An index store that serializes scalar indices using the lance format
33///
34/// Scalar indices are made up of named collections of record batches.  This
35/// struct relies on there being a dedicated directory for the index and stores
36/// each collection in a file in the lance format.
37#[derive(Debug, Clone)]
38pub struct LanceIndexStore {
39    object_store: Arc<ObjectStore>,
40    index_dir: Path,
41    metadata_cache: Arc<LanceCache>,
42    scheduler: Arc<ScanScheduler>,
43    /// Cached file sizes (filename -> size in bytes)
44    /// When set, used to avoid HEAD calls when opening files
45    file_sizes: HashMap<String, u64>,
46    format_version: LanceFileVersion,
47}
48
49impl DeepSizeOf for LanceIndexStore {
50    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
51        self.object_store.deep_size_of_children(context)
52            + self.index_dir.as_ref().deep_size_of_children(context)
53            + self.metadata_cache.deep_size_of_children(context)
54    }
55}
56
57impl LanceIndexStore {
58    /// Create a new index store at the given directory
59    pub fn new(
60        object_store: Arc<ObjectStore>,
61        index_dir: Path,
62        metadata_cache: Arc<LanceCache>,
63    ) -> Self {
64        Self::with_format_version(
65            object_store,
66            index_dir,
67            metadata_cache,
68            LanceFileVersion::V2_0,
69        )
70    }
71
72    /// Create a new index store at the given directory with a specific format version
73    pub fn with_format_version(
74        object_store: Arc<ObjectStore>,
75        index_dir: Path,
76        metadata_cache: Arc<LanceCache>,
77        format_version: LanceFileVersion,
78    ) -> Self {
79        let scheduler = ScanScheduler::new(
80            object_store.clone(),
81            SchedulerConfig::max_bandwidth(&object_store),
82        );
83        Self {
84            object_store,
85            index_dir,
86            metadata_cache,
87            scheduler,
88            file_sizes: HashMap::new(),
89            format_version,
90        }
91    }
92
93    /// Set cached file sizes to avoid HEAD calls when opening files.
94    ///
95    /// The map should contain relative paths (e.g., "index.idx") as keys
96    /// and file sizes in bytes as values.
97    pub fn with_file_sizes(mut self, file_sizes: HashMap<String, u64>) -> Self {
98        self.file_sizes = file_sizes;
99        self
100    }
101}
102
103#[async_trait]
104impl<M: PreviousManifestProvider + Send + Sync> IndexWriter for PreviousFileWriter<M> {
105    async fn write_record_batch(&mut self, batch: RecordBatch) -> Result<u64> {
106        let offset = self.tell().await?;
107        self.write(&[batch]).await?;
108        Ok(offset as u64)
109    }
110
111    async fn finish(&mut self) -> Result<()> {
112        Self::finish(self).await.map(|_| ())
113    }
114
115    async fn finish_with_metadata(&mut self, metadata: HashMap<String, String>) -> Result<()> {
116        Self::finish_with_metadata(self, &metadata)
117            .await
118            .map(|_| ())
119    }
120}
121
122#[async_trait]
123impl IndexWriter for current_writer::FileWriter {
124    async fn write_record_batch(&mut self, batch: RecordBatch) -> Result<u64> {
125        let offset = self.tell().await?;
126        self.write_batch(&batch).await?;
127        Ok(offset)
128    }
129
130    async fn add_global_buffer(&mut self, data: Bytes) -> Result<u32> {
131        Self::add_global_buffer(self, data).await
132    }
133
134    async fn finish(&mut self) -> Result<()> {
135        Self::finish(self).await.map(|_| ())
136    }
137
138    async fn finish_with_metadata(&mut self, metadata: HashMap<String, String>) -> Result<()> {
139        metadata.into_iter().for_each(|(k, v)| {
140            self.add_schema_metadata(k, v);
141        });
142        Self::finish(self).await.map(|_| ())
143    }
144}
145
146#[async_trait]
147impl IndexReader for PreviousFileReader {
148    async fn read_record_batch(&self, offset: u64, _batch_size: u64) -> Result<RecordBatch> {
149        self.read_batch(offset as i32, ReadBatchParams::RangeFull, self.schema())
150            .await
151    }
152
153    async fn read_range(
154        &self,
155        range: std::ops::Range<usize>,
156        projection: Option<&[&str]>,
157    ) -> Result<RecordBatch> {
158        let projection = match projection {
159            Some(projection) => self.schema().project(projection)?,
160            None => self.schema().clone(),
161        };
162        self.read_range(range, &projection).await
163    }
164
165    async fn num_batches(&self, _batch_size: u64) -> u32 {
166        self.num_batches() as u32
167    }
168
169    fn num_rows(&self) -> usize {
170        self.len()
171    }
172
173    fn schema(&self) -> &lance_core::datatypes::Schema {
174        Self::schema(self)
175    }
176}
177
178#[async_trait]
179impl IndexReader for current_reader::FileReader {
180    async fn read_record_batch(&self, offset: u64, batch_size: u64) -> Result<RecordBatch> {
181        let start = offset * batch_size;
182        let end = start + batch_size;
183        let end = end.min(self.num_rows());
184        self.read_range(start as usize..end as usize, None).await
185    }
186
187    async fn read_global_buffer(&self, n: u32) -> Result<Bytes> {
188        Self::read_global_buffer(self, n).await
189    }
190
191    async fn read_range(
192        &self,
193        range: std::ops::Range<usize>,
194        projection: Option<&[&str]>,
195    ) -> Result<RecordBatch> {
196        if range.is_empty() {
197            return Ok(RecordBatch::new_empty(Arc::new(
198                self.schema().as_ref().into(),
199            )));
200        }
201        let projection = if let Some(projection) = projection {
202            ReaderProjection::from_column_names(
203                self.metadata().version(),
204                self.schema(),
205                projection,
206            )?
207        } else {
208            ReaderProjection::from_whole_schema(self.schema(), self.metadata().version())
209        };
210        let batches = self
211            .read_stream_projected(
212                ReadBatchParams::Range(range),
213                u32::MAX,
214                u32::MAX,
215                projection,
216                FilterExpression::no_filter(),
217            )
218            .await?
219            .try_collect::<Vec<_>>()
220            .await?;
221        assert_eq!(batches.len(), 1);
222        Ok(batches[0].clone())
223    }
224
225    // V2 format has removed the row group concept,
226    // so here we assume each batch is with 4096 rows.
227    async fn num_batches(&self, batch_size: u64) -> u32 {
228        Self::num_rows(self).div_ceil(batch_size) as u32
229    }
230
231    fn num_rows(&self) -> usize {
232        Self::num_rows(self) as usize
233    }
234
235    fn schema(&self) -> &lance_core::datatypes::Schema {
236        Self::schema(self)
237    }
238}
239
240#[async_trait]
241impl IndexStore for LanceIndexStore {
242    fn as_any(&self) -> &dyn Any {
243        self
244    }
245
246    fn clone_arc(&self) -> Arc<dyn IndexStore> {
247        Arc::new(self.clone())
248    }
249
250    fn io_parallelism(&self) -> usize {
251        self.object_store.io_parallelism()
252    }
253
254    async fn new_index_file(
255        &self,
256        name: &str,
257        schema: Arc<Schema>,
258    ) -> Result<Box<dyn IndexWriter>> {
259        let path = self.index_dir.child(name);
260        let schema = schema.as_ref().try_into()?;
261        let writer = self.object_store.create(&path).await?;
262        let writer = current_writer::FileWriter::try_new(
263            writer,
264            schema,
265            current_writer::FileWriterOptions {
266                format_version: Some(self.format_version),
267                ..Default::default()
268            },
269        )?;
270        Ok(Box::new(writer))
271    }
272
273    async fn open_index_file(&self, name: &str) -> Result<Arc<dyn IndexReader>> {
274        let path = self.index_dir.child(name);
275        // Use cached file size if available, otherwise unknown (requires HEAD call)
276        let cached_size = self
277            .file_sizes
278            .get(name)
279            .map(|&size| CachedFileSize::new(size))
280            .unwrap_or_else(CachedFileSize::unknown);
281        let file_scheduler = self.scheduler.open_file(&path, &cached_size).await?;
282        match current_reader::FileReader::try_open(
283            file_scheduler,
284            None,
285            Arc::<DecoderPlugins>::default(),
286            &self.metadata_cache,
287            FileReaderOptions::default(),
288        )
289        .await
290        {
291            Ok(reader) => Ok(Arc::new(reader)),
292            Err(e) => {
293                // If the error is a version conflict we can try to read the file with v1 reader
294                if let Error::VersionConflict { .. } = e {
295                    let path = self.index_dir.child(name);
296                    let file_reader = PreviousFileReader::try_new_self_described(
297                        &self.object_store,
298                        &path,
299                        Some(&self.metadata_cache),
300                    )
301                    .await?;
302                    Ok(Arc::new(file_reader))
303                } else {
304                    Err(e)
305                }
306            }
307        }
308    }
309
310    async fn copy_index_file(&self, name: &str, dest_store: &dyn IndexStore) -> Result<()> {
311        let path = self.index_dir.child(name);
312
313        let other_store = dest_store.as_any().downcast_ref::<Self>();
314        match other_store {
315            Some(dest_store) if dest_store.object_store.scheme() == self.object_store.scheme() => {
316                // If both this store and the destination are lance stores we can use object_store's copy
317                // This does blindly assume that both stores are using the same underlying object_store
318                // but there is no easy way to verify this and it happens to always be true at the moment
319                let dest_path = dest_store.index_dir.child(name);
320                self.object_store.copy(&path, &dest_path).await
321            }
322            _ => {
323                let reader = self.open_index_file(name).await?;
324                let mut writer = dest_store
325                    .new_index_file(name, Arc::new(reader.schema().into()))
326                    .await?;
327
328                for offset in (0..reader.num_rows()).step_by(4096) {
329                    let next_offset = min(offset + 4096, reader.num_rows());
330                    let batch = reader.read_range(offset..next_offset, None).await?;
331                    writer.write_record_batch(batch).await?;
332                }
333                writer.finish().await?;
334
335                Ok(())
336            }
337        }
338    }
339
340    async fn rename_index_file(&self, name: &str, new_name: &str) -> Result<()> {
341        let path = self.index_dir.child(name);
342        let new_path = self.index_dir.child(new_name);
343        self.object_store.copy(&path, &new_path).await?;
344        self.object_store.delete(&path).await
345    }
346
347    async fn delete_index_file(&self, name: &str) -> Result<()> {
348        let path = self.index_dir.child(name);
349        self.object_store.delete(&path).await
350    }
351
352    async fn list_files_with_sizes(&self) -> Result<Vec<IndexFile>> {
353        list_index_files_with_sizes(&self.object_store, &self.index_dir).await
354    }
355}
356
357#[cfg(test)]
358mod tests {
359
360    use std::{collections::HashMap, ops::Bound};
361
362    use crate::metrics::NoOpMetricsCollector;
363    use crate::pbold;
364    use crate::scalar::bitmap::BitmapIndexPlugin;
365    use crate::scalar::btree::{BTreeIndexPlugin, BTreeParameters};
366    use crate::scalar::label_list::LabelListIndexPlugin;
367    use crate::scalar::registry::{ScalarIndexPlugin, VALUE_COLUMN_NAME};
368    use crate::scalar::{
369        LabelListQuery, SargableQuery, ScalarIndex, SearchResult,
370        bitmap::BitmapIndex,
371        btree::{DEFAULT_BTREE_BATCH_SIZE, train_btree_index},
372    };
373
374    use super::*;
375    use arrow::{buffer::ScalarBuffer, datatypes::UInt8Type};
376    use arrow_array::{
377        ListArray, RecordBatchIterator, RecordBatchReader, StringArray, UInt64Array,
378        cast::AsArray,
379        types::{Int32Type, UInt64Type},
380    };
381    use arrow_schema::Schema as ArrowSchema;
382    use arrow_schema::{DataType, Field, TimeUnit};
383    use arrow_select::take::TakeOptions;
384    use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
385    use datafusion_common::ScalarValue;
386    use futures::FutureExt;
387    use lance_core::ROW_ID;
388    use lance_core::utils::mask::{RowAddrTreeMap, RowSetOps};
389    use lance_core::utils::tempfile::TempDir;
390    use lance_datagen::{ArrayGeneratorExt, BatchCount, ByteCount, RowCount, array, gen_batch};
391
392    fn test_store(tempdir: &TempDir) -> Arc<dyn IndexStore> {
393        let test_path = tempdir.obj_path();
394        let (object_store, test_path) = ObjectStore::from_uri(test_path.as_ref())
395            .now_or_never()
396            .unwrap()
397            .unwrap();
398        let cache = Arc::new(lance_core::cache::LanceCache::with_capacity(
399            128 * 1024 * 1024,
400        ));
401        Arc::new(LanceIndexStore::new(object_store, test_path, cache))
402    }
403
404    async fn train_index(
405        index_store: &Arc<dyn IndexStore>,
406        data: impl RecordBatchReader + Send + Sync + 'static,
407        custom_batch_size: Option<u64>,
408    ) {
409        let batch_size = custom_batch_size.unwrap_or(DEFAULT_BTREE_BATCH_SIZE);
410        let params = BTreeParameters {
411            zone_size: Some(batch_size),
412            range_id: None,
413        };
414        let params = serde_json::to_string(&params).unwrap();
415        let btree_plugin = BTreeIndexPlugin;
416        let data = lance_datafusion::utils::reader_to_stream(Box::new(data));
417        let request = btree_plugin
418            .new_training_request(
419                &params,
420                &Field::new(VALUE_COLUMN_NAME, DataType::Int32, false),
421            )
422            .unwrap();
423        btree_plugin
424            .train_index(
425                data,
426                index_store.as_ref(),
427                request,
428                None,
429                crate::progress::noop_progress(),
430            )
431            .await
432            .unwrap();
433    }
434
435    fn default_details<T: prost::Message + prost::Name + std::default::Default>() -> prost_types::Any
436    {
437        prost_types::Any::from_msg(&T::default()).unwrap()
438    }
439
440    #[tokio::test]
441    async fn test_global_buffer_round_trip() {
442        let tempdir = TempDir::default();
443        let index_store = test_store(&tempdir);
444
445        let mut writer = index_store
446            .new_index_file("global-buffer.lance", Arc::new(Schema::empty()))
447            .await
448            .unwrap();
449        let expected = bytes::Bytes::from_static(b"scalar-global-buffer");
450        let buffer_idx = writer.add_global_buffer(expected.clone()).await.unwrap();
451        writer.finish().await.unwrap();
452
453        let reader = index_store
454            .open_index_file("global-buffer.lance")
455            .await
456            .unwrap();
457        let actual = reader.read_global_buffer(buffer_idx).await.unwrap();
458
459        assert_eq!(actual, expected);
460    }
461
462    #[tokio::test]
463    async fn test_basic_btree() {
464        let tempdir = TempDir::default();
465        let index_store = test_store(&tempdir);
466        let data = gen_batch()
467            .col(VALUE_COLUMN_NAME, array::step::<Int32Type>())
468            .col(ROW_ID, array::step::<UInt64Type>())
469            .into_reader_rows(RowCount::from(4096), BatchCount::from(100));
470        train_index(&index_store, data, None).await;
471        let index = BTreeIndexPlugin
472            .load_index(
473                index_store,
474                &default_details::<pbold::BTreeIndexDetails>(),
475                None,
476                &LanceCache::no_cache(),
477            )
478            .await
479            .unwrap();
480
481        let result = index
482            .search(
483                &SargableQuery::Equals(ScalarValue::Int32(Some(10000))),
484                &NoOpMetricsCollector,
485            )
486            .await
487            .unwrap();
488
489        assert!(result.is_exact());
490        let row_ids = result.row_addrs().true_rows();
491        assert_eq!(Some(1), row_ids.len());
492        assert!(row_ids.contains(10000));
493
494        let result = index
495            .search(
496                &SargableQuery::Range(
497                    Bound::Unbounded,
498                    Bound::Excluded(ScalarValue::Int32(Some(-100))),
499                ),
500                &NoOpMetricsCollector,
501            )
502            .await
503            .unwrap();
504
505        assert!(result.is_exact());
506        let row_addrs = result.row_addrs().true_rows();
507
508        assert_eq!(Some(0), row_addrs.len());
509
510        let result = index
511            .search(
512                &SargableQuery::Range(
513                    Bound::Unbounded,
514                    Bound::Excluded(ScalarValue::Int32(Some(100))),
515                ),
516                &NoOpMetricsCollector,
517            )
518            .await
519            .unwrap();
520
521        assert!(result.is_exact());
522        let row_addrs = result.row_addrs().true_rows();
523
524        assert_eq!(Some(100), row_addrs.len());
525    }
526
527    #[tokio::test]
528    async fn test_btree_update() {
529        let index_dir = TempDir::default();
530        let index_store = test_store(&index_dir);
531        let data = gen_batch()
532            .col(VALUE_COLUMN_NAME, array::step::<Int32Type>())
533            .col(ROW_ID, array::step::<UInt64Type>())
534            .into_reader_rows(RowCount::from(4096), BatchCount::from(100));
535        train_index(&index_store, data, None).await;
536        let index = BTreeIndexPlugin
537            .load_index(
538                index_store,
539                &default_details::<pbold::BTreeIndexDetails>(),
540                None,
541                &LanceCache::no_cache(),
542            )
543            .await
544            .unwrap();
545
546        let data = gen_batch()
547            .col(
548                VALUE_COLUMN_NAME,
549                array::step_custom::<Int32Type>(4096 * 100, 1),
550            )
551            .col(ROW_ID, array::step_custom::<UInt64Type>(4096 * 100, 1))
552            .into_reader_rows(RowCount::from(4096), BatchCount::from(100));
553
554        let updated_index_dir = TempDir::default();
555        let updated_index_store = test_store(&updated_index_dir);
556        index
557            .update(
558                lance_datafusion::utils::reader_to_stream(Box::new(data)),
559                updated_index_store.as_ref(),
560                None,
561            )
562            .await
563            .unwrap();
564        let updated_index = BTreeIndexPlugin
565            .load_index(
566                updated_index_store,
567                &default_details::<pbold::BTreeIndexDetails>(),
568                None,
569                &LanceCache::no_cache(),
570            )
571            .await
572            .unwrap();
573
574        let result = updated_index
575            .search(
576                &SargableQuery::Equals(ScalarValue::Int32(Some(10000))),
577                &NoOpMetricsCollector,
578            )
579            .await
580            .unwrap();
581
582        assert!(result.is_exact());
583        let row_addrs = result.row_addrs().true_rows();
584
585        assert_eq!(Some(1), row_addrs.len());
586        assert!(row_addrs.contains(10000));
587
588        let result = updated_index
589            .search(
590                &SargableQuery::Equals(ScalarValue::Int32(Some(500_000))),
591                &NoOpMetricsCollector,
592            )
593            .await
594            .unwrap();
595
596        assert!(result.is_exact());
597        let row_addrs = result.row_addrs().true_rows();
598
599        assert_eq!(Some(1), row_addrs.len());
600        assert!(row_addrs.contains(500_000));
601    }
602
603    async fn check(index: &Arc<dyn ScalarIndex>, query: SargableQuery, expected: &[u64]) {
604        let results = index.search(&query, &NoOpMetricsCollector).await.unwrap();
605        assert!(results.is_exact());
606        let expected_arr = RowAddrTreeMap::from_iter(expected);
607        assert_eq!(&results.row_addrs().true_rows(), &expected_arr);
608    }
609
610    #[tokio::test]
611    async fn test_btree_with_gaps() {
612        let tempdir = TempDir::default();
613        let index_store = test_store(&tempdir);
614        let batch_one = gen_batch()
615            .col(
616                VALUE_COLUMN_NAME,
617                array::cycle::<Int32Type>(vec![0, 1, 4, 5]),
618            )
619            .col(ROW_ID, array::cycle::<UInt64Type>(vec![0, 1, 2, 3]))
620            .into_batch_rows(RowCount::from(4));
621        let batch_two = gen_batch()
622            .col(
623                VALUE_COLUMN_NAME,
624                array::cycle::<Int32Type>(vec![10, 11, 11, 15]),
625            )
626            .col(ROW_ID, array::cycle::<UInt64Type>(vec![40, 50, 60, 70]))
627            .into_batch_rows(RowCount::from(4));
628        let batch_three = gen_batch()
629            .col(
630                VALUE_COLUMN_NAME,
631                array::cycle::<Int32Type>(vec![15, 15, 15, 15]),
632            )
633            .col(ROW_ID, array::cycle::<UInt64Type>(vec![400, 500, 600, 700]))
634            .into_batch_rows(RowCount::from(4));
635        let batch_four = gen_batch()
636            .col(
637                VALUE_COLUMN_NAME,
638                array::cycle::<Int32Type>(vec![15, 16, 20, 20]),
639            )
640            .col(
641                ROW_ID,
642                array::cycle::<UInt64Type>(vec![4000, 5000, 6000, 7000]),
643            )
644            .into_batch_rows(RowCount::from(4));
645        let batches = vec![batch_one, batch_two, batch_three, batch_four];
646        let schema = Arc::new(Schema::new(vec![
647            Field::new(VALUE_COLUMN_NAME, DataType::Int32, false),
648            Field::new(ROW_ID, DataType::UInt64, false),
649        ]));
650        let data = RecordBatchIterator::new(batches, schema);
651        train_index(&index_store, data, Some(4)).await;
652        let index = BTreeIndexPlugin
653            .load_index(
654                index_store,
655                &default_details::<pbold::BTreeIndexDetails>(),
656                None,
657                &LanceCache::no_cache(),
658            )
659            .await
660            .unwrap();
661
662        // The above should create four pages
663        //
664        // 0 - 5
665        // 10 - 15
666        // 15 - 15
667        // 15 - 20
668        //
669        // This will help us test various indexing corner cases
670
671        // No results (off the left side)
672        check(
673            &index,
674            SargableQuery::Equals(ScalarValue::Int32(Some(-3))),
675            &[],
676        )
677        .await;
678
679        check(
680            &index,
681            SargableQuery::Range(
682                Bound::Unbounded,
683                Bound::Included(ScalarValue::Int32(Some(-3))),
684            ),
685            &[],
686        )
687        .await;
688
689        check(
690            &index,
691            SargableQuery::Range(
692                Bound::Included(ScalarValue::Int32(Some(-10))),
693                Bound::Included(ScalarValue::Int32(Some(-3))),
694            ),
695            &[],
696        )
697        .await;
698
699        // Hitting the middle of a bucket
700        check(
701            &index,
702            SargableQuery::Equals(ScalarValue::Int32(Some(4))),
703            &[2],
704        )
705        .await;
706
707        // Hitting a gap between two buckets
708        check(
709            &index,
710            SargableQuery::Equals(ScalarValue::Int32(Some(7))),
711            &[],
712        )
713        .await;
714
715        // Hitting the lowest of the overlapping buckets
716        check(
717            &index,
718            SargableQuery::Equals(ScalarValue::Int32(Some(11))),
719            &[50, 60],
720        )
721        .await;
722
723        // Hitting the 15 shared on all three buckets
724        check(
725            &index,
726            SargableQuery::Equals(ScalarValue::Int32(Some(15))),
727            &[70, 400, 500, 600, 700, 4000],
728        )
729        .await;
730
731        // Hitting the upper part of the three overlapping buckets
732        check(
733            &index,
734            SargableQuery::Equals(ScalarValue::Int32(Some(20))),
735            &[6000, 7000],
736        )
737        .await;
738
739        // Ranges that capture multiple buckets
740        check(
741            &index,
742            SargableQuery::Range(
743                Bound::Unbounded,
744                Bound::Included(ScalarValue::Int32(Some(11))),
745            ),
746            &[0, 1, 2, 3, 40, 50, 60],
747        )
748        .await;
749
750        check(
751            &index,
752            SargableQuery::Range(
753                Bound::Unbounded,
754                Bound::Excluded(ScalarValue::Int32(Some(11))),
755            ),
756            &[0, 1, 2, 3, 40],
757        )
758        .await;
759
760        check(
761            &index,
762            SargableQuery::Range(
763                Bound::Included(ScalarValue::Int32(Some(4))),
764                Bound::Unbounded,
765            ),
766            &[
767                2, 3, 40, 50, 60, 70, 400, 500, 600, 700, 4000, 5000, 6000, 7000,
768            ],
769        )
770        .await;
771
772        check(
773            &index,
774            SargableQuery::Range(
775                Bound::Included(ScalarValue::Int32(Some(4))),
776                Bound::Included(ScalarValue::Int32(Some(11))),
777            ),
778            &[2, 3, 40, 50, 60],
779        )
780        .await;
781
782        check(
783            &index,
784            SargableQuery::Range(
785                Bound::Included(ScalarValue::Int32(Some(4))),
786                Bound::Excluded(ScalarValue::Int32(Some(11))),
787            ),
788            &[2, 3, 40],
789        )
790        .await;
791
792        check(
793            &index,
794            SargableQuery::Range(
795                Bound::Excluded(ScalarValue::Int32(Some(4))),
796                Bound::Unbounded,
797            ),
798            &[
799                3, 40, 50, 60, 70, 400, 500, 600, 700, 4000, 5000, 6000, 7000,
800            ],
801        )
802        .await;
803
804        check(
805            &index,
806            SargableQuery::Range(
807                Bound::Excluded(ScalarValue::Int32(Some(4))),
808                Bound::Included(ScalarValue::Int32(Some(11))),
809            ),
810            &[3, 40, 50, 60],
811        )
812        .await;
813
814        check(
815            &index,
816            SargableQuery::Range(
817                Bound::Excluded(ScalarValue::Int32(Some(4))),
818                Bound::Excluded(ScalarValue::Int32(Some(11))),
819            ),
820            &[3, 40],
821        )
822        .await;
823
824        check(
825            &index,
826            SargableQuery::Range(
827                Bound::Excluded(ScalarValue::Int32(Some(-50))),
828                Bound::Excluded(ScalarValue::Int32(Some(1000))),
829            ),
830            &[
831                0, 1, 2, 3, 40, 50, 60, 70, 400, 500, 600, 700, 4000, 5000, 6000, 7000,
832            ],
833        )
834        .await;
835    }
836
837    #[tokio::test]
838    async fn test_btree_types() {
839        for data_type in &[
840            DataType::Boolean,
841            DataType::Int32,
842            DataType::Utf8,
843            DataType::Float32,
844            DataType::Date32,
845            DataType::Timestamp(TimeUnit::Nanosecond, None),
846            DataType::Date64,
847            DataType::Date32,
848            DataType::Time64(TimeUnit::Nanosecond),
849            DataType::Time32(TimeUnit::Second),
850            DataType::FixedSizeBinary(16),
851            // Not supported today, error from datafusion:
852            // Min/max accumulator not implemented for Duration(Nanosecond)
853            // DataType::Duration(TimeUnit::Nanosecond),
854        ] {
855            let tempdir = TempDir::default();
856            let index_store = test_store(&tempdir);
857            let data: RecordBatch = gen_batch()
858                .col(VALUE_COLUMN_NAME, array::rand_type(data_type))
859                .col(ROW_ID, array::step::<UInt64Type>())
860                .into_batch_rows(RowCount::from(4096 * 3))
861                .unwrap();
862
863            let sample_value = ScalarValue::try_from_array(data.column(0), 0).unwrap();
864            let sample_row_id = data.column(1).as_primitive::<UInt64Type>().value(0);
865
866            let sort_indices = arrow::compute::sort_to_indices(data.column(0), None, None).unwrap();
867            let sorted_values = arrow_select::take::take(
868                data.column(0),
869                &sort_indices,
870                Some(TakeOptions {
871                    check_bounds: false,
872                }),
873            )
874            .unwrap();
875            let sorted_row_ids = arrow_select::take::take(
876                data.column(1),
877                &sort_indices,
878                Some(TakeOptions {
879                    check_bounds: false,
880                }),
881            )
882            .unwrap();
883            let sorted_batch =
884                RecordBatch::try_new(data.schema().clone(), vec![sorted_values, sorted_row_ids])
885                    .unwrap();
886
887            let batch_one = sorted_batch.slice(0, 4096);
888            let batch_two = sorted_batch.slice(4096, 4096);
889            let batch_three = sorted_batch.slice(8192, 4096);
890            let training_data = RecordBatchIterator::new(
891                vec![batch_one, batch_two, batch_three].into_iter().map(Ok),
892                data.schema().clone(),
893            );
894
895            train_index(&index_store, training_data, None).await;
896            let index = BTreeIndexPlugin
897                .load_index(
898                    index_store,
899                    &default_details::<pbold::BTreeIndexDetails>(),
900                    None,
901                    &LanceCache::no_cache(),
902                )
903                .await
904                .unwrap();
905
906            let result = index
907                .search(&SargableQuery::Equals(sample_value), &NoOpMetricsCollector)
908                .await
909                .unwrap();
910
911            assert!(result.is_exact());
912            let row_addrs = result.row_addrs().true_rows();
913
914            // The random data may have had duplicates so there might be more than 1 result
915            // but even for boolean we shouldn't match the entire thing
916            assert!(!row_addrs.is_empty());
917            assert!(row_addrs.len().unwrap() < data.num_rows() as u64);
918            assert!(row_addrs.contains(sample_row_id));
919        }
920    }
921
922    #[tokio::test]
923    async fn btree_entire_null_page() {
924        let tempdir = TempDir::default();
925        let index_store = test_store(&tempdir);
926        let batch = gen_batch()
927            .col(
928                VALUE_COLUMN_NAME,
929                array::rand_utf8(ByteCount::from(0), false).with_nulls(&[true]),
930            )
931            .col(ROW_ID, array::step::<UInt64Type>())
932            .into_batch_rows(RowCount::from(4096));
933        assert_eq!(
934            batch.as_ref().unwrap()[VALUE_COLUMN_NAME].null_count(),
935            4096
936        );
937        let batches = vec![batch];
938        let schema = Arc::new(Schema::new(vec![
939            Field::new(VALUE_COLUMN_NAME, DataType::Utf8, true),
940            Field::new(ROW_ID, DataType::UInt64, false),
941        ]));
942        let data = RecordBatchIterator::new(batches, schema);
943        let data = lance_datafusion::utils::reader_to_stream(Box::new(data));
944
945        train_btree_index(
946            data,
947            index_store.as_ref(),
948            DEFAULT_BTREE_BATCH_SIZE,
949            None,
950            None,
951        )
952        .await
953        .unwrap();
954
955        let index = BTreeIndexPlugin
956            .load_index(
957                index_store,
958                &default_details::<pbold::BTreeIndexDetails>(),
959                None,
960                &LanceCache::no_cache(),
961            )
962            .await
963            .unwrap();
964
965        let result = index
966            .search(
967                &SargableQuery::Equals(ScalarValue::Utf8(Some("foo".to_string()))),
968                &NoOpMetricsCollector,
969            )
970            .await
971            .unwrap();
972
973        assert!(result.is_exact());
974        let row_addrs = result.row_addrs().true_rows();
975
976        assert!(row_addrs.is_empty());
977
978        let result = index
979            .search(&SargableQuery::IsNull(), &NoOpMetricsCollector)
980            .await
981            .unwrap();
982        assert!(result.is_exact());
983        let row_addrs = result.row_addrs().true_rows();
984        assert_eq!(row_addrs.len(), Some(4096));
985    }
986
987    async fn train_bitmap(
988        index_store: &Arc<dyn IndexStore>,
989        data: impl RecordBatchReader + Send + Sync + 'static,
990    ) {
991        // Sort the data by value column (nulls first) to match the production
992        // scanner behavior (TrainingOrdering::Values).
993        let schema = data.schema();
994        let batches: Vec<_> = data
995            .into_iter()
996            .collect::<std::result::Result<Vec<_>, _>>()
997            .unwrap();
998        let combined = arrow::compute::concat_batches(&schema, &batches).unwrap();
999        let options = arrow::compute::SortOptions {
1000            descending: false,
1001            nulls_first: true,
1002        };
1003        let indices =
1004            arrow::compute::sort_to_indices(combined.column(0), Some(options), None).unwrap();
1005        let sorted_columns: Vec<_> = combined
1006            .columns()
1007            .iter()
1008            .map(|col| arrow::compute::take(col.as_ref(), &indices, None).unwrap())
1009            .collect();
1010        let sorted_batch = RecordBatch::try_new(schema.clone(), sorted_columns).unwrap();
1011        let stream = Box::pin(RecordBatchStreamAdapter::new(
1012            schema,
1013            futures::stream::once(async move { Ok(sorted_batch) }),
1014        ));
1015
1016        let request = BitmapIndexPlugin
1017            .new_training_request("{}", &Field::new(VALUE_COLUMN_NAME, DataType::Int32, false))
1018            .unwrap();
1019        BitmapIndexPlugin
1020            .train_index(
1021                stream,
1022                index_store.as_ref(),
1023                request,
1024                None,
1025                crate::progress::noop_progress(),
1026            )
1027            .await
1028            .unwrap();
1029    }
1030
1031    #[tokio::test]
1032    async fn test_bitmap_working() {
1033        let tempdir = TempDir::default();
1034        let index_store = test_store(&tempdir);
1035
1036        let schema = Arc::new(ArrowSchema::new(vec![
1037            Field::new(VALUE_COLUMN_NAME, DataType::Utf8, true),
1038            Field::new(ROW_ID, DataType::UInt64, false),
1039        ]));
1040
1041        let batch1 = RecordBatch::try_new(
1042            schema.clone(),
1043            vec![
1044                Arc::new(StringArray::from(vec![Some("abcd"), None, Some("abcd")])),
1045                Arc::new(UInt64Array::from(vec![1, 2, 3])),
1046            ],
1047        )
1048        .unwrap();
1049
1050        let batch2 = RecordBatch::try_new(
1051            schema.clone(),
1052            vec![
1053                Arc::new(StringArray::from(vec![
1054                    Some("apple"),
1055                    Some("hello"),
1056                    Some("abcd"),
1057                ])),
1058                Arc::new(UInt64Array::from(vec![4, 5, 6])),
1059            ],
1060        )
1061        .unwrap();
1062
1063        let batches = vec![batch1, batch2];
1064        let data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
1065        train_bitmap(&index_store, data).await;
1066
1067        let index = BitmapIndex::load(index_store, None, &LanceCache::no_cache())
1068            .await
1069            .unwrap();
1070
1071        let result = index
1072            .search(
1073                &SargableQuery::Equals(ScalarValue::Utf8(None)),
1074                &NoOpMetricsCollector,
1075            )
1076            .await
1077            .unwrap();
1078
1079        assert!(result.is_exact());
1080        let row_addrs = result.row_addrs().true_rows();
1081        assert_eq!(Some(1), row_addrs.len());
1082        assert!(row_addrs.contains(2));
1083
1084        let result = index
1085            .search(
1086                &SargableQuery::Equals(ScalarValue::Utf8(Some("abcd".to_string()))),
1087                &NoOpMetricsCollector,
1088            )
1089            .await
1090            .unwrap();
1091
1092        assert!(result.is_exact());
1093        let row_addrs = result.row_addrs().true_rows();
1094        assert_eq!(Some(3), row_addrs.len());
1095        assert!(row_addrs.contains(1));
1096        assert!(row_addrs.contains(3));
1097        assert!(row_addrs.contains(6));
1098    }
1099
1100    #[tokio::test]
1101    async fn test_basic_bitmap() {
1102        let tempdir = TempDir::default();
1103        let index_store = test_store(&tempdir);
1104        let data = gen_batch()
1105            .col(VALUE_COLUMN_NAME, array::step::<Int32Type>())
1106            .col(ROW_ID, array::step::<UInt64Type>())
1107            .into_reader_rows(RowCount::from(4096), BatchCount::from(100));
1108        train_bitmap(&index_store, data).await;
1109        let index = BitmapIndex::load(index_store, None, &LanceCache::no_cache())
1110            .await
1111            .unwrap();
1112
1113        let result = index
1114            .search(
1115                &SargableQuery::Equals(ScalarValue::Int32(Some(10000))),
1116                &NoOpMetricsCollector,
1117            )
1118            .await
1119            .unwrap();
1120
1121        assert!(result.is_exact());
1122        let row_addrs = result.row_addrs().true_rows();
1123        assert_eq!(Some(1), row_addrs.len());
1124        assert!(row_addrs.contains(10000));
1125
1126        let result = index
1127            .search(
1128                &SargableQuery::Range(
1129                    Bound::Unbounded,
1130                    Bound::Excluded(ScalarValue::Int32(Some(-100))),
1131                ),
1132                &NoOpMetricsCollector,
1133            )
1134            .await
1135            .unwrap();
1136
1137        assert!(result.is_exact());
1138        let row_addrs = result.row_addrs().true_rows();
1139        assert!(row_addrs.is_empty());
1140
1141        let result = index
1142            .search(
1143                &SargableQuery::Range(
1144                    Bound::Unbounded,
1145                    Bound::Excluded(ScalarValue::Int32(Some(100))),
1146                ),
1147                &NoOpMetricsCollector,
1148            )
1149            .await
1150            .unwrap();
1151
1152        assert!(result.is_exact());
1153        let row_addrs = result.row_addrs().true_rows();
1154        assert_eq!(Some(100), row_addrs.len());
1155    }
1156
1157    async fn check_bitmap(index: &BitmapIndex, query: SargableQuery, expected: &[u64]) {
1158        let results = index.search(&query, &NoOpMetricsCollector).await.unwrap();
1159        assert!(results.is_exact());
1160        let expected_arr = RowAddrTreeMap::from_iter(expected);
1161        assert_eq!(&results.row_addrs().true_rows(), &expected_arr);
1162    }
1163
1164    #[tokio::test]
1165    async fn test_bitmap_with_gaps() {
1166        let tempdir = TempDir::default();
1167        let index_store = test_store(&tempdir);
1168        let batch_one = gen_batch()
1169            .col(
1170                VALUE_COLUMN_NAME,
1171                array::cycle::<Int32Type>(vec![0, 1, 4, 5]),
1172            )
1173            .col(ROW_ID, array::cycle::<UInt64Type>(vec![0, 1, 2, 3]))
1174            .into_batch_rows(RowCount::from(4));
1175        let batch_two = gen_batch()
1176            .col(
1177                VALUE_COLUMN_NAME,
1178                array::cycle::<Int32Type>(vec![10, 11, 11, 15]),
1179            )
1180            .col(ROW_ID, array::cycle::<UInt64Type>(vec![40, 50, 60, 70]))
1181            .into_batch_rows(RowCount::from(4));
1182        let batch_three = gen_batch()
1183            .col(
1184                VALUE_COLUMN_NAME,
1185                array::cycle::<Int32Type>(vec![15, 15, 15, 15]),
1186            )
1187            .col(ROW_ID, array::cycle::<UInt64Type>(vec![400, 500, 600, 700]))
1188            .into_batch_rows(RowCount::from(4));
1189        let batch_four = gen_batch()
1190            .col(
1191                VALUE_COLUMN_NAME,
1192                array::cycle::<Int32Type>(vec![15, 16, 20, 20]),
1193            )
1194            .col(
1195                ROW_ID,
1196                array::cycle::<UInt64Type>(vec![4000, 5000, 6000, 7000]),
1197            )
1198            .into_batch_rows(RowCount::from(4));
1199        let batches = vec![batch_one, batch_two, batch_three, batch_four];
1200        let schema = Arc::new(Schema::new(vec![
1201            Field::new(VALUE_COLUMN_NAME, DataType::Int32, false),
1202            Field::new(ROW_ID, DataType::UInt64, false),
1203        ]));
1204        let data = RecordBatchIterator::new(batches, schema);
1205        train_bitmap(&index_store, data).await;
1206        let index = BitmapIndex::load(index_store, None, &LanceCache::no_cache())
1207            .await
1208            .unwrap();
1209
1210        // The above should create four pages
1211        //
1212        // 0 - 5
1213        // 10 - 15
1214        // 15 - 15
1215        // 15 - 20
1216        //
1217        // This will help us test various indexing corner cases
1218
1219        // No results (off the left side)
1220        check_bitmap(
1221            &index,
1222            SargableQuery::Equals(ScalarValue::Int32(Some(-3))),
1223            &[],
1224        )
1225        .await;
1226
1227        check_bitmap(
1228            &index,
1229            SargableQuery::Range(
1230                Bound::Unbounded,
1231                Bound::Included(ScalarValue::Int32(Some(-3))),
1232            ),
1233            &[],
1234        )
1235        .await;
1236
1237        check_bitmap(
1238            &index,
1239            SargableQuery::Range(
1240                Bound::Included(ScalarValue::Int32(Some(-10))),
1241                Bound::Included(ScalarValue::Int32(Some(-3))),
1242            ),
1243            &[],
1244        )
1245        .await;
1246
1247        // Hitting the middle of a bucket
1248        check_bitmap(
1249            &index,
1250            SargableQuery::Equals(ScalarValue::Int32(Some(4))),
1251            &[2],
1252        )
1253        .await;
1254
1255        // Hitting a gap between two buckets
1256        check_bitmap(
1257            &index,
1258            SargableQuery::Equals(ScalarValue::Int32(Some(7))),
1259            &[],
1260        )
1261        .await;
1262
1263        // Hitting the lowest of the overlapping buckets
1264        check_bitmap(
1265            &index,
1266            SargableQuery::Equals(ScalarValue::Int32(Some(11))),
1267            &[50, 60],
1268        )
1269        .await;
1270
1271        // Hitting the 15 shared on all three buckets
1272        check_bitmap(
1273            &index,
1274            SargableQuery::Equals(ScalarValue::Int32(Some(15))),
1275            &[70, 400, 500, 600, 700, 4000],
1276        )
1277        .await;
1278
1279        // Hitting the upper part of the three overlapping buckets
1280        check_bitmap(
1281            &index,
1282            SargableQuery::Equals(ScalarValue::Int32(Some(20))),
1283            &[6000, 7000],
1284        )
1285        .await;
1286
1287        // Ranges that capture multiple buckets
1288        check_bitmap(
1289            &index,
1290            SargableQuery::Range(
1291                Bound::Unbounded,
1292                Bound::Included(ScalarValue::Int32(Some(11))),
1293            ),
1294            &[0, 1, 2, 3, 40, 50, 60],
1295        )
1296        .await;
1297
1298        check_bitmap(
1299            &index,
1300            SargableQuery::Range(
1301                Bound::Unbounded,
1302                Bound::Excluded(ScalarValue::Int32(Some(11))),
1303            ),
1304            &[0, 1, 2, 3, 40],
1305        )
1306        .await;
1307
1308        check_bitmap(
1309            &index,
1310            SargableQuery::Range(
1311                Bound::Included(ScalarValue::Int32(Some(4))),
1312                Bound::Unbounded,
1313            ),
1314            &[
1315                2, 3, 40, 50, 60, 70, 400, 500, 600, 700, 4000, 5000, 6000, 7000,
1316            ],
1317        )
1318        .await;
1319
1320        check_bitmap(
1321            &index,
1322            SargableQuery::Range(
1323                Bound::Included(ScalarValue::Int32(Some(4))),
1324                Bound::Included(ScalarValue::Int32(Some(11))),
1325            ),
1326            &[2, 3, 40, 50, 60],
1327        )
1328        .await;
1329
1330        check_bitmap(
1331            &index,
1332            SargableQuery::Range(
1333                Bound::Included(ScalarValue::Int32(Some(4))),
1334                Bound::Excluded(ScalarValue::Int32(Some(11))),
1335            ),
1336            &[2, 3, 40],
1337        )
1338        .await;
1339
1340        check_bitmap(
1341            &index,
1342            SargableQuery::Range(
1343                Bound::Excluded(ScalarValue::Int32(Some(4))),
1344                Bound::Unbounded,
1345            ),
1346            &[
1347                3, 40, 50, 60, 70, 400, 500, 600, 700, 4000, 5000, 6000, 7000,
1348            ],
1349        )
1350        .await;
1351
1352        check_bitmap(
1353            &index,
1354            SargableQuery::Range(
1355                Bound::Excluded(ScalarValue::Int32(Some(4))),
1356                Bound::Included(ScalarValue::Int32(Some(11))),
1357            ),
1358            &[3, 40, 50, 60],
1359        )
1360        .await;
1361
1362        check_bitmap(
1363            &index,
1364            SargableQuery::Range(
1365                Bound::Excluded(ScalarValue::Int32(Some(4))),
1366                Bound::Excluded(ScalarValue::Int32(Some(11))),
1367            ),
1368            &[3, 40],
1369        )
1370        .await;
1371
1372        check_bitmap(
1373            &index,
1374            SargableQuery::Range(
1375                Bound::Excluded(ScalarValue::Int32(Some(-50))),
1376                Bound::Excluded(ScalarValue::Int32(Some(1000))),
1377            ),
1378            &[
1379                0, 1, 2, 3, 40, 50, 60, 70, 400, 500, 600, 700, 4000, 5000, 6000, 7000,
1380            ],
1381        )
1382        .await;
1383    }
1384
1385    #[tokio::test]
1386    async fn test_bitmap_update() {
1387        let index_dir = TempDir::default();
1388        let index_store = test_store(&index_dir);
1389        let data = gen_batch()
1390            .col(VALUE_COLUMN_NAME, array::step::<Int32Type>())
1391            .col(ROW_ID, array::step::<UInt64Type>())
1392            .into_reader_rows(RowCount::from(4096), BatchCount::from(1));
1393        train_bitmap(&index_store, data).await;
1394        let index = BitmapIndex::load(index_store, None, &LanceCache::no_cache())
1395            .await
1396            .unwrap();
1397
1398        let data = gen_batch()
1399            .col(VALUE_COLUMN_NAME, array::step_custom::<Int32Type>(4096, 1))
1400            .col(ROW_ID, array::step_custom::<UInt64Type>(4096, 1))
1401            .into_reader_rows(RowCount::from(4096), BatchCount::from(1));
1402
1403        let updated_index_dir = TempDir::default();
1404        let updated_index_store = test_store(&updated_index_dir);
1405        index
1406            .update(
1407                lance_datafusion::utils::reader_to_stream(Box::new(data)),
1408                updated_index_store.as_ref(),
1409                None,
1410            )
1411            .await
1412            .unwrap();
1413        let updated_index = BitmapIndex::load(updated_index_store, None, &LanceCache::no_cache())
1414            .await
1415            .unwrap();
1416
1417        let result = updated_index
1418            .search(
1419                &SargableQuery::Equals(ScalarValue::Int32(Some(5000))),
1420                &NoOpMetricsCollector,
1421            )
1422            .await
1423            .unwrap();
1424
1425        assert!(result.is_exact());
1426        let row_addrs = result.row_addrs().true_rows();
1427        assert_eq!(Some(1), row_addrs.len());
1428        assert!(row_addrs.contains(5000));
1429    }
1430
1431    #[tokio::test]
1432    async fn test_bitmap_remap() {
1433        let index_dir = TempDir::default();
1434        let index_store = test_store(&index_dir);
1435        let data = gen_batch()
1436            .col(VALUE_COLUMN_NAME, array::step::<Int32Type>())
1437            .col(ROW_ID, array::step::<UInt64Type>())
1438            .into_reader_rows(RowCount::from(50), BatchCount::from(1));
1439        train_bitmap(&index_store, data).await;
1440        let index = BitmapIndex::load(index_store, None, &LanceCache::no_cache())
1441            .await
1442            .unwrap();
1443
1444        let mapping = (0..50)
1445            .map(|i| {
1446                let map_result = if i == 5 {
1447                    Some(65)
1448                } else if i == 7 {
1449                    None
1450                } else {
1451                    Some(i)
1452                };
1453                (i, map_result)
1454            })
1455            .collect::<HashMap<_, _>>();
1456
1457        let remapped_dir = TempDir::default();
1458        let remapped_store = test_store(&remapped_dir);
1459        index
1460            .remap(&mapping, remapped_store.as_ref())
1461            .await
1462            .unwrap();
1463        let remapped_index = BitmapIndex::load(remapped_store, None, &LanceCache::no_cache())
1464            .await
1465            .unwrap();
1466
1467        // Remapped to new value
1468        assert!(
1469            remapped_index
1470                .search(
1471                    &SargableQuery::Equals(ScalarValue::Int32(Some(5))),
1472                    &NoOpMetricsCollector
1473                )
1474                .await
1475                .unwrap()
1476                .row_addrs()
1477                .selected(65)
1478        );
1479        // Deleted
1480        assert!(
1481            remapped_index
1482                .search(
1483                    &SargableQuery::Equals(ScalarValue::Int32(Some(7))),
1484                    &NoOpMetricsCollector
1485                )
1486                .await
1487                .unwrap()
1488                .row_addrs()
1489                .is_empty()
1490        );
1491        // Not remapped
1492        assert!(
1493            remapped_index
1494                .search(
1495                    &SargableQuery::Equals(ScalarValue::Int32(Some(3))),
1496                    &NoOpMetricsCollector
1497                )
1498                .await
1499                .unwrap()
1500                .row_addrs()
1501                .selected(3)
1502        );
1503    }
1504
1505    async fn train_tag(
1506        index_store: &Arc<dyn IndexStore>,
1507        data: impl RecordBatchReader + Send + Sync + 'static,
1508    ) {
1509        let data = lance_datafusion::utils::reader_to_stream(Box::new(data));
1510        let request = LabelListIndexPlugin
1511            .new_training_request(
1512                "{}",
1513                &Field::new(
1514                    VALUE_COLUMN_NAME,
1515                    DataType::List(Arc::new(Field::new("item", DataType::UInt8, false))),
1516                    false,
1517                ),
1518            )
1519            .unwrap();
1520        LabelListIndexPlugin
1521            .train_index(
1522                data,
1523                index_store.as_ref(),
1524                request,
1525                None,
1526                crate::progress::noop_progress(),
1527            )
1528            .await
1529            .unwrap();
1530    }
1531
1532    #[tokio::test]
1533    async fn test_label_list_index() {
1534        let tempdir = TempDir::default();
1535        let index_store = test_store(&tempdir);
1536        let data = gen_batch()
1537            .col(
1538                VALUE_COLUMN_NAME,
1539                array::rand_type(&DataType::List(Arc::new(Field::new(
1540                    "item",
1541                    DataType::UInt8,
1542                    false,
1543                )))),
1544            )
1545            .col(ROW_ID, array::step::<UInt64Type>())
1546            .into_batch_rows(RowCount::from(40960))
1547            .unwrap();
1548
1549        let batch_reader = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
1550
1551        // This is probably enough data that we can be assured each tag is used at least once
1552        train_tag(&index_store, batch_reader).await;
1553
1554        // We scan through each list, if it was a match we run match_fn to check
1555        // if the match was correct if it was not a match we run no_match_fn to check
1556        // if the no-match was correct
1557        type MatchFn = Box<dyn Fn(&ScalarBuffer<u8>) -> bool>;
1558        let check = |query: LabelListQuery, match_fn: MatchFn, no_match_fn: MatchFn| {
1559            let index_store = index_store.clone();
1560            let data = data.clone();
1561            async move {
1562                let index = LabelListIndexPlugin
1563                    .load_index(
1564                        index_store,
1565                        &default_details::<pbold::LabelListIndexDetails>(),
1566                        None,
1567                        &LanceCache::no_cache(),
1568                    )
1569                    .await
1570                    .unwrap();
1571                let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
1572                assert!(result.is_exact());
1573                let row_addrs = result.row_addrs().true_rows();
1574
1575                let row_addrs_set = row_addrs
1576                    .row_addrs()
1577                    .unwrap()
1578                    .map(u64::from)
1579                    .collect::<std::collections::HashSet<_>>();
1580
1581                for (list, row_id) in data
1582                    .column(0)
1583                    .as_list::<i32>()
1584                    .iter()
1585                    .zip(data.column(1).as_primitive::<UInt64Type>())
1586                {
1587                    let list = list.unwrap();
1588                    let row_id = row_id.unwrap();
1589                    let vals = list.as_primitive::<UInt8Type>().values();
1590                    if row_addrs_set.contains(&row_id) {
1591                        assert!(match_fn(vals));
1592                    } else {
1593                        assert!(no_match_fn(vals));
1594                    }
1595                }
1596            }
1597        };
1598
1599        // Simple check for 1 value (doesn't matter intersection vs union)
1600        check(
1601            LabelListQuery::HasAnyLabel(vec![ScalarValue::UInt8(Some(1))]),
1602            Box::new(|vals| vals.contains(&1)),
1603            Box::new(|vals| !vals.contains(&1)),
1604        )
1605        .await;
1606        check(
1607            LabelListQuery::HasAllLabels(vec![ScalarValue::UInt8(Some(1))]),
1608            Box::new(|vals| vals.contains(&1)),
1609            Box::new(|vals| !vals.contains(&1)),
1610        )
1611        .await;
1612        // Set intersection
1613        check(
1614            LabelListQuery::HasAllLabels(vec![
1615                ScalarValue::UInt8(Some(1)),
1616                ScalarValue::UInt8(Some(2)),
1617            ]),
1618            // Match must have 1 and 2
1619            Box::new(|vals| vals.contains(&1) && vals.contains(&2)),
1620            // No-match must either not have 1 or not have 2
1621            Box::new(|vals| !vals.contains(&1) || !vals.contains(&2)),
1622        )
1623        .await;
1624        // Set union
1625        check(
1626            LabelListQuery::HasAnyLabel(vec![
1627                ScalarValue::UInt8(Some(1)),
1628                ScalarValue::UInt8(Some(2)),
1629            ]),
1630            // Match either have 1 or have 2
1631            Box::new(|vals| vals.contains(&1) || vals.contains(&2)),
1632            // No-match must not have 1 and not have 2
1633            Box::new(|vals| !vals.contains(&1) && !vals.contains(&2)),
1634        )
1635        .await;
1636    }
1637
1638    #[tokio::test]
1639    async fn test_label_list_null_handling() {
1640        let tempdir = TempDir::default();
1641        let index_store = test_store(&tempdir);
1642
1643        // Create test data with null items within lists:
1644        // Row 0: [1, 2] - no nulls
1645        // Row 1: [3, null] - has a null item
1646        // Row 2: [4] - no nulls
1647        let list_array = ListArray::from_iter_primitive::<UInt8Type, _, _>(vec![
1648            Some(vec![Some(1), Some(2)]),
1649            Some(vec![Some(3), None]),
1650            Some(vec![Some(4)]),
1651        ]);
1652        let row_ids = UInt64Array::from_iter_values(0..3);
1653        // Create schema with nullable list items to match the ListArray
1654        let schema = Arc::new(Schema::new(vec![
1655            Field::new(
1656                VALUE_COLUMN_NAME,
1657                DataType::List(Arc::new(Field::new("item", DataType::UInt8, true))),
1658                true,
1659            ),
1660            Field::new(ROW_ID, DataType::UInt64, false),
1661        ]));
1662        let batch = RecordBatch::try_new(
1663            schema.clone(),
1664            vec![Arc::new(list_array), Arc::new(row_ids)],
1665        )
1666        .unwrap();
1667
1668        let batch_reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
1669        train_tag(&index_store, batch_reader).await;
1670
1671        let index = LabelListIndexPlugin
1672            .load_index(
1673                index_store,
1674                &default_details::<pbold::LabelListIndexDetails>(),
1675                None,
1676                &LanceCache::no_cache(),
1677            )
1678            .await
1679            .unwrap();
1680
1681        // Test: Search for lists containing value 1
1682        // Row 0: [1, 2] - contains 1 → TRUE
1683        // Row 1: [3, null] - null elements are ignored → FALSE
1684        // Row 2: [4] - doesn't contain 1 → FALSE
1685        let query = LabelListQuery::HasAnyLabel(vec![ScalarValue::UInt8(Some(1))]);
1686        let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
1687
1688        match result {
1689            SearchResult::Exact(row_ids) => {
1690                let actual_rows: Vec<u64> = row_ids
1691                    .true_rows()
1692                    .row_addrs()
1693                    .unwrap()
1694                    .map(u64::from)
1695                    .collect();
1696                assert_eq!(
1697                    actual_rows,
1698                    vec![0],
1699                    "Should find row 0 where list contains 1"
1700                );
1701
1702                assert!(
1703                    row_ids.null_rows().is_empty(),
1704                    "null_row_ids should be empty when null elements are ignored"
1705                );
1706            }
1707            _ => panic!("Expected Exact search result"),
1708        }
1709    }
1710
1711    #[tokio::test]
1712    async fn test_label_list_bitmap_only_layout_is_compatible() {
1713        let tempdir = TempDir::default();
1714        let index_store = test_store(&tempdir);
1715
1716        // Simulate an older released layout that only had the bitmap lookup file.
1717        let values = arrow_array::UInt8Array::from(vec![1, 2]);
1718        let row_ids = UInt64Array::from(vec![0, 2]);
1719        let schema = Arc::new(Schema::new(vec![
1720            Field::new(VALUE_COLUMN_NAME, DataType::UInt8, true),
1721            Field::new(ROW_ID, DataType::UInt64, false),
1722        ]));
1723        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(values), Arc::new(row_ids)])
1724            .unwrap();
1725
1726        BitmapIndexPlugin::train_bitmap_index(
1727            lance_datafusion::utils::reader_to_stream(Box::new(RecordBatchIterator::new(
1728                vec![Ok(batch)],
1729                schema,
1730            ))),
1731            index_store.as_ref(),
1732        )
1733        .await
1734        .unwrap();
1735
1736        let index = LabelListIndexPlugin
1737            .load_index(
1738                index_store,
1739                &default_details::<pbold::LabelListIndexDetails>(),
1740                None,
1741                &LanceCache::no_cache(),
1742            )
1743            .await
1744            .unwrap();
1745
1746        let query = LabelListQuery::HasAnyLabel(vec![ScalarValue::UInt8(Some(1))]);
1747        let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
1748
1749        match result {
1750            SearchResult::Exact(row_ids) => {
1751                assert!(row_ids.null_rows().is_empty());
1752                let actual_rows: Vec<u64> = row_ids
1753                    .true_rows()
1754                    .row_addrs()
1755                    .unwrap()
1756                    .map(u64::from)
1757                    .collect();
1758                assert_eq!(actual_rows, vec![0]);
1759            }
1760            _ => panic!("Expected Exact search result"),
1761        }
1762    }
1763}