1use crate::frag_reuse::FragReuseIndex;
5use crate::metrics::{MetricsCollector, NoOpMetricsCollector};
6use crate::scalar::expression::{GeoQueryParser, ScalarQueryParser};
7use crate::scalar::lance_format::LanceIndexStore;
8use crate::scalar::registry::{
9 ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest,
10};
11use crate::scalar::rtree::sort::Sorter;
12use crate::scalar::{
13 AnyQuery, BuiltinIndexType, CreatedIndex, GeoQuery, IndexReader, IndexReaderStream, IndexStore,
14 IndexWriter, ScalarIndex, ScalarIndexParams, SearchResult, UpdateCriteria,
15};
16use crate::vector::VectorIndex;
17use crate::{pb, Index, IndexType};
18use arrow_array::cast::AsArray;
19use arrow_array::types::UInt64Type;
20use arrow_array::UInt32Array;
21use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array};
22use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
23use async_trait::async_trait;
24use datafusion::execution::SendableRecordBatchStream;
25use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
26use datafusion_common::DataFusionError;
27use deepsize::DeepSizeOf;
28use futures::{stream, StreamExt, TryFutureExt, TryStreamExt};
29use geoarrow_array::array::{from_arrow_array, RectArray};
30use geoarrow_array::builder::RectBuilder;
31use geoarrow_array::{GeoArrowArray, GeoArrowArrayAccessor, IntoArrow};
32use geoarrow_schema::{Dimension, RectType};
33use lance_arrow::RecordBatchExt;
34use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
35use lance_core::utils::address::RowAddress;
36use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap, RowSetOps};
37use lance_core::utils::tempfile::TempDir;
38use lance_core::{Error, Result, ROW_ID};
39use lance_datafusion::chunker::chunk_concat_stream;
40pub use lance_geo::bbox::{bounding_box, total_bounds, BoundingBox};
41use lance_io::object_store::ObjectStore;
42use roaring::RoaringBitmap;
43use serde::{Deserialize, Serialize};
44use snafu::location;
45use sort::hilbert_sort::HilbertSorter;
46use std::any::Any;
47use std::collections::HashMap;
48use std::ops::Range;
49use std::sync::{Arc, LazyLock};
50
51mod sort;
52
53pub const DEFAULT_RTREE_PAGE_SIZE: u32 = 4096;
54const RTREE_INDEX_VERSION: u32 = 0;
55const RTREE_PAGES_NAME: &str = "page_data.lance";
56const RTREE_NULLS_NAME: &str = "nulls.lance";
57
58static BBOX_FIELD: LazyLock<Arc<ArrowField>> = LazyLock::new(|| {
59 let bbox_type = RectType::new(Dimension::XY, Default::default());
60 Arc::new(bbox_type.to_field("bbox", false))
61});
62static BBOX_ROWID_SCHEMA: LazyLock<Arc<ArrowSchema>> = LazyLock::new(|| {
63 let rowid_field = ArrowField::new(ROW_ID, DataType::UInt64, false);
64 Arc::new(ArrowSchema::new(vec![
65 BBOX_FIELD.clone(),
66 rowid_field.into(),
67 ]))
68});
69static RTREE_PAGE_SCHEMA: LazyLock<Arc<ArrowSchema>> = LazyLock::new(|| {
70 let id_field = ArrowField::new("id", DataType::UInt64, false);
71 Arc::new(ArrowSchema::new(vec![BBOX_FIELD.clone(), id_field.into()]))
72});
73
74static RTREE_NULLS_SCHEMA: LazyLock<Arc<ArrowSchema>> = LazyLock::new(|| {
75 Arc::new(ArrowSchema::new(vec![ArrowField::new(
76 "nulls",
77 DataType::Binary,
78 false,
79 )]))
80});
81
82#[derive(Debug, Clone, Serialize)]
83pub struct RTreeMetadata {
84 pub(crate) page_size: u32,
85 pub(crate) num_pages: u64,
86 pub(crate) num_items: usize,
87 pub(crate) bbox: BoundingBox,
88 pub(crate) page_offsets: Vec<usize>,
89}
90
91impl RTreeMetadata {
92 pub fn new(page_size: u32, num_pages: u64, num_items: usize, bbox: BoundingBox) -> Self {
93 let page_offsets = Self::calculate_page_offsets(num_items, page_size);
94 debug_assert_eq!(page_offsets.len(), num_pages as usize);
95 Self {
96 page_size,
97 num_pages,
98 num_items,
99 bbox,
100 page_offsets,
101 }
102 }
103
104 fn calculate_page_offsets(num_items: usize, page_size: u32) -> Vec<usize> {
105 let mut page_offsets = vec![];
106 let mut cur_level_items = num_items;
107 let mut cur_offset = 0;
108 while cur_level_items > 0 {
109 if cur_level_items <= page_size as usize {
110 page_offsets.push(cur_offset);
111 break;
112 }
113 for off in (0..cur_level_items).step_by(page_size as usize) {
114 page_offsets.push(cur_offset + off);
115 }
116 cur_offset += cur_level_items;
117 cur_level_items = cur_level_items.div_ceil(page_size as usize);
118 }
119
120 page_offsets
121 }
122
123 fn into_map(self) -> HashMap<String, String> {
124 HashMap::from_iter(vec![
125 ("page_size".to_owned(), self.page_size.to_string()),
126 ("num_pages".to_owned(), self.num_pages.to_string()),
127 ("num_items".to_owned(), self.num_items.to_string()),
128 ("bbox".to_owned(), serde_json::json!(self.bbox).to_string()),
129 ])
130 }
131}
132
133impl From<&HashMap<String, String>> for RTreeMetadata {
134 fn from(metadata: &HashMap<String, String>) -> Self {
135 let page_size = metadata
136 .get("page_size")
137 .map(|bs| bs.parse().unwrap_or(DEFAULT_RTREE_PAGE_SIZE))
138 .unwrap_or(DEFAULT_RTREE_PAGE_SIZE);
139 let num_pages = metadata
140 .get("num_pages")
141 .map(|bs| bs.parse().unwrap_or(0))
142 .unwrap_or(0);
143 let num_items = metadata
144 .get("num_items")
145 .map(|bs| bs.parse().unwrap_or(0))
146 .unwrap_or(0);
147 let bbox = metadata
148 .get("bbox")
149 .map(|bs| serde_json::from_str(bs).unwrap_or_default())
150 .unwrap_or_default();
151 Self::new(page_size, num_pages, num_items, bbox)
152 }
153}
154
155pub fn extract_bounding_boxes(
157 geometry_array: &dyn Array,
158 geometry_field: &ArrowField,
159) -> Result<RectArray> {
160 let geo_array = from_arrow_array(geometry_array, geometry_field).map_err(|e| Error::Index {
161 message: format!("Construct GeoArrowArray from an Arrow Array failed: {}", e),
162 location: location!(),
163 })?;
164 let rect_array = bounding_box(geo_array.as_ref())?;
165
166 Ok(rect_array)
167}
168
169struct BboxStreamStats {
170 null_map: RowAddrTreeMap,
171 total_bbox: BoundingBox,
172 num_items: usize,
174}
175
176#[derive(Debug, Clone)]
177pub enum RTreeCacheKey {
178 Page(u64),
179 Nulls,
180}
181
182#[derive(Debug)]
183pub struct RTreeCacheValue(Arc<RecordBatch>);
184
185impl DeepSizeOf for RTreeCacheValue {
186 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
187 self.0.get_array_memory_size()
188 }
189}
190
191impl CacheKey for RTreeCacheKey {
192 type ValueType = RTreeCacheValue;
193
194 fn key(&self) -> std::borrow::Cow<'_, str> {
195 match self {
196 Self::Page(page_id) => format!("page-{}", page_id).into(),
197 Self::Nulls => "nulls".into(),
198 }
199 }
200}
201
202#[derive(Clone)]
203pub struct RTreeIndex {
204 pub(crate) metadata: Arc<RTreeMetadata>,
205 store: Arc<dyn IndexStore>,
206 frag_reuse_index: Option<Arc<FragReuseIndex>>,
207 index_cache: WeakLanceCache,
208 pages_reader: Arc<dyn IndexReader>,
209 nulls_reader: Arc<dyn IndexReader>,
210}
211
212impl std::fmt::Debug for RTreeIndex {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 f.debug_struct("RTreeIndex")
215 .field("metadata", &self.metadata)
216 .field("store", &self.store)
217 .finish()
218 }
219}
220
221impl RTreeIndex {
222 pub async fn load(
223 store: Arc<dyn IndexStore>,
224 frag_reuse_index: Option<Arc<FragReuseIndex>>,
225 index_cache: &LanceCache,
226 ) -> Result<Arc<Self>> {
227 let pages_reader = store.open_index_file(RTREE_PAGES_NAME).await?;
228 let metadata = RTreeMetadata::from(&pages_reader.schema().metadata);
229 let nulls_reader = store.open_index_file(RTREE_NULLS_NAME).await?;
230
231 Ok(Arc::new(Self {
232 metadata: Arc::new(metadata),
233 store,
234 frag_reuse_index,
235 index_cache: WeakLanceCache::from(index_cache),
236 pages_reader,
237 nulls_reader,
238 }))
239 }
240
241 async fn page_range(&self, page_idx: u64) -> Result<Range<usize>> {
242 let start = match self.metadata.page_offsets.get(page_idx as usize) {
243 None => self.pages_reader.num_rows(),
244 Some(start) => *start,
245 };
246 let end = match self.metadata.page_offsets.get((page_idx + 1) as usize) {
247 None => self.pages_reader.num_rows(),
248 Some(end) => *end,
249 };
250 Ok(start..end)
251 }
252
253 async fn search_bbox(
254 &self,
255 bbox: BoundingBox,
256 metrics: &dyn MetricsCollector,
257 ) -> Result<RowAddrTreeMap> {
258 if self.metadata.num_items == 0 || !self.metadata.bbox.rect_intersects(&bbox) {
259 return Ok(RowAddrTreeMap::default());
260 }
261
262 let mut row_addrs = RowAddrTreeMap::new();
263 let mut stack = vec![self.metadata.num_pages - 1];
264
265 while let Some(page_idx) = stack.pop() {
266 let range = self.page_range(page_idx).await?;
267 let is_leaf = range.start < self.metadata.num_items;
268 let batch = self
269 .index_cache
270 .get_or_insert_with_key(RTreeCacheKey::Page(page_idx), move || async move {
271 let batch = self.pages_reader.read_range(range, None).await?;
272 metrics.record_part_load();
273 Ok(RTreeCacheValue(Arc::new(batch)))
274 })
275 .await
276 .map(|v| v.0.clone())?;
277
278 let bbox_array =
279 extract_bounding_boxes(batch.column(0).as_ref(), batch.schema().field(0))?;
280 let rowaddr_or_pageid_array = batch
281 .column(1)
282 .as_any()
283 .downcast_ref::<UInt64Array>()
284 .unwrap();
285
286 for i in 0..bbox_array.len() {
287 let rect = bbox_array.value(i).unwrap();
288 if bbox.rect_intersects(&rect) {
289 if is_leaf {
290 let row_addr = rowaddr_or_pageid_array.value(i);
291 row_addrs.insert(row_addr);
292 } else {
293 let page_id = rowaddr_or_pageid_array.value(i);
294 stack.push(page_id);
295 }
296 }
297 }
298 }
299
300 Ok(row_addrs)
301 }
302
303 async fn search_null(&self, metrics: &dyn MetricsCollector) -> Result<RowAddrTreeMap> {
304 let batch = self
305 .index_cache
306 .get_or_insert_with_key(RTreeCacheKey::Nulls, move || async move {
307 let batch = self.nulls_reader.read_range(0..1, None).await?;
309 metrics.record_part_load();
310 Ok(RTreeCacheValue(Arc::new(batch)))
311 })
312 .await
313 .map(|v| v.0.clone())?;
314
315 let null_map = match batch.num_rows() {
316 0 => RowAddrTreeMap::default(),
317 1 => {
318 let bytes = batch
319 .column(0)
320 .as_any()
321 .downcast_ref::<BinaryArray>()
322 .unwrap()
323 .value(0);
324 RowAddrTreeMap::deserialize_from(bytes)?
325 }
326 _ => {
327 unreachable!()
328 }
329 };
330 Ok(null_map)
331 }
332
333 async fn into_data_stream(self) -> Result<SendableRecordBatchStream> {
335 let reader = self.store.open_index_file(RTREE_PAGES_NAME).await?;
336 let reader_stream = IndexReaderStream::new_with_limit(
337 reader,
338 self.metadata.page_size as u64,
339 self.metadata.num_items as u64,
340 )
341 .await;
342 let batches = reader_stream
343 .map(|fut| {
344 fut.map_ok(|batch| {
345 RecordBatch::try_new(BBOX_ROWID_SCHEMA.clone(), batch.columns().into()).unwrap()
346 })
347 })
348 .map(|fut| fut.map_err(DataFusionError::from))
349 .buffered(self.store.io_parallelism())
350 .boxed();
351 Ok(Box::pin(RecordBatchStreamAdapter::new(
352 BBOX_ROWID_SCHEMA.clone(),
353 batches,
354 )))
355 }
356
357 async fn combine_old_new(
358 self,
359 new_input: SendableRecordBatchStream,
360 ) -> Result<SendableRecordBatchStream> {
361 let old_input = self.into_data_stream().await?;
362 debug_assert_eq!(
363 old_input.schema().flattened_fields().len(),
364 new_input.schema().flattened_fields().len()
365 );
366
367 let merged = futures::stream::select(old_input, new_input);
368
369 Ok(Box::pin(RecordBatchStreamAdapter::new(
370 BBOX_ROWID_SCHEMA.clone(),
371 merged,
372 )))
373 }
374}
375
376impl DeepSizeOf for RTreeIndex {
377 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
378 let mut total_size = 0;
379
380 total_size += self.store.deep_size_of_children(context);
381
382 total_size
383 }
384}
385
386#[async_trait]
387impl Index for RTreeIndex {
388 fn as_any(&self) -> &dyn Any {
389 self
390 }
391
392 fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
393 self
394 }
395
396 fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn VectorIndex>> {
397 Err(Error::NotSupported {
398 source: "RTreeIndex is not vector index".into(),
399 location: location!(),
400 })
401 }
402
403 fn statistics(&self) -> Result<serde_json::Value> {
404 serde_json::to_value(self.metadata.clone()).map_err(|e| Error::Internal {
405 message: format!("Error serializing statistics: {}", e),
406 location: location!(),
407 })
408 }
409
410 async fn prewarm(&self) -> Result<()> {
411 for page_id in 0..self.metadata.num_pages {
412 let range = self.page_range(page_id).await?;
413 let batch = Arc::new(self.pages_reader.read_range(range, None).await?);
414 self.index_cache
415 .insert_with_key(
416 &RTreeCacheKey::Page(page_id),
417 Arc::new(RTreeCacheValue(batch.clone())),
418 )
419 .await;
420 }
421
422 let batch = self.nulls_reader.read_range(0..1, None).await?;
423 self.index_cache
424 .insert_with_key(
425 &RTreeCacheKey::Nulls,
426 Arc::new(RTreeCacheValue(Arc::new(batch))),
427 )
428 .await;
429
430 Ok(())
431 }
432
433 fn index_type(&self) -> IndexType {
434 IndexType::RTree
435 }
436
437 async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
438 let mut frag_ids = RoaringBitmap::default();
439
440 let mut reader_stream = self.clone().into_data_stream().await?;
441 while let Some(page) = reader_stream.try_next().await? {
442 let mut page_frag_ids = page
443 .column(1)
444 .as_primitive::<UInt64Type>()
445 .iter()
446 .flatten()
447 .map(|row_addr| RowAddress::from(row_addr).fragment_id())
448 .collect::<Vec<_>>();
449 page_frag_ids.sort();
450 page_frag_ids.dedup();
451 frag_ids |= RoaringBitmap::from_sorted_iter(page_frag_ids).unwrap();
452 }
453 Ok(frag_ids)
454 }
455}
456
457#[async_trait]
458impl ScalarIndex for RTreeIndex {
459 async fn search(
460 &self,
461 query: &dyn AnyQuery,
462 metrics: &dyn MetricsCollector,
463 ) -> Result<SearchResult> {
464 let query = query.as_any().downcast_ref::<GeoQuery>().unwrap();
465 match query {
466 GeoQuery::IntersectQuery(query) => {
467 let geo_array =
468 extract_bounding_boxes(query.value.to_array()?.as_ref(), &query.field)?;
469 let bbox = total_bounds(&geo_array)?;
470 let mut rowids = self.search_bbox(bbox, metrics).await?;
471 let mut null_map = self.search_null(metrics).await?;
472
473 if let Some(fri) = &self.frag_reuse_index {
474 rowids = fri.remap_row_addrs_tree_map(&rowids);
475 null_map = fri.remap_row_addrs_tree_map(&null_map);
476 }
477 Ok(SearchResult::AtMost(NullableRowAddrSet::new(
478 rowids, null_map,
479 )))
480 }
481 GeoQuery::IsNull => {
482 let mut null_map = self.search_null(metrics).await?;
483
484 if let Some(fri) = &self.frag_reuse_index {
485 null_map = fri.remap_row_addrs_tree_map(&null_map);
486 }
487 Ok(SearchResult::Exact(NullableRowAddrSet::new(
488 null_map,
489 RowAddrTreeMap::default(),
490 )))
491 }
492 }
493 }
494
495 fn can_remap(&self) -> bool {
496 false
497 }
498
499 async fn remap(
500 &self,
501 _mapping: &HashMap<u64, Option<u64>>,
502 _dest_store: &dyn IndexStore,
503 ) -> Result<CreatedIndex> {
504 Err(Error::InvalidInput {
505 source: "RTree does not support remap".into(),
506 location: location!(),
507 })
508 }
509
510 async fn update(
511 &self,
512 new_data: SendableRecordBatchStream,
513 dest_store: &dyn IndexStore,
514 ) -> Result<CreatedIndex> {
515 let bbox_data = RTreeIndexPlugin::convert_bbox_stream(new_data)?;
516 let tmpdir = Arc::new(TempDir::default());
517 let spill_store = Arc::new(LanceIndexStore::new(
518 Arc::new(ObjectStore::local()),
519 tmpdir.obj_path(),
520 Arc::new(LanceCache::no_cache()),
521 ));
522 let (new_bbox_data, stats) = RTreeIndexPlugin::process_and_analyze_bbox_stream(
523 bbox_data,
524 self.metadata.page_size,
525 spill_store.clone(),
526 )
527 .await?;
528
529 let merged_bbox_data = self.clone().combine_old_new(new_bbox_data).await?;
530
531 let null_map = self.search_null(&NoOpMetricsCollector).await?;
532
533 let mut new_bbox = BoundingBox::new();
534 new_bbox.add_rect(&stats.total_bbox);
535 new_bbox.add_rect(&self.metadata.bbox);
536
537 let merge_stats = BboxStreamStats {
538 null_map: RowAddrTreeMap::union_all(&[&null_map, &stats.null_map]),
539 total_bbox: new_bbox,
540 num_items: self.metadata.num_items + stats.num_items,
541 };
542
543 RTreeIndexPlugin::train_rtree_index(
544 merged_bbox_data,
545 merge_stats,
546 self.metadata.page_size,
547 dest_store,
548 )
549 .await?;
550
551 Ok(CreatedIndex {
552 index_details: prost_types::Any::from_msg(&pb::RTreeIndexDetails::default())?,
553 index_version: RTREE_INDEX_VERSION,
554 })
555 }
556
557 fn update_criteria(&self) -> UpdateCriteria {
558 UpdateCriteria::only_new_data(TrainingCriteria::new(TrainingOrdering::None).with_row_id())
559 }
560
561 fn derive_index_params(&self) -> Result<ScalarIndexParams> {
562 let params = serde_json::to_value(RTreeParameters {
563 page_size: Some(self.metadata.page_size),
564 })?;
565 Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::RTree).with_params(¶ms))
566 }
567}
568
569#[derive(Debug, Serialize, Deserialize, Clone)]
571struct RTreeParameters {
572 pub page_size: Option<u32>,
574}
575
576pub struct RTreeTrainingRequest {
577 parameters: RTreeParameters,
578 criteria: TrainingCriteria,
579}
580
581impl RTreeTrainingRequest {
582 fn new(parameters: RTreeParameters) -> Self {
583 Self {
584 parameters,
585 criteria: TrainingCriteria::new(TrainingOrdering::None).with_row_id(),
586 }
587 }
588}
589
590impl Default for RTreeTrainingRequest {
591 fn default() -> Self {
592 Self::new(RTreeParameters {
593 page_size: Some(DEFAULT_RTREE_PAGE_SIZE),
594 })
595 }
596}
597
598impl TrainingRequest for RTreeTrainingRequest {
599 fn as_any(&self) -> &dyn Any {
600 self
601 }
602
603 fn criteria(&self) -> &TrainingCriteria {
604 &self.criteria
605 }
606}
607
608#[derive(Debug, Default)]
609pub struct RTreeIndexPlugin;
610
611impl RTreeIndexPlugin {
612 fn validate_schema(schema: &ArrowSchema) -> Result<()> {
613 if schema.fields().len() != 2 {
614 return Err(Error::InvalidInput {
615 source: "RTree index schema must have exactly two fields".into(),
616 location: location!(),
617 });
618 }
619
620 let row_id_field = schema.field_with_name(ROW_ID)?;
621 if *row_id_field.data_type() != DataType::UInt64 {
622 return Err(Error::InvalidInput {
623 source: "Second field in RTree index schema must be of type UInt64".into(),
624 location: location!(),
625 });
626 }
627 Ok(())
628 }
629
630 fn convert_bbox_stream(source: SendableRecordBatchStream) -> Result<SendableRecordBatchStream> {
631 let bbox_stream = source
632 .map_err(DataFusionError::into)
633 .and_then(move |batch| async move {
634 let schema = batch.schema();
635 let geometry_field = schema.field(0);
636 let geometry_array = batch.column(0);
637 let bbox_array = extract_bounding_boxes(geometry_array, geometry_field)?;
638
639 let bbox_schema = Arc::new(ArrowSchema::new(vec![
640 bbox_array.extension_type().clone().to_field("bbox", true),
641 ArrowField::new(ROW_ID, DataType::UInt64, false),
642 ]));
643 RecordBatch::try_new(
644 bbox_schema,
645 vec![bbox_array.into_array_ref(), batch.column(1).clone()],
646 )
647 .map_err(DataFusionError::from)
648 });
649
650 Ok(Box::pin(RecordBatchStreamAdapter::new(
651 BBOX_ROWID_SCHEMA.clone(),
652 bbox_stream,
653 )))
654 }
655
656 async fn process_and_analyze_bbox_stream(
659 mut data: SendableRecordBatchStream,
660 page_size: u32,
661 spill_store: Arc<LanceIndexStore>,
662 ) -> Result<(SendableRecordBatchStream, BboxStreamStats)> {
663 let mut null_rowaddrs = RowAddrTreeMap::new();
664 let mut total_bbox = BoundingBox::new();
665 let mut num_non_null_rows = 0;
666
667 let schema = data.schema();
668
669 let mut writer = spill_store
670 .new_index_file("analyze.tmp", BBOX_ROWID_SCHEMA.clone())
671 .await?;
672
673 while let Some(batch) = data.try_next().await? {
674 let bbox_array = extract_bounding_boxes(&batch.column(0), batch.schema().field(0))?;
675 let rowaddr_array = batch
676 .column(1)
677 .as_any()
678 .downcast_ref::<UInt64Array>()
679 .unwrap();
680
681 total_bbox.add_geo_arrow_array(&bbox_array)?;
682
683 let num_rows = bbox_array.len();
684
685 let mut non_null_indexes = vec![];
686
687 for i in 0..num_rows {
688 if bbox_array.is_null(i) {
689 let rowaddr = rowaddr_array.value(i);
690 null_rowaddrs.insert(rowaddr);
691 } else {
692 non_null_indexes.push(i as u32);
693 }
694 }
695
696 let new_batch = if non_null_indexes.is_empty() {
697 continue;
699 } else if non_null_indexes.len() == num_rows {
700 batch
701 } else {
702 batch.take(&UInt32Array::from(non_null_indexes))?
703 };
704
705 num_non_null_rows += new_batch.num_rows();
706 writer.write_record_batch(new_batch).await?;
707 }
708 writer.finish().await?;
709 let reader = spill_store.open_index_file("analyze.tmp").await?;
710 let stream = IndexReaderStream::new(reader, page_size as u64)
711 .await
712 .map(|fut| fut.map_err(DataFusionError::from))
713 .buffered(spill_store.io_parallelism())
714 .boxed();
715 let new_data = RecordBatchStreamAdapter::new(schema.clone(), stream);
716
717 Ok((
718 Box::pin(new_data),
719 BboxStreamStats {
720 null_map: null_rowaddrs,
721 total_bbox,
722 num_items: num_non_null_rows,
723 },
724 ))
725 }
726
727 async fn train_rtree_page(
728 batch: RecordBatch,
729 page_id: u64,
730 writer: &mut dyn IndexWriter,
731 ) -> Result<EncodedBatch> {
732 let geo_array = extract_bounding_boxes(batch.column(0).as_ref(), batch.schema().field(0))?;
733 let bbox = total_bounds(&geo_array)?;
734 let new_batch = RecordBatch::try_new(
735 RTREE_PAGE_SCHEMA.clone(),
736 vec![batch.column(0).clone(), batch.column(1).clone()],
737 )?;
738 writer.write_record_batch(new_batch).await?;
739 Ok(EncodedBatch { bbox, page_id })
740 }
741
742 fn encoded_batches_into_batch_stream(
743 batches: Vec<EncodedBatch>,
744 batch_size: u32,
745 ) -> SendableRecordBatchStream {
746 let batches = batches
747 .chunks(batch_size as usize)
748 .map(|chunk| {
749 let bbox_type = RectType::new(Dimension::XY, Default::default());
750 let mut bbox_builder = RectBuilder::with_capacity(bbox_type, chunk.len());
751 let mut page_ids = UInt64Array::builder(chunk.len());
752
753 for item in chunk {
754 bbox_builder.push_rect(Some(&item.bbox));
755 page_ids.append_value(item.page_id);
756 }
757
758 RecordBatch::try_new(
759 RTREE_PAGE_SCHEMA.clone(),
760 vec![
761 bbox_builder.finish().into_array_ref(),
762 Arc::new(page_ids.finish()),
763 ],
764 )
765 .unwrap()
766 })
767 .collect::<Vec<_>>();
768
769 Box::pin(RecordBatchStreamAdapter::new(
770 RTREE_PAGE_SCHEMA.clone(),
771 stream::iter(batches).map(Ok).boxed(),
772 ))
773 }
774
775 pub async fn write_index(
776 sorted_data: SendableRecordBatchStream,
777 num_items: usize,
778 total_bbox: BoundingBox,
779 store: &dyn IndexStore,
780 page_size: u32,
781 ) -> Result<()> {
782 let mut page_idx: u64 = 0;
783 let mut writer = store
784 .new_index_file(RTREE_PAGES_NAME, RTREE_PAGE_SCHEMA.clone())
785 .await?;
786
787 if num_items > 0 {
788 let mut current_level = Some((sorted_data, num_items));
789 while let Some((mut data, num_items)) = current_level.take() {
790 if num_items <= page_size as usize {
791 while let Some(batch) = data.try_next().await? {
792 Self::train_rtree_page(batch, page_idx, writer.as_mut()).await?;
793 page_idx += 1;
794 }
795 } else {
796 let mut next_level = vec![];
797 let mut paged_source = chunk_concat_stream(data, page_size as usize);
798 while let Some(batch) = paged_source.try_next().await? {
799 let encoded_batch =
800 Self::train_rtree_page(batch, page_idx, writer.as_mut()).await?;
801 page_idx += 1;
802 next_level.push(encoded_batch);
803 }
804 if !next_level.is_empty() {
805 let next_num_items = next_level.len();
806 current_level = Some((
807 Self::encoded_batches_into_batch_stream(next_level, page_size),
808 next_num_items,
809 ));
810 }
811 }
812 }
813 }
814
815 writer
816 .finish_with_metadata(
817 RTreeMetadata::new(page_size, page_idx, num_items, total_bbox).into_map(),
818 )
819 .await?;
820
821 Ok(())
822 }
823
824 pub async fn write_nulls(store: &dyn IndexStore, null_map: RowAddrTreeMap) -> Result<()> {
825 let mut writer = store
826 .new_index_file(RTREE_NULLS_NAME, RTREE_NULLS_SCHEMA.clone())
827 .await?;
828 let mut bytes = Vec::new();
829 null_map.serialize_into(&mut bytes)?;
830 let batch = RecordBatch::try_new(
831 RTREE_NULLS_SCHEMA.clone(),
832 vec![Arc::new(BinaryArray::from_vec(vec![&bytes]))],
833 )?;
834
835 writer.write_record_batch(batch).await?;
836 writer.finish().await
837 }
838
839 async fn train_rtree_index(
840 bbox_data: SendableRecordBatchStream,
841 stats: BboxStreamStats,
842 page_size: u32,
843 store: &dyn IndexStore,
844 ) -> Result<()> {
845 let sorter = HilbertSorter::new(stats.total_bbox);
847 let sorted_data = sorter.sort(bbox_data).await?;
848
849 Self::write_index(
850 sorted_data,
851 stats.num_items,
852 stats.total_bbox,
853 store,
854 page_size,
855 )
856 .await?;
857
858 Self::write_nulls(store, stats.null_map).await?;
859
860 Ok(())
861 }
862}
863
864#[async_trait]
865impl ScalarIndexPlugin for RTreeIndexPlugin {
866 fn name(&self) -> &str {
867 "RTree"
868 }
869
870 fn new_training_request(
871 &self,
872 params: &str,
873 _field: &ArrowField,
874 ) -> Result<Box<dyn TrainingRequest>> {
875 let params = serde_json::from_str::<RTreeParameters>(params)?;
876 Ok(Box::new(RTreeTrainingRequest::new(params)))
877 }
878
879 async fn train_index(
880 &self,
881 data: SendableRecordBatchStream,
882 index_store: &dyn IndexStore,
883 request: Box<dyn TrainingRequest>,
884 fragment_ids: Option<Vec<u32>>,
885 ) -> Result<CreatedIndex> {
886 if fragment_ids.is_some() {
887 return Err(Error::InvalidInput {
888 source: "RTree index does not support fragment training".into(),
889 location: location!(),
890 });
891 }
892
893 Self::validate_schema(&data.schema())?;
894
895 let request = request
896 .as_any()
897 .downcast_ref::<RTreeTrainingRequest>()
898 .unwrap();
899 let page_size = request
900 .parameters
901 .page_size
902 .unwrap_or(DEFAULT_RTREE_PAGE_SIZE);
903
904 let bbox_data = Self::convert_bbox_stream(data)?;
905 let tmpdir = Arc::new(TempDir::default());
906 let spill_store = Arc::new(LanceIndexStore::new(
907 Arc::new(ObjectStore::local()),
908 tmpdir.obj_path(),
909 Arc::new(LanceCache::no_cache()),
910 ));
911 let (bbox_data, stats) =
912 Self::process_and_analyze_bbox_stream(bbox_data, page_size, spill_store.clone())
913 .await?;
914
915 Self::train_rtree_index(bbox_data, stats, page_size, index_store).await?;
916
917 Ok(CreatedIndex {
918 index_details: prost_types::Any::from_msg(&pb::RTreeIndexDetails::default())?,
919 index_version: RTREE_INDEX_VERSION,
920 })
921 }
922
923 fn provides_exact_answer(&self) -> bool {
924 false
925 }
926
927 fn version(&self) -> u32 {
928 RTREE_INDEX_VERSION
929 }
930
931 fn new_query_parser(
932 &self,
933 index_name: String,
934 _index_details: &prost_types::Any,
935 ) -> Option<Box<dyn ScalarQueryParser>> {
936 Some(Box::new(GeoQueryParser::new(index_name)))
937 }
938
939 async fn load_index(
940 &self,
941 index_store: Arc<dyn IndexStore>,
942 _index_details: &prost_types::Any,
943 frag_reuse_index: Option<Arc<FragReuseIndex>>,
944 cache: &LanceCache,
945 ) -> Result<Arc<dyn ScalarIndex>> {
946 Ok(RTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc<dyn ScalarIndex>)
947 }
948}
949
950struct EncodedBatch {
951 bbox: BoundingBox,
952 page_id: u64,
953}
954
955#[cfg(test)]
956mod tests {
957 use super::*;
958 use crate::metrics::NoOpMetricsCollector;
959 use crate::scalar::registry::VALUE_COLUMN_NAME;
960 use arrow_array::ArrayRef;
961 use arrow_schema::Schema;
962 use geo_types::{coord, Rect};
963 use geoarrow_array::builder::{PointBuilder, RectBuilder};
964 use geoarrow_schema::{Dimension, PointType, RectType};
965 use lance_core::utils::tempfile::TempObjDir;
966 use rand::Rng;
967
968 fn expected_num_pages(num_items: usize, page_size: u32) -> u64 {
969 RTreeMetadata::calculate_page_offsets(num_items, page_size).len() as u64
970 }
971
972 fn convert_bbox_rowid_batch_stream(
973 geo_array: &dyn GeoArrowArray,
974 row_id_array: ArrayRef,
975 ) -> SendableRecordBatchStream {
976 let schema = Arc::new(Schema::new(vec![
977 geo_array.data_type().to_field(VALUE_COLUMN_NAME, true),
978 ArrowField::new(ROW_ID, DataType::UInt64, false),
979 ]));
980
981 let batch =
982 RecordBatch::try_new(schema.clone(), vec![geo_array.to_array_ref(), row_id_array])
983 .unwrap();
984
985 let stream = stream::once(async move { Ok(batch) });
986 Box::pin(RecordBatchStreamAdapter::new(schema, stream))
987 }
988
989 async fn train_index(
990 geo_array: &dyn GeoArrowArray,
991 page_size: Option<u32>,
992 ) -> (Arc<RTreeIndex>, Arc<LanceIndexStore>, TempObjDir) {
993 let page_size = page_size.unwrap_or(DEFAULT_RTREE_PAGE_SIZE);
994 let mut num_items = 0;
995 for i in 0..geo_array.len() {
996 if !geo_array.is_null(i) {
997 num_items += 1;
998 }
999 }
1000
1001 let tmpdir = TempObjDir::default();
1002 let store = Arc::new(LanceIndexStore::new(
1003 Arc::new(ObjectStore::local()),
1004 tmpdir.clone(),
1005 Arc::new(LanceCache::no_cache()),
1006 ));
1007
1008 let stream = convert_bbox_rowid_batch_stream(
1009 geo_array,
1010 Arc::new(UInt64Array::from(
1011 (0..geo_array.len() as u64).collect::<Vec<_>>(),
1012 )),
1013 );
1014
1015 let plugin = RTreeIndexPlugin;
1016 plugin
1017 .train_index(
1018 stream,
1019 store.as_ref(),
1020 Box::new(RTreeTrainingRequest::new(RTreeParameters {
1021 page_size: Some(page_size),
1022 })),
1023 None,
1024 )
1025 .await
1026 .unwrap();
1027
1028 let pages_reader = store.open_index_file(RTREE_PAGES_NAME).await.unwrap();
1029 let metadata = RTreeMetadata::from(&pages_reader.schema().metadata);
1030 assert_eq!(metadata.num_items, num_items);
1031 assert_eq!(metadata.num_pages, expected_num_pages(num_items, page_size));
1032
1033 (
1034 RTreeIndex::load(store.clone(), None, &LanceCache::no_cache())
1035 .await
1036 .unwrap(),
1037 store,
1038 tmpdir,
1039 )
1040 }
1041
1042 #[tokio::test]
1043 async fn test_search_bbox() {
1044 let bbox_type = RectType::new(Dimension::XY, Default::default());
1045
1046 let mut rng = rand::rng();
1047 let mut rect_builder = RectBuilder::new(bbox_type.clone());
1048 let num_items = 10000;
1049 let page_size = 16;
1050
1051 for _ in 0..num_items {
1052 let x1 = rng.random_range(-1000.0..1000.0);
1053 let y1 = rng.random_range(-1000.0..1000.0);
1054 let x2 = rng.random_range(x1..x1 + 10.0);
1055 let y2 = rng.random_range(y1..y1 + 10.0);
1056
1057 rect_builder.push_rect(Some(&Rect::new(
1058 coord! { x: x1, y: y1 },
1059 coord! { x: x2, y: y2 },
1060 )));
1061 }
1062 let rect_arr = rect_builder.finish();
1063
1064 let (rtree_index, _store, _tmpdir) = train_index(&rect_arr, Some(page_size)).await;
1065
1066 let mut search_bbox = BoundingBox::new();
1067 search_bbox.add_rect(&Rect::new(
1068 coord! { x: 10.5, y: 1.5 },
1069 coord! { x: 99.5, y: 200.5 },
1070 ));
1071 let row_ids = rtree_index
1072 .search_bbox(search_bbox, &NoOpMetricsCollector)
1073 .await
1074 .unwrap();
1075
1076 let mut expected_row_ids = RowAddrTreeMap::new();
1077 for i in 0..rect_arr.len() {
1078 let mut bbox = BoundingBox::new();
1079 bbox.add_rect(&rect_arr.value(i).unwrap());
1080 if search_bbox.rect_intersects(&bbox) {
1081 expected_row_ids.insert(i as u64);
1082 }
1083 }
1084 assert_eq!(row_ids, expected_row_ids);
1085 }
1086
1087 #[tokio::test]
1088 async fn test_search_null() {
1089 let point_type = PointType::new(Dimension::XY, Default::default());
1090
1091 let mut rng = rand::rng();
1092 let num_points = 10000;
1093 let null_probability = 0.001; let mut expected_nulls = Vec::new();
1096 let mut point_builder = PointBuilder::new(point_type.clone());
1097
1098 for i in 0..num_points {
1099 if rng.random_bool(null_probability) {
1100 point_builder.push_null();
1101 expected_nulls.push(RowAddress::new_from_parts(0, i as u32));
1102 } else {
1103 let x = rng.random_range(-1000.0..1000.0);
1104 let y = rng.random_range(-1000.0..1000.0);
1105 point_builder.push_point(Some(&geo_types::point!(x: x, y: y)));
1106 }
1107 }
1108 let point_arr = point_builder.finish();
1109
1110 let (rtree_index, _store, _tmpdir) = train_index(&point_arr, None).await;
1111 let row_addrs = rtree_index
1112 .search_null(&NoOpMetricsCollector)
1113 .await
1114 .unwrap();
1115
1116 let mut actual_nulls = row_addrs.row_addrs().unwrap().collect::<Vec<_>>();
1117 actual_nulls.sort();
1118 expected_nulls.sort();
1119
1120 assert_eq!(actual_nulls, expected_nulls);
1121 }
1122
1123 #[tokio::test]
1124 async fn test_update_and_search() {
1125 fn gen_data(num_items: u32, frag_id: u32, nulls_addrs: &mut RowAddrTreeMap) -> RectArray {
1126 let bbox_type = RectType::new(Dimension::XY, Default::default());
1127
1128 let mut rng = rand::rng();
1129 let null_probability = 0.001;
1130 let mut rect_builder = RectBuilder::new(bbox_type);
1131
1132 for i in 0..num_items {
1133 if rng.random_bool(null_probability) {
1134 rect_builder.push_null();
1135 nulls_addrs.insert(RowAddress::new_from_parts(frag_id, i).into());
1136 } else {
1137 let x1 = rng.random_range(-1000.0..1000.0);
1138 let y1 = rng.random_range(-1000.0..1000.0);
1139 let x2 = rng.random_range(x1..x1 + 10.0);
1140 let y2 = rng.random_range(y1..y1 + 10.0);
1141
1142 rect_builder.push_rect(Some(&Rect::new(
1143 coord! { x: x1, y: y1 },
1144 coord! { x: x2, y: y2 },
1145 )));
1146 }
1147 }
1148 rect_builder.finish()
1149 }
1150
1151 let mut nulls_addrs = RowAddrTreeMap::default();
1152
1153 let frag_id = 0;
1154 let rect_arr = gen_data(10000, frag_id, &mut nulls_addrs);
1155
1156 let (rtree_index, _store, _tmpdir) = train_index(&rect_arr, Some(16)).await;
1157
1158 let tmpdir = TempObjDir::default();
1159 let new_store = Arc::new(LanceIndexStore::new(
1160 Arc::new(ObjectStore::local()),
1161 tmpdir.clone(),
1162 Arc::new(LanceCache::no_cache()),
1163 ));
1164
1165 let new_frag_id = 1;
1166 let new_rect_arr = gen_data(10000, 1, &mut nulls_addrs);
1167 let new_rowaddr_arr = (0..new_rect_arr.len())
1168 .map(|off| RowAddress::new_from_parts(new_frag_id, off as u32).into())
1169 .collect::<Vec<_>>();
1170 let stream = convert_bbox_rowid_batch_stream(
1171 &new_rect_arr,
1172 Arc::new(UInt64Array::from(new_rowaddr_arr.clone())),
1173 );
1174 rtree_index
1175 .update(stream, new_store.as_ref())
1176 .await
1177 .unwrap();
1178
1179 let new_rtree_index = RTreeIndex::load(new_store.clone(), None, &LanceCache::no_cache())
1180 .await
1181 .unwrap();
1182
1183 let mut search_bbox = BoundingBox::new();
1184 search_bbox.add_rect(&Rect::new(
1185 coord! { x: 10.5, y: 1.5 },
1186 coord! { x: 99.5, y: 200.5 },
1187 ));
1188 let row_addrs = new_rtree_index
1189 .search_bbox(search_bbox, &NoOpMetricsCollector)
1190 .await
1191 .unwrap();
1192
1193 let mut expected_row_addrs = RowAddrTreeMap::new();
1194 for i in 0..rect_arr.len() {
1195 if !rect_arr.is_null(i) {
1196 let bbox = BoundingBox::new_with_rect(&rect_arr.value(i).unwrap());
1197 if search_bbox.rect_intersects(&bbox) {
1198 expected_row_addrs.insert(i as u64);
1199 }
1200 }
1201 }
1202 for i in 0..new_rect_arr.len() {
1203 if !new_rect_arr.is_null(i) {
1204 let bbox = BoundingBox::new_with_rect(&new_rect_arr.value(i).unwrap());
1205 if search_bbox.rect_intersects(&bbox) {
1206 expected_row_addrs.insert(new_rowaddr_arr.get(i).copied().unwrap());
1207 }
1208 }
1209 }
1210
1211 assert_eq!(row_addrs, expected_row_addrs);
1212
1213 let actual_nulls = new_rtree_index
1214 .search_null(&NoOpMetricsCollector)
1215 .await
1216 .unwrap();
1217 assert_eq!(actual_nulls, nulls_addrs);
1218 }
1219
1220 #[tokio::test]
1221 async fn test_prewarm() {
1222 let point_type = PointType::new(Dimension::XY, Default::default());
1223
1224 let mut rng = rand::rng();
1225 let num_points = 1000;
1226 let null_probability = 0.1;
1227
1228 let mut point_builder = PointBuilder::new(point_type.clone());
1229
1230 for _ in 0..num_points {
1231 if rng.random_bool(null_probability) {
1232 point_builder.push_null();
1233 } else {
1234 let x = rng.random_range(-1000.0..1000.0);
1235 let y = rng.random_range(-1000.0..1000.0);
1236 point_builder.push_point(Some(&geo_types::point!(x: x, y: y)));
1237 }
1238 }
1239 let point_arr = point_builder.finish();
1240
1241 let (_, store, _tmpdir) = train_index(&point_arr, Some(32)).await;
1242
1243 let cache = LanceCache::with_capacity(10 << 20);
1244 let rtree_index = RTreeIndex::load(store, None, &cache).await.unwrap();
1245
1246 rtree_index.prewarm().await.unwrap();
1248
1249 for page_id in 0..rtree_index.metadata.num_pages {
1250 assert!(rtree_index
1251 .index_cache
1252 .get_with_key(&RTreeCacheKey::Page(page_id))
1253 .await
1254 .is_some())
1255 }
1256
1257 assert!(rtree_index
1258 .index_cache
1259 .get_with_key(&RTreeCacheKey::Nulls)
1260 .await
1261 .is_some())
1262 }
1263}