Skip to main content

lance_index/scalar/
rtree.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use crate::frag_reuse::FragReuseIndex;
5use crate::metrics::{MetricsCollector, NoOpMetricsCollector};
6use crate::scalar::expression::{GeoQueryParser, ScalarQueryParser};
7use crate::scalar::lance_format::LanceIndexStore;
8use crate::scalar::registry::{
9    ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest,
10};
11use crate::scalar::rtree::sort::Sorter;
12use crate::scalar::{
13    AnyQuery, BuiltinIndexType, CreatedIndex, GeoQuery, IndexReader, IndexReaderStream, IndexStore,
14    IndexWriter, ScalarIndex, ScalarIndexParams, SearchResult, UpdateCriteria,
15};
16use crate::vector::VectorIndex;
17use crate::{pb, Index, IndexType};
18use arrow_array::cast::AsArray;
19use arrow_array::types::UInt64Type;
20use arrow_array::UInt32Array;
21use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array};
22use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
23use async_trait::async_trait;
24use datafusion::execution::SendableRecordBatchStream;
25use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
26use datafusion_common::DataFusionError;
27use deepsize::DeepSizeOf;
28use futures::{stream, StreamExt, TryFutureExt, TryStreamExt};
29use geoarrow_array::array::{from_arrow_array, RectArray};
30use geoarrow_array::builder::RectBuilder;
31use geoarrow_array::{GeoArrowArray, GeoArrowArrayAccessor, IntoArrow};
32use geoarrow_schema::{Dimension, RectType};
33use lance_arrow::RecordBatchExt;
34use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
35use lance_core::utils::address::RowAddress;
36use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap, RowSetOps};
37use lance_core::utils::tempfile::TempDir;
38use lance_core::{Error, Result, ROW_ID};
39use lance_datafusion::chunker::chunk_concat_stream;
40pub use lance_geo::bbox::{bounding_box, total_bounds, BoundingBox};
41use lance_io::object_store::ObjectStore;
42use roaring::RoaringBitmap;
43use serde::{Deserialize, Serialize};
44use snafu::location;
45use sort::hilbert_sort::HilbertSorter;
46use std::any::Any;
47use std::collections::HashMap;
48use std::ops::Range;
49use std::sync::{Arc, LazyLock};
50
51mod sort;
52
53pub const DEFAULT_RTREE_PAGE_SIZE: u32 = 4096;
54const RTREE_INDEX_VERSION: u32 = 0;
55const RTREE_PAGES_NAME: &str = "page_data.lance";
56const RTREE_NULLS_NAME: &str = "nulls.lance";
57
58static BBOX_FIELD: LazyLock<Arc<ArrowField>> = LazyLock::new(|| {
59    let bbox_type = RectType::new(Dimension::XY, Default::default());
60    Arc::new(bbox_type.to_field("bbox", false))
61});
62static BBOX_ROWID_SCHEMA: LazyLock<Arc<ArrowSchema>> = LazyLock::new(|| {
63    let rowid_field = ArrowField::new(ROW_ID, DataType::UInt64, false);
64    Arc::new(ArrowSchema::new(vec![
65        BBOX_FIELD.clone(),
66        rowid_field.into(),
67    ]))
68});
69static RTREE_PAGE_SCHEMA: LazyLock<Arc<ArrowSchema>> = LazyLock::new(|| {
70    let id_field = ArrowField::new("id", DataType::UInt64, false);
71    Arc::new(ArrowSchema::new(vec![BBOX_FIELD.clone(), id_field.into()]))
72});
73
74static RTREE_NULLS_SCHEMA: LazyLock<Arc<ArrowSchema>> = LazyLock::new(|| {
75    Arc::new(ArrowSchema::new(vec![ArrowField::new(
76        "nulls",
77        DataType::Binary,
78        false,
79    )]))
80});
81
82#[derive(Debug, Clone, Serialize)]
83pub struct RTreeMetadata {
84    pub(crate) page_size: u32,
85    pub(crate) num_pages: u64,
86    pub(crate) num_items: usize,
87    pub(crate) bbox: BoundingBox,
88    pub(crate) page_offsets: Vec<usize>,
89}
90
91impl RTreeMetadata {
92    pub fn new(page_size: u32, num_pages: u64, num_items: usize, bbox: BoundingBox) -> Self {
93        let page_offsets = Self::calculate_page_offsets(num_items, page_size);
94        debug_assert_eq!(page_offsets.len(), num_pages as usize);
95        Self {
96            page_size,
97            num_pages,
98            num_items,
99            bbox,
100            page_offsets,
101        }
102    }
103
104    fn calculate_page_offsets(num_items: usize, page_size: u32) -> Vec<usize> {
105        let mut page_offsets = vec![];
106        let mut cur_level_items = num_items;
107        let mut cur_offset = 0;
108        while cur_level_items > 0 {
109            if cur_level_items <= page_size as usize {
110                page_offsets.push(cur_offset);
111                break;
112            }
113            for off in (0..cur_level_items).step_by(page_size as usize) {
114                page_offsets.push(cur_offset + off);
115            }
116            cur_offset += cur_level_items;
117            cur_level_items = cur_level_items.div_ceil(page_size as usize);
118        }
119
120        page_offsets
121    }
122
123    fn into_map(self) -> HashMap<String, String> {
124        HashMap::from_iter(vec![
125            ("page_size".to_owned(), self.page_size.to_string()),
126            ("num_pages".to_owned(), self.num_pages.to_string()),
127            ("num_items".to_owned(), self.num_items.to_string()),
128            ("bbox".to_owned(), serde_json::json!(self.bbox).to_string()),
129        ])
130    }
131}
132
133impl From<&HashMap<String, String>> for RTreeMetadata {
134    fn from(metadata: &HashMap<String, String>) -> Self {
135        let page_size = metadata
136            .get("page_size")
137            .map(|bs| bs.parse().unwrap_or(DEFAULT_RTREE_PAGE_SIZE))
138            .unwrap_or(DEFAULT_RTREE_PAGE_SIZE);
139        let num_pages = metadata
140            .get("num_pages")
141            .map(|bs| bs.parse().unwrap_or(0))
142            .unwrap_or(0);
143        let num_items = metadata
144            .get("num_items")
145            .map(|bs| bs.parse().unwrap_or(0))
146            .unwrap_or(0);
147        let bbox = metadata
148            .get("bbox")
149            .map(|bs| serde_json::from_str(bs).unwrap_or_default())
150            .unwrap_or_default();
151        Self::new(page_size, num_pages, num_items, bbox)
152    }
153}
154
155/// Extract bounding boxes from geometry columns
156pub fn extract_bounding_boxes(
157    geometry_array: &dyn Array,
158    geometry_field: &ArrowField,
159) -> Result<RectArray> {
160    let geo_array = from_arrow_array(geometry_array, geometry_field).map_err(|e| Error::Index {
161        message: format!("Construct GeoArrowArray from an Arrow Array failed: {}", e),
162        location: location!(),
163    })?;
164    let rect_array = bounding_box(geo_array.as_ref())?;
165
166    Ok(rect_array)
167}
168
169struct BboxStreamStats {
170    null_map: RowAddrTreeMap,
171    total_bbox: BoundingBox,
172    // Number of non-null items
173    num_items: usize,
174}
175
176#[derive(Debug, Clone)]
177pub enum RTreeCacheKey {
178    Page(u64),
179    Nulls,
180}
181
182#[derive(Debug)]
183pub struct RTreeCacheValue(Arc<RecordBatch>);
184
185impl DeepSizeOf for RTreeCacheValue {
186    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
187        self.0.get_array_memory_size()
188    }
189}
190
191impl CacheKey for RTreeCacheKey {
192    type ValueType = RTreeCacheValue;
193
194    fn key(&self) -> std::borrow::Cow<'_, str> {
195        match self {
196            Self::Page(page_id) => format!("page-{}", page_id).into(),
197            Self::Nulls => "nulls".into(),
198        }
199    }
200}
201
202#[derive(Clone)]
203pub struct RTreeIndex {
204    pub(crate) metadata: Arc<RTreeMetadata>,
205    store: Arc<dyn IndexStore>,
206    frag_reuse_index: Option<Arc<FragReuseIndex>>,
207    index_cache: WeakLanceCache,
208    pages_reader: Arc<dyn IndexReader>,
209    nulls_reader: Arc<dyn IndexReader>,
210}
211
212impl std::fmt::Debug for RTreeIndex {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        f.debug_struct("RTreeIndex")
215            .field("metadata", &self.metadata)
216            .field("store", &self.store)
217            .finish()
218    }
219}
220
221impl RTreeIndex {
222    pub async fn load(
223        store: Arc<dyn IndexStore>,
224        frag_reuse_index: Option<Arc<FragReuseIndex>>,
225        index_cache: &LanceCache,
226    ) -> Result<Arc<Self>> {
227        let pages_reader = store.open_index_file(RTREE_PAGES_NAME).await?;
228        let metadata = RTreeMetadata::from(&pages_reader.schema().metadata);
229        let nulls_reader = store.open_index_file(RTREE_NULLS_NAME).await?;
230
231        Ok(Arc::new(Self {
232            metadata: Arc::new(metadata),
233            store,
234            frag_reuse_index,
235            index_cache: WeakLanceCache::from(index_cache),
236            pages_reader,
237            nulls_reader,
238        }))
239    }
240
241    async fn page_range(&self, page_idx: u64) -> Result<Range<usize>> {
242        let start = match self.metadata.page_offsets.get(page_idx as usize) {
243            None => self.pages_reader.num_rows(),
244            Some(start) => *start,
245        };
246        let end = match self.metadata.page_offsets.get((page_idx + 1) as usize) {
247            None => self.pages_reader.num_rows(),
248            Some(end) => *end,
249        };
250        Ok(start..end)
251    }
252
253    async fn search_bbox(
254        &self,
255        bbox: BoundingBox,
256        metrics: &dyn MetricsCollector,
257    ) -> Result<RowAddrTreeMap> {
258        if self.metadata.num_items == 0 || !self.metadata.bbox.rect_intersects(&bbox) {
259            return Ok(RowAddrTreeMap::default());
260        }
261
262        let mut row_addrs = RowAddrTreeMap::new();
263        let mut stack = vec![self.metadata.num_pages - 1];
264
265        while let Some(page_idx) = stack.pop() {
266            let range = self.page_range(page_idx).await?;
267            let is_leaf = range.start < self.metadata.num_items;
268            let batch = self
269                .index_cache
270                .get_or_insert_with_key(RTreeCacheKey::Page(page_idx), move || async move {
271                    let batch = self.pages_reader.read_range(range, None).await?;
272                    metrics.record_part_load();
273                    Ok(RTreeCacheValue(Arc::new(batch)))
274                })
275                .await
276                .map(|v| v.0.clone())?;
277
278            let bbox_array =
279                extract_bounding_boxes(batch.column(0).as_ref(), batch.schema().field(0))?;
280            let rowaddr_or_pageid_array = batch
281                .column(1)
282                .as_any()
283                .downcast_ref::<UInt64Array>()
284                .unwrap();
285
286            for i in 0..bbox_array.len() {
287                let rect = bbox_array.value(i).unwrap();
288                if bbox.rect_intersects(&rect) {
289                    if is_leaf {
290                        let row_addr = rowaddr_or_pageid_array.value(i);
291                        row_addrs.insert(row_addr);
292                    } else {
293                        let page_id = rowaddr_or_pageid_array.value(i);
294                        stack.push(page_id);
295                    }
296                }
297            }
298        }
299
300        Ok(row_addrs)
301    }
302
303    async fn search_null(&self, metrics: &dyn MetricsCollector) -> Result<RowAddrTreeMap> {
304        let batch = self
305            .index_cache
306            .get_or_insert_with_key(RTreeCacheKey::Nulls, move || async move {
307                // Only one row
308                let batch = self.nulls_reader.read_range(0..1, None).await?;
309                metrics.record_part_load();
310                Ok(RTreeCacheValue(Arc::new(batch)))
311            })
312            .await
313            .map(|v| v.0.clone())?;
314
315        let null_map = match batch.num_rows() {
316            0 => RowAddrTreeMap::default(),
317            1 => {
318                let bytes = batch
319                    .column(0)
320                    .as_any()
321                    .downcast_ref::<BinaryArray>()
322                    .unwrap()
323                    .value(0);
324                RowAddrTreeMap::deserialize_from(bytes)?
325            }
326            _ => {
327                unreachable!()
328            }
329        };
330        Ok(null_map)
331    }
332
333    /// Create a stream of all the data in the index, in the format (bbox, row_id)
334    async fn into_data_stream(self) -> Result<SendableRecordBatchStream> {
335        let reader = self.store.open_index_file(RTREE_PAGES_NAME).await?;
336        let reader_stream = IndexReaderStream::new_with_limit(
337            reader,
338            self.metadata.page_size as u64,
339            self.metadata.num_items as u64,
340        )
341        .await;
342        let batches = reader_stream
343            .map(|fut| {
344                fut.map_ok(|batch| {
345                    RecordBatch::try_new(BBOX_ROWID_SCHEMA.clone(), batch.columns().into()).unwrap()
346                })
347            })
348            .map(|fut| fut.map_err(DataFusionError::from))
349            .buffered(self.store.io_parallelism())
350            .boxed();
351        Ok(Box::pin(RecordBatchStreamAdapter::new(
352            BBOX_ROWID_SCHEMA.clone(),
353            batches,
354        )))
355    }
356
357    async fn combine_old_new(
358        self,
359        new_input: SendableRecordBatchStream,
360    ) -> Result<SendableRecordBatchStream> {
361        let old_input = self.into_data_stream().await?;
362        debug_assert_eq!(
363            old_input.schema().flattened_fields().len(),
364            new_input.schema().flattened_fields().len()
365        );
366
367        let merged = futures::stream::select(old_input, new_input);
368
369        Ok(Box::pin(RecordBatchStreamAdapter::new(
370            BBOX_ROWID_SCHEMA.clone(),
371            merged,
372        )))
373    }
374}
375
376impl DeepSizeOf for RTreeIndex {
377    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
378        let mut total_size = 0;
379
380        total_size += self.store.deep_size_of_children(context);
381
382        total_size
383    }
384}
385
386#[async_trait]
387impl Index for RTreeIndex {
388    fn as_any(&self) -> &dyn Any {
389        self
390    }
391
392    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
393        self
394    }
395
396    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn VectorIndex>> {
397        Err(Error::NotSupported {
398            source: "RTreeIndex is not vector index".into(),
399            location: location!(),
400        })
401    }
402
403    fn statistics(&self) -> Result<serde_json::Value> {
404        serde_json::to_value(self.metadata.clone()).map_err(|e| Error::Internal {
405            message: format!("Error serializing statistics: {}", e),
406            location: location!(),
407        })
408    }
409
410    async fn prewarm(&self) -> Result<()> {
411        for page_id in 0..self.metadata.num_pages {
412            let range = self.page_range(page_id).await?;
413            let batch = Arc::new(self.pages_reader.read_range(range, None).await?);
414            self.index_cache
415                .insert_with_key(
416                    &RTreeCacheKey::Page(page_id),
417                    Arc::new(RTreeCacheValue(batch.clone())),
418                )
419                .await;
420        }
421
422        let batch = self.nulls_reader.read_range(0..1, None).await?;
423        self.index_cache
424            .insert_with_key(
425                &RTreeCacheKey::Nulls,
426                Arc::new(RTreeCacheValue(Arc::new(batch))),
427            )
428            .await;
429
430        Ok(())
431    }
432
433    fn index_type(&self) -> IndexType {
434        IndexType::RTree
435    }
436
437    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
438        let mut frag_ids = RoaringBitmap::default();
439
440        let mut reader_stream = self.clone().into_data_stream().await?;
441        while let Some(page) = reader_stream.try_next().await? {
442            let mut page_frag_ids = page
443                .column(1)
444                .as_primitive::<UInt64Type>()
445                .iter()
446                .flatten()
447                .map(|row_addr| RowAddress::from(row_addr).fragment_id())
448                .collect::<Vec<_>>();
449            page_frag_ids.sort();
450            page_frag_ids.dedup();
451            frag_ids |= RoaringBitmap::from_sorted_iter(page_frag_ids).unwrap();
452        }
453        Ok(frag_ids)
454    }
455}
456
457#[async_trait]
458impl ScalarIndex for RTreeIndex {
459    async fn search(
460        &self,
461        query: &dyn AnyQuery,
462        metrics: &dyn MetricsCollector,
463    ) -> Result<SearchResult> {
464        let query = query.as_any().downcast_ref::<GeoQuery>().unwrap();
465        match query {
466            GeoQuery::IntersectQuery(query) => {
467                let geo_array =
468                    extract_bounding_boxes(query.value.to_array()?.as_ref(), &query.field)?;
469                let bbox = total_bounds(&geo_array)?;
470                let mut rowids = self.search_bbox(bbox, metrics).await?;
471                let mut null_map = self.search_null(metrics).await?;
472
473                if let Some(fri) = &self.frag_reuse_index {
474                    rowids = fri.remap_row_addrs_tree_map(&rowids);
475                    null_map = fri.remap_row_addrs_tree_map(&null_map);
476                }
477                Ok(SearchResult::AtMost(NullableRowAddrSet::new(
478                    rowids, null_map,
479                )))
480            }
481            GeoQuery::IsNull => {
482                let mut null_map = self.search_null(metrics).await?;
483
484                if let Some(fri) = &self.frag_reuse_index {
485                    null_map = fri.remap_row_addrs_tree_map(&null_map);
486                }
487                Ok(SearchResult::Exact(NullableRowAddrSet::new(
488                    null_map,
489                    RowAddrTreeMap::default(),
490                )))
491            }
492        }
493    }
494
495    fn can_remap(&self) -> bool {
496        false
497    }
498
499    async fn remap(
500        &self,
501        _mapping: &HashMap<u64, Option<u64>>,
502        _dest_store: &dyn IndexStore,
503    ) -> Result<CreatedIndex> {
504        Err(Error::InvalidInput {
505            source: "RTree does not support remap".into(),
506            location: location!(),
507        })
508    }
509
510    async fn update(
511        &self,
512        new_data: SendableRecordBatchStream,
513        dest_store: &dyn IndexStore,
514    ) -> Result<CreatedIndex> {
515        let bbox_data = RTreeIndexPlugin::convert_bbox_stream(new_data)?;
516        let tmpdir = Arc::new(TempDir::default());
517        let spill_store = Arc::new(LanceIndexStore::new(
518            Arc::new(ObjectStore::local()),
519            tmpdir.obj_path(),
520            Arc::new(LanceCache::no_cache()),
521        ));
522        let (new_bbox_data, stats) = RTreeIndexPlugin::process_and_analyze_bbox_stream(
523            bbox_data,
524            self.metadata.page_size,
525            spill_store.clone(),
526        )
527        .await?;
528
529        let merged_bbox_data = self.clone().combine_old_new(new_bbox_data).await?;
530
531        let null_map = self.search_null(&NoOpMetricsCollector).await?;
532
533        let mut new_bbox = BoundingBox::new();
534        new_bbox.add_rect(&stats.total_bbox);
535        new_bbox.add_rect(&self.metadata.bbox);
536
537        let merge_stats = BboxStreamStats {
538            null_map: RowAddrTreeMap::union_all(&[&null_map, &stats.null_map]),
539            total_bbox: new_bbox,
540            num_items: self.metadata.num_items + stats.num_items,
541        };
542
543        RTreeIndexPlugin::train_rtree_index(
544            merged_bbox_data,
545            merge_stats,
546            self.metadata.page_size,
547            dest_store,
548        )
549        .await?;
550
551        Ok(CreatedIndex {
552            index_details: prost_types::Any::from_msg(&pb::RTreeIndexDetails::default())?,
553            index_version: RTREE_INDEX_VERSION,
554        })
555    }
556
557    fn update_criteria(&self) -> UpdateCriteria {
558        UpdateCriteria::only_new_data(TrainingCriteria::new(TrainingOrdering::None).with_row_id())
559    }
560
561    fn derive_index_params(&self) -> Result<ScalarIndexParams> {
562        let params = serde_json::to_value(RTreeParameters {
563            page_size: Some(self.metadata.page_size),
564        })?;
565        Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::RTree).with_params(&params))
566    }
567}
568
569/// Parameters for a rtree index
570#[derive(Debug, Serialize, Deserialize, Clone)]
571struct RTreeParameters {
572    /// The number of rows to include in each page
573    pub page_size: Option<u32>,
574}
575
576pub struct RTreeTrainingRequest {
577    parameters: RTreeParameters,
578    criteria: TrainingCriteria,
579}
580
581impl RTreeTrainingRequest {
582    fn new(parameters: RTreeParameters) -> Self {
583        Self {
584            parameters,
585            criteria: TrainingCriteria::new(TrainingOrdering::None).with_row_id(),
586        }
587    }
588}
589
590impl Default for RTreeTrainingRequest {
591    fn default() -> Self {
592        Self::new(RTreeParameters {
593            page_size: Some(DEFAULT_RTREE_PAGE_SIZE),
594        })
595    }
596}
597
598impl TrainingRequest for RTreeTrainingRequest {
599    fn as_any(&self) -> &dyn Any {
600        self
601    }
602
603    fn criteria(&self) -> &TrainingCriteria {
604        &self.criteria
605    }
606}
607
608#[derive(Debug, Default)]
609pub struct RTreeIndexPlugin;
610
611impl RTreeIndexPlugin {
612    fn validate_schema(schema: &ArrowSchema) -> Result<()> {
613        if schema.fields().len() != 2 {
614            return Err(Error::InvalidInput {
615                source: "RTree index schema must have exactly two fields".into(),
616                location: location!(),
617            });
618        }
619
620        let row_id_field = schema.field_with_name(ROW_ID)?;
621        if *row_id_field.data_type() != DataType::UInt64 {
622            return Err(Error::InvalidInput {
623                source: "Second field in RTree index schema must be of type UInt64".into(),
624                location: location!(),
625            });
626        }
627        Ok(())
628    }
629
630    fn convert_bbox_stream(source: SendableRecordBatchStream) -> Result<SendableRecordBatchStream> {
631        let bbox_stream = source
632            .map_err(DataFusionError::into)
633            .and_then(move |batch| async move {
634                let schema = batch.schema();
635                let geometry_field = schema.field(0);
636                let geometry_array = batch.column(0);
637                let bbox_array = extract_bounding_boxes(geometry_array, geometry_field)?;
638
639                let bbox_schema = Arc::new(ArrowSchema::new(vec![
640                    bbox_array.extension_type().clone().to_field("bbox", true),
641                    ArrowField::new(ROW_ID, DataType::UInt64, false),
642                ]));
643                RecordBatch::try_new(
644                    bbox_schema,
645                    vec![bbox_array.into_array_ref(), batch.column(1).clone()],
646                )
647                .map_err(DataFusionError::from)
648            });
649
650        Ok(Box::pin(RecordBatchStreamAdapter::new(
651            BBOX_ROWID_SCHEMA.clone(),
652            bbox_stream,
653        )))
654    }
655
656    /// Processes a bounding box data stream, separating null and non-null elements, and collects
657    /// statistics about non-null elements.
658    async fn process_and_analyze_bbox_stream(
659        mut data: SendableRecordBatchStream,
660        page_size: u32,
661        spill_store: Arc<LanceIndexStore>,
662    ) -> Result<(SendableRecordBatchStream, BboxStreamStats)> {
663        let mut null_rowaddrs = RowAddrTreeMap::new();
664        let mut total_bbox = BoundingBox::new();
665        let mut num_non_null_rows = 0;
666
667        let schema = data.schema();
668
669        let mut writer = spill_store
670            .new_index_file("analyze.tmp", BBOX_ROWID_SCHEMA.clone())
671            .await?;
672
673        while let Some(batch) = data.try_next().await? {
674            let bbox_array = extract_bounding_boxes(&batch.column(0), batch.schema().field(0))?;
675            let rowaddr_array = batch
676                .column(1)
677                .as_any()
678                .downcast_ref::<UInt64Array>()
679                .unwrap();
680
681            total_bbox.add_geo_arrow_array(&bbox_array)?;
682
683            let num_rows = bbox_array.len();
684
685            let mut non_null_indexes = vec![];
686
687            for i in 0..num_rows {
688                if bbox_array.is_null(i) {
689                    let rowaddr = rowaddr_array.value(i);
690                    null_rowaddrs.insert(rowaddr);
691                } else {
692                    non_null_indexes.push(i as u32);
693                }
694            }
695
696            let new_batch = if non_null_indexes.is_empty() {
697                // all nulls, skip write
698                continue;
699            } else if non_null_indexes.len() == num_rows {
700                batch
701            } else {
702                batch.take(&UInt32Array::from(non_null_indexes))?
703            };
704
705            num_non_null_rows += new_batch.num_rows();
706            writer.write_record_batch(new_batch).await?;
707        }
708        writer.finish().await?;
709        let reader = spill_store.open_index_file("analyze.tmp").await?;
710        let stream = IndexReaderStream::new(reader, page_size as u64)
711            .await
712            .map(|fut| fut.map_err(DataFusionError::from))
713            .buffered(spill_store.io_parallelism())
714            .boxed();
715        let new_data = RecordBatchStreamAdapter::new(schema.clone(), stream);
716
717        Ok((
718            Box::pin(new_data),
719            BboxStreamStats {
720                null_map: null_rowaddrs,
721                total_bbox,
722                num_items: num_non_null_rows,
723            },
724        ))
725    }
726
727    async fn train_rtree_page(
728        batch: RecordBatch,
729        page_id: u64,
730        writer: &mut dyn IndexWriter,
731    ) -> Result<EncodedBatch> {
732        let geo_array = extract_bounding_boxes(batch.column(0).as_ref(), batch.schema().field(0))?;
733        let bbox = total_bounds(&geo_array)?;
734        let new_batch = RecordBatch::try_new(
735            RTREE_PAGE_SCHEMA.clone(),
736            vec![batch.column(0).clone(), batch.column(1).clone()],
737        )?;
738        writer.write_record_batch(new_batch).await?;
739        Ok(EncodedBatch { bbox, page_id })
740    }
741
742    fn encoded_batches_into_batch_stream(
743        batches: Vec<EncodedBatch>,
744        batch_size: u32,
745    ) -> SendableRecordBatchStream {
746        let batches = batches
747            .chunks(batch_size as usize)
748            .map(|chunk| {
749                let bbox_type = RectType::new(Dimension::XY, Default::default());
750                let mut bbox_builder = RectBuilder::with_capacity(bbox_type, chunk.len());
751                let mut page_ids = UInt64Array::builder(chunk.len());
752
753                for item in chunk {
754                    bbox_builder.push_rect(Some(&item.bbox));
755                    page_ids.append_value(item.page_id);
756                }
757
758                RecordBatch::try_new(
759                    RTREE_PAGE_SCHEMA.clone(),
760                    vec![
761                        bbox_builder.finish().into_array_ref(),
762                        Arc::new(page_ids.finish()),
763                    ],
764                )
765                .unwrap()
766            })
767            .collect::<Vec<_>>();
768
769        Box::pin(RecordBatchStreamAdapter::new(
770            RTREE_PAGE_SCHEMA.clone(),
771            stream::iter(batches).map(Ok).boxed(),
772        ))
773    }
774
775    pub async fn write_index(
776        sorted_data: SendableRecordBatchStream,
777        num_items: usize,
778        total_bbox: BoundingBox,
779        store: &dyn IndexStore,
780        page_size: u32,
781    ) -> Result<()> {
782        let mut page_idx: u64 = 0;
783        let mut writer = store
784            .new_index_file(RTREE_PAGES_NAME, RTREE_PAGE_SCHEMA.clone())
785            .await?;
786
787        if num_items > 0 {
788            let mut current_level = Some((sorted_data, num_items));
789            while let Some((mut data, num_items)) = current_level.take() {
790                if num_items <= page_size as usize {
791                    while let Some(batch) = data.try_next().await? {
792                        Self::train_rtree_page(batch, page_idx, writer.as_mut()).await?;
793                        page_idx += 1;
794                    }
795                } else {
796                    let mut next_level = vec![];
797                    let mut paged_source = chunk_concat_stream(data, page_size as usize);
798                    while let Some(batch) = paged_source.try_next().await? {
799                        let encoded_batch =
800                            Self::train_rtree_page(batch, page_idx, writer.as_mut()).await?;
801                        page_idx += 1;
802                        next_level.push(encoded_batch);
803                    }
804                    if !next_level.is_empty() {
805                        let next_num_items = next_level.len();
806                        current_level = Some((
807                            Self::encoded_batches_into_batch_stream(next_level, page_size),
808                            next_num_items,
809                        ));
810                    }
811                }
812            }
813        }
814
815        writer
816            .finish_with_metadata(
817                RTreeMetadata::new(page_size, page_idx, num_items, total_bbox).into_map(),
818            )
819            .await?;
820
821        Ok(())
822    }
823
824    pub async fn write_nulls(store: &dyn IndexStore, null_map: RowAddrTreeMap) -> Result<()> {
825        let mut writer = store
826            .new_index_file(RTREE_NULLS_NAME, RTREE_NULLS_SCHEMA.clone())
827            .await?;
828        let mut bytes = Vec::new();
829        null_map.serialize_into(&mut bytes)?;
830        let batch = RecordBatch::try_new(
831            RTREE_NULLS_SCHEMA.clone(),
832            vec![Arc::new(BinaryArray::from_vec(vec![&bytes]))],
833        )?;
834
835        writer.write_record_batch(batch).await?;
836        writer.finish().await
837    }
838
839    async fn train_rtree_index(
840        bbox_data: SendableRecordBatchStream,
841        stats: BboxStreamStats,
842        page_size: u32,
843        store: &dyn IndexStore,
844    ) -> Result<()> {
845        // new sorted stream
846        let sorter = HilbertSorter::new(stats.total_bbox);
847        let sorted_data = sorter.sort(bbox_data).await?;
848
849        Self::write_index(
850            sorted_data,
851            stats.num_items,
852            stats.total_bbox,
853            store,
854            page_size,
855        )
856        .await?;
857
858        Self::write_nulls(store, stats.null_map).await?;
859
860        Ok(())
861    }
862}
863
864#[async_trait]
865impl ScalarIndexPlugin for RTreeIndexPlugin {
866    fn name(&self) -> &str {
867        "RTree"
868    }
869
870    fn new_training_request(
871        &self,
872        params: &str,
873        _field: &ArrowField,
874    ) -> Result<Box<dyn TrainingRequest>> {
875        let params = serde_json::from_str::<RTreeParameters>(params)?;
876        Ok(Box::new(RTreeTrainingRequest::new(params)))
877    }
878
879    async fn train_index(
880        &self,
881        data: SendableRecordBatchStream,
882        index_store: &dyn IndexStore,
883        request: Box<dyn TrainingRequest>,
884        fragment_ids: Option<Vec<u32>>,
885    ) -> Result<CreatedIndex> {
886        if fragment_ids.is_some() {
887            return Err(Error::InvalidInput {
888                source: "RTree index does not support fragment training".into(),
889                location: location!(),
890            });
891        }
892
893        Self::validate_schema(&data.schema())?;
894
895        let request = request
896            .as_any()
897            .downcast_ref::<RTreeTrainingRequest>()
898            .unwrap();
899        let page_size = request
900            .parameters
901            .page_size
902            .unwrap_or(DEFAULT_RTREE_PAGE_SIZE);
903
904        let bbox_data = Self::convert_bbox_stream(data)?;
905        let tmpdir = Arc::new(TempDir::default());
906        let spill_store = Arc::new(LanceIndexStore::new(
907            Arc::new(ObjectStore::local()),
908            tmpdir.obj_path(),
909            Arc::new(LanceCache::no_cache()),
910        ));
911        let (bbox_data, stats) =
912            Self::process_and_analyze_bbox_stream(bbox_data, page_size, spill_store.clone())
913                .await?;
914
915        Self::train_rtree_index(bbox_data, stats, page_size, index_store).await?;
916
917        Ok(CreatedIndex {
918            index_details: prost_types::Any::from_msg(&pb::RTreeIndexDetails::default())?,
919            index_version: RTREE_INDEX_VERSION,
920        })
921    }
922
923    fn provides_exact_answer(&self) -> bool {
924        false
925    }
926
927    fn version(&self) -> u32 {
928        RTREE_INDEX_VERSION
929    }
930
931    fn new_query_parser(
932        &self,
933        index_name: String,
934        _index_details: &prost_types::Any,
935    ) -> Option<Box<dyn ScalarQueryParser>> {
936        Some(Box::new(GeoQueryParser::new(index_name)))
937    }
938
939    async fn load_index(
940        &self,
941        index_store: Arc<dyn IndexStore>,
942        _index_details: &prost_types::Any,
943        frag_reuse_index: Option<Arc<FragReuseIndex>>,
944        cache: &LanceCache,
945    ) -> Result<Arc<dyn ScalarIndex>> {
946        Ok(RTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc<dyn ScalarIndex>)
947    }
948}
949
950struct EncodedBatch {
951    bbox: BoundingBox,
952    page_id: u64,
953}
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958    use crate::metrics::NoOpMetricsCollector;
959    use crate::scalar::registry::VALUE_COLUMN_NAME;
960    use arrow_array::ArrayRef;
961    use arrow_schema::Schema;
962    use geo_types::{coord, Rect};
963    use geoarrow_array::builder::{PointBuilder, RectBuilder};
964    use geoarrow_schema::{Dimension, PointType, RectType};
965    use lance_core::utils::tempfile::TempObjDir;
966    use rand::Rng;
967
968    fn expected_num_pages(num_items: usize, page_size: u32) -> u64 {
969        RTreeMetadata::calculate_page_offsets(num_items, page_size).len() as u64
970    }
971
972    fn convert_bbox_rowid_batch_stream(
973        geo_array: &dyn GeoArrowArray,
974        row_id_array: ArrayRef,
975    ) -> SendableRecordBatchStream {
976        let schema = Arc::new(Schema::new(vec![
977            geo_array.data_type().to_field(VALUE_COLUMN_NAME, true),
978            ArrowField::new(ROW_ID, DataType::UInt64, false),
979        ]));
980
981        let batch =
982            RecordBatch::try_new(schema.clone(), vec![geo_array.to_array_ref(), row_id_array])
983                .unwrap();
984
985        let stream = stream::once(async move { Ok(batch) });
986        Box::pin(RecordBatchStreamAdapter::new(schema, stream))
987    }
988
989    async fn train_index(
990        geo_array: &dyn GeoArrowArray,
991        page_size: Option<u32>,
992    ) -> (Arc<RTreeIndex>, Arc<LanceIndexStore>, TempObjDir) {
993        let page_size = page_size.unwrap_or(DEFAULT_RTREE_PAGE_SIZE);
994        let mut num_items = 0;
995        for i in 0..geo_array.len() {
996            if !geo_array.is_null(i) {
997                num_items += 1;
998            }
999        }
1000
1001        let tmpdir = TempObjDir::default();
1002        let store = Arc::new(LanceIndexStore::new(
1003            Arc::new(ObjectStore::local()),
1004            tmpdir.clone(),
1005            Arc::new(LanceCache::no_cache()),
1006        ));
1007
1008        let stream = convert_bbox_rowid_batch_stream(
1009            geo_array,
1010            Arc::new(UInt64Array::from(
1011                (0..geo_array.len() as u64).collect::<Vec<_>>(),
1012            )),
1013        );
1014
1015        let plugin = RTreeIndexPlugin;
1016        plugin
1017            .train_index(
1018                stream,
1019                store.as_ref(),
1020                Box::new(RTreeTrainingRequest::new(RTreeParameters {
1021                    page_size: Some(page_size),
1022                })),
1023                None,
1024            )
1025            .await
1026            .unwrap();
1027
1028        let pages_reader = store.open_index_file(RTREE_PAGES_NAME).await.unwrap();
1029        let metadata = RTreeMetadata::from(&pages_reader.schema().metadata);
1030        assert_eq!(metadata.num_items, num_items);
1031        assert_eq!(metadata.num_pages, expected_num_pages(num_items, page_size));
1032
1033        (
1034            RTreeIndex::load(store.clone(), None, &LanceCache::no_cache())
1035                .await
1036                .unwrap(),
1037            store,
1038            tmpdir,
1039        )
1040    }
1041
1042    #[tokio::test]
1043    async fn test_search_bbox() {
1044        let bbox_type = RectType::new(Dimension::XY, Default::default());
1045
1046        let mut rng = rand::rng();
1047        let mut rect_builder = RectBuilder::new(bbox_type.clone());
1048        let num_items = 10000;
1049        let page_size = 16;
1050
1051        for _ in 0..num_items {
1052            let x1 = rng.random_range(-1000.0..1000.0);
1053            let y1 = rng.random_range(-1000.0..1000.0);
1054            let x2 = rng.random_range(x1..x1 + 10.0);
1055            let y2 = rng.random_range(y1..y1 + 10.0);
1056
1057            rect_builder.push_rect(Some(&Rect::new(
1058                coord! { x: x1, y: y1 },
1059                coord! { x: x2, y: y2 },
1060            )));
1061        }
1062        let rect_arr = rect_builder.finish();
1063
1064        let (rtree_index, _store, _tmpdir) = train_index(&rect_arr, Some(page_size)).await;
1065
1066        let mut search_bbox = BoundingBox::new();
1067        search_bbox.add_rect(&Rect::new(
1068            coord! { x: 10.5, y: 1.5 },
1069            coord! { x: 99.5, y: 200.5 },
1070        ));
1071        let row_ids = rtree_index
1072            .search_bbox(search_bbox, &NoOpMetricsCollector)
1073            .await
1074            .unwrap();
1075
1076        let mut expected_row_ids = RowAddrTreeMap::new();
1077        for i in 0..rect_arr.len() {
1078            let mut bbox = BoundingBox::new();
1079            bbox.add_rect(&rect_arr.value(i).unwrap());
1080            if search_bbox.rect_intersects(&bbox) {
1081                expected_row_ids.insert(i as u64);
1082            }
1083        }
1084        assert_eq!(row_ids, expected_row_ids);
1085    }
1086
1087    #[tokio::test]
1088    async fn test_search_null() {
1089        let point_type = PointType::new(Dimension::XY, Default::default());
1090
1091        let mut rng = rand::rng();
1092        let num_points = 10000;
1093        let null_probability = 0.001; // 0.1%
1094
1095        let mut expected_nulls = Vec::new();
1096        let mut point_builder = PointBuilder::new(point_type.clone());
1097
1098        for i in 0..num_points {
1099            if rng.random_bool(null_probability) {
1100                point_builder.push_null();
1101                expected_nulls.push(RowAddress::new_from_parts(0, i as u32));
1102            } else {
1103                let x = rng.random_range(-1000.0..1000.0);
1104                let y = rng.random_range(-1000.0..1000.0);
1105                point_builder.push_point(Some(&geo_types::point!(x: x, y: y)));
1106            }
1107        }
1108        let point_arr = point_builder.finish();
1109
1110        let (rtree_index, _store, _tmpdir) = train_index(&point_arr, None).await;
1111        let row_addrs = rtree_index
1112            .search_null(&NoOpMetricsCollector)
1113            .await
1114            .unwrap();
1115
1116        let mut actual_nulls = row_addrs.row_addrs().unwrap().collect::<Vec<_>>();
1117        actual_nulls.sort();
1118        expected_nulls.sort();
1119
1120        assert_eq!(actual_nulls, expected_nulls);
1121    }
1122
1123    #[tokio::test]
1124    async fn test_update_and_search() {
1125        fn gen_data(num_items: u32, frag_id: u32, nulls_addrs: &mut RowAddrTreeMap) -> RectArray {
1126            let bbox_type = RectType::new(Dimension::XY, Default::default());
1127
1128            let mut rng = rand::rng();
1129            let null_probability = 0.001;
1130            let mut rect_builder = RectBuilder::new(bbox_type);
1131
1132            for i in 0..num_items {
1133                if rng.random_bool(null_probability) {
1134                    rect_builder.push_null();
1135                    nulls_addrs.insert(RowAddress::new_from_parts(frag_id, i).into());
1136                } else {
1137                    let x1 = rng.random_range(-1000.0..1000.0);
1138                    let y1 = rng.random_range(-1000.0..1000.0);
1139                    let x2 = rng.random_range(x1..x1 + 10.0);
1140                    let y2 = rng.random_range(y1..y1 + 10.0);
1141
1142                    rect_builder.push_rect(Some(&Rect::new(
1143                        coord! { x: x1, y: y1 },
1144                        coord! { x: x2, y: y2 },
1145                    )));
1146                }
1147            }
1148            rect_builder.finish()
1149        }
1150
1151        let mut nulls_addrs = RowAddrTreeMap::default();
1152
1153        let frag_id = 0;
1154        let rect_arr = gen_data(10000, frag_id, &mut nulls_addrs);
1155
1156        let (rtree_index, _store, _tmpdir) = train_index(&rect_arr, Some(16)).await;
1157
1158        let tmpdir = TempObjDir::default();
1159        let new_store = Arc::new(LanceIndexStore::new(
1160            Arc::new(ObjectStore::local()),
1161            tmpdir.clone(),
1162            Arc::new(LanceCache::no_cache()),
1163        ));
1164
1165        let new_frag_id = 1;
1166        let new_rect_arr = gen_data(10000, 1, &mut nulls_addrs);
1167        let new_rowaddr_arr = (0..new_rect_arr.len())
1168            .map(|off| RowAddress::new_from_parts(new_frag_id, off as u32).into())
1169            .collect::<Vec<_>>();
1170        let stream = convert_bbox_rowid_batch_stream(
1171            &new_rect_arr,
1172            Arc::new(UInt64Array::from(new_rowaddr_arr.clone())),
1173        );
1174        rtree_index
1175            .update(stream, new_store.as_ref())
1176            .await
1177            .unwrap();
1178
1179        let new_rtree_index = RTreeIndex::load(new_store.clone(), None, &LanceCache::no_cache())
1180            .await
1181            .unwrap();
1182
1183        let mut search_bbox = BoundingBox::new();
1184        search_bbox.add_rect(&Rect::new(
1185            coord! { x: 10.5, y: 1.5 },
1186            coord! { x: 99.5, y: 200.5 },
1187        ));
1188        let row_addrs = new_rtree_index
1189            .search_bbox(search_bbox, &NoOpMetricsCollector)
1190            .await
1191            .unwrap();
1192
1193        let mut expected_row_addrs = RowAddrTreeMap::new();
1194        for i in 0..rect_arr.len() {
1195            if !rect_arr.is_null(i) {
1196                let bbox = BoundingBox::new_with_rect(&rect_arr.value(i).unwrap());
1197                if search_bbox.rect_intersects(&bbox) {
1198                    expected_row_addrs.insert(i as u64);
1199                }
1200            }
1201        }
1202        for i in 0..new_rect_arr.len() {
1203            if !new_rect_arr.is_null(i) {
1204                let bbox = BoundingBox::new_with_rect(&new_rect_arr.value(i).unwrap());
1205                if search_bbox.rect_intersects(&bbox) {
1206                    expected_row_addrs.insert(new_rowaddr_arr.get(i).copied().unwrap());
1207                }
1208            }
1209        }
1210
1211        assert_eq!(row_addrs, expected_row_addrs);
1212
1213        let actual_nulls = new_rtree_index
1214            .search_null(&NoOpMetricsCollector)
1215            .await
1216            .unwrap();
1217        assert_eq!(actual_nulls, nulls_addrs);
1218    }
1219
1220    #[tokio::test]
1221    async fn test_prewarm() {
1222        let point_type = PointType::new(Dimension::XY, Default::default());
1223
1224        let mut rng = rand::rng();
1225        let num_points = 1000;
1226        let null_probability = 0.1;
1227
1228        let mut point_builder = PointBuilder::new(point_type.clone());
1229
1230        for _ in 0..num_points {
1231            if rng.random_bool(null_probability) {
1232                point_builder.push_null();
1233            } else {
1234                let x = rng.random_range(-1000.0..1000.0);
1235                let y = rng.random_range(-1000.0..1000.0);
1236                point_builder.push_point(Some(&geo_types::point!(x: x, y: y)));
1237            }
1238        }
1239        let point_arr = point_builder.finish();
1240
1241        let (_, store, _tmpdir) = train_index(&point_arr, Some(32)).await;
1242
1243        let cache = LanceCache::with_capacity(10 << 20);
1244        let rtree_index = RTreeIndex::load(store, None, &cache).await.unwrap();
1245
1246        // Call prewarm
1247        rtree_index.prewarm().await.unwrap();
1248
1249        for page_id in 0..rtree_index.metadata.num_pages {
1250            assert!(rtree_index
1251                .index_cache
1252                .get_with_key(&RTreeCacheKey::Page(page_id))
1253                .await
1254                .is_some())
1255        }
1256
1257        assert!(rtree_index
1258            .index_cache
1259            .get_with_key(&RTreeCacheKey::Nulls)
1260            .await
1261            .is_some())
1262    }
1263}