lance_index/scalar/
btree.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    any::Any,
6    cmp::Ordering,
7    collections::{BTreeMap, BinaryHeap, HashMap},
8    fmt::{Debug, Display},
9    ops::Bound,
10    sync::Arc,
11};
12
13use super::{
14    flat::FlatIndexMetadata, AnyQuery, IndexReader, IndexStore, IndexWriter, MetricsCollector,
15    SargableQuery, ScalarIndex, SearchResult,
16};
17use crate::frag_reuse::FragReuseIndex;
18use crate::{Index, IndexType};
19use arrow_array::{new_empty_array, Array, RecordBatch, UInt32Array};
20use arrow_schema::{DataType, Field, Schema, SortOptions};
21use async_trait::async_trait;
22use datafusion::physical_plan::{
23    sorts::sort_preserving_merge::SortPreservingMergeExec, stream::RecordBatchStreamAdapter,
24    union::UnionExec, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream,
25};
26use datafusion_common::{DataFusionError, ScalarValue};
27use datafusion_physical_expr::{expressions::Column, LexOrdering, PhysicalSortExpr};
28use deepsize::{Context, DeepSizeOf};
29use futures::{
30    future::BoxFuture,
31    stream::{self},
32    FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
33};
34use lance_core::{
35    utils::{
36        mask::RowIdTreeMap,
37        tokio::get_num_compute_intensive_cpus,
38        tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS},
39    },
40    Error, Result,
41};
42use lance_datafusion::{
43    chunker::chunk_concat_stream,
44    exec::{execute_plan, LanceExecutionOptions, OneShotExec},
45};
46use log::debug;
47use moka::sync::Cache;
48use roaring::RoaringBitmap;
49use serde::{Serialize, Serializer};
50use snafu::location;
51use tracing::info;
52
53const BTREE_LOOKUP_NAME: &str = "page_lookup.lance";
54const BTREE_PAGES_NAME: &str = "page_data.lance";
55pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096;
56const BATCH_SIZE_META_KEY: &str = "batch_size";
57
58static CACHE_SIZE: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
59    std::env::var("LANCE_BTREE_CACHE_SIZE")
60        .ok()
61        .and_then(|s| s.parse().ok())
62        .unwrap_or(512 * 1024 * 1024)
63});
64
65/// Wraps a ScalarValue and implements Ord (ScalarValue only implements PartialOrd)
66#[derive(Clone, Debug)]
67pub struct OrderableScalarValue(pub ScalarValue);
68
69impl DeepSizeOf for OrderableScalarValue {
70    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
71        // deepsize and size both factor in the size of the ScalarValue
72        self.0.size() - std::mem::size_of::<ScalarValue>()
73    }
74}
75
76impl Display for OrderableScalarValue {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        std::fmt::Display::fmt(&self.0, f)
79    }
80}
81
82impl PartialEq for OrderableScalarValue {
83    fn eq(&self, other: &Self) -> bool {
84        self.0.eq(&other.0)
85    }
86}
87
88impl Eq for OrderableScalarValue {}
89
90impl PartialOrd for OrderableScalarValue {
91    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
92        Some(self.cmp(other))
93    }
94}
95
96// manual implementation of `Ord` that panics when asked to compare scalars of different type
97// and always puts nulls before non-nulls (this is consistent with Option<T>'s implementation
98// of Ord)
99//
100// TODO: Consider upstreaming this
101impl Ord for OrderableScalarValue {
102    fn cmp(&self, other: &Self) -> Ordering {
103        use ScalarValue::*;
104        // This purposely doesn't have a catch-all "(_, _)" so that
105        // any newly added enum variant will require editing this list
106        // or else face a compile error
107        match (&self.0, &other.0) {
108            (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => {
109                if p1.eq(p2) && s1.eq(s2) {
110                    v1.cmp(v2)
111                } else {
112                    // Two decimal values can only be compared if they have the same precision and scale.
113                    panic!("Attempt to compare decimals with unequal precision / scale")
114                }
115            }
116            (Decimal128(v1, _, _), Null) => {
117                if v1.is_none() {
118                    Ordering::Equal
119                } else {
120                    Ordering::Greater
121                }
122            }
123            (Decimal128(_, _, _), _) => panic!("Attempt to compare decimal with non-decimal"),
124            (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => {
125                if p1.eq(p2) && s1.eq(s2) {
126                    v1.cmp(v2)
127                } else {
128                    // Two decimal values can only be compared if they have the same precision and scale.
129                    panic!("Attempt to compare decimals with unequal precision / scale")
130                }
131            }
132            (Decimal256(v1, _, _), Null) => {
133                if v1.is_none() {
134                    Ordering::Equal
135                } else {
136                    Ordering::Greater
137                }
138            }
139            (Decimal256(_, _, _), _) => panic!("Attempt to compare decimal with non-decimal"),
140            (Boolean(v1), Boolean(v2)) => v1.cmp(v2),
141            (Boolean(v1), Null) => {
142                if v1.is_none() {
143                    Ordering::Equal
144                } else {
145                    Ordering::Greater
146                }
147            }
148            (Boolean(_), _) => panic!("Attempt to compare boolean with non-boolean"),
149            (Float32(v1), Float32(v2)) => match (v1, v2) {
150                (Some(f1), Some(f2)) => f1.total_cmp(f2),
151                (None, Some(_)) => Ordering::Less,
152                (Some(_), None) => Ordering::Greater,
153                (None, None) => Ordering::Equal,
154            },
155            (Float32(v1), Null) => {
156                if v1.is_none() {
157                    Ordering::Equal
158                } else {
159                    Ordering::Greater
160                }
161            }
162            (Float32(_), _) => panic!("Attempt to compare f32 with non-f32"),
163            (Float64(v1), Float64(v2)) => match (v1, v2) {
164                (Some(f1), Some(f2)) => f1.total_cmp(f2),
165                (None, Some(_)) => Ordering::Less,
166                (Some(_), None) => Ordering::Greater,
167                (None, None) => Ordering::Equal,
168            },
169            (Float64(v1), Null) => {
170                if v1.is_none() {
171                    Ordering::Equal
172                } else {
173                    Ordering::Greater
174                }
175            }
176            (Float64(_), _) => panic!("Attempt to compare f64 with non-f64"),
177            (Float16(v1), Float16(v2)) => match (v1, v2) {
178                (Some(f1), Some(f2)) => f1.total_cmp(f2),
179                (None, Some(_)) => Ordering::Less,
180                (Some(_), None) => Ordering::Greater,
181                (None, None) => Ordering::Equal,
182            },
183            (Float16(v1), Null) => {
184                if v1.is_none() {
185                    Ordering::Equal
186                } else {
187                    Ordering::Greater
188                }
189            }
190            (Float16(_), _) => panic!("Attempt to compare f16 with non-f16"),
191            (Int8(v1), Int8(v2)) => v1.cmp(v2),
192            (Int8(v1), Null) => {
193                if v1.is_none() {
194                    Ordering::Equal
195                } else {
196                    Ordering::Greater
197                }
198            }
199            (Int8(_), _) => panic!("Attempt to compare Int8 with non-Int8"),
200            (Int16(v1), Int16(v2)) => v1.cmp(v2),
201            (Int16(v1), Null) => {
202                if v1.is_none() {
203                    Ordering::Equal
204                } else {
205                    Ordering::Greater
206                }
207            }
208            (Int16(_), _) => panic!("Attempt to compare Int16 with non-Int16"),
209            (Int32(v1), Int32(v2)) => v1.cmp(v2),
210            (Int32(v1), Null) => {
211                if v1.is_none() {
212                    Ordering::Equal
213                } else {
214                    Ordering::Greater
215                }
216            }
217            (Int32(_), _) => panic!("Attempt to compare Int32 with non-Int32"),
218            (Int64(v1), Int64(v2)) => v1.cmp(v2),
219            (Int64(v1), Null) => {
220                if v1.is_none() {
221                    Ordering::Equal
222                } else {
223                    Ordering::Greater
224                }
225            }
226            (Int64(_), _) => panic!("Attempt to compare Int16 with non-Int64"),
227            (UInt8(v1), UInt8(v2)) => v1.cmp(v2),
228            (UInt8(v1), Null) => {
229                if v1.is_none() {
230                    Ordering::Equal
231                } else {
232                    Ordering::Greater
233                }
234            }
235            (UInt8(_), _) => panic!("Attempt to compare UInt8 with non-UInt8"),
236            (UInt16(v1), UInt16(v2)) => v1.cmp(v2),
237            (UInt16(v1), Null) => {
238                if v1.is_none() {
239                    Ordering::Equal
240                } else {
241                    Ordering::Greater
242                }
243            }
244            (UInt16(_), _) => panic!("Attempt to compare UInt16 with non-UInt16"),
245            (UInt32(v1), UInt32(v2)) => v1.cmp(v2),
246            (UInt32(v1), Null) => {
247                if v1.is_none() {
248                    Ordering::Equal
249                } else {
250                    Ordering::Greater
251                }
252            }
253            (UInt32(_), _) => panic!("Attempt to compare UInt32 with non-UInt32"),
254            (UInt64(v1), UInt64(v2)) => v1.cmp(v2),
255            (UInt64(v1), Null) => {
256                if v1.is_none() {
257                    Ordering::Equal
258                } else {
259                    Ordering::Greater
260                }
261            }
262            (UInt64(_), _) => panic!("Attempt to compare Int16 with non-UInt64"),
263            (Utf8(v1) | Utf8View(v1) | LargeUtf8(v1), Utf8(v2) | Utf8View(v2) | LargeUtf8(v2)) => {
264                v1.cmp(v2)
265            }
266            (Utf8(v1) | Utf8View(v1) | LargeUtf8(v1), Null) => {
267                if v1.is_none() {
268                    Ordering::Equal
269                } else {
270                    Ordering::Greater
271                }
272            }
273            (Utf8(_) | Utf8View(_) | LargeUtf8(_), _) => {
274                panic!("Attempt to compare Utf8 with non-Utf8")
275            }
276            (
277                Binary(v1) | LargeBinary(v1) | BinaryView(v1),
278                Binary(v2) | LargeBinary(v2) | BinaryView(v2),
279            ) => v1.cmp(v2),
280            (Binary(v1) | LargeBinary(v1) | BinaryView(v1), Null) => {
281                if v1.is_none() {
282                    Ordering::Equal
283                } else {
284                    Ordering::Greater
285                }
286            }
287            (Binary(_) | LargeBinary(_) | BinaryView(_), _) => {
288                panic!("Attempt to compare Binary with non-Binary")
289            }
290            (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.cmp(v2),
291            (FixedSizeBinary(_, v1), Null) => {
292                if v1.is_none() {
293                    Ordering::Equal
294                } else {
295                    Ordering::Greater
296                }
297            }
298            (FixedSizeBinary(_, _), _) => {
299                panic!("Attempt to compare FixedSizeBinary with non-FixedSizeBinary")
300            }
301            (FixedSizeList(left), FixedSizeList(right)) => {
302                if left.eq(right) {
303                    todo!()
304                } else {
305                    panic!(
306                        "Attempt to compare fixed size list elements with different widths/fields"
307                    )
308                }
309            }
310            (FixedSizeList(left), Null) => {
311                if left.is_null(0) {
312                    Ordering::Equal
313                } else {
314                    Ordering::Greater
315                }
316            }
317            (FixedSizeList(_), _) => {
318                panic!("Attempt to compare FixedSizeList with non-FixedSizeList")
319            }
320            (List(_), List(_)) => todo!(),
321            (List(left), Null) => {
322                if left.is_null(0) {
323                    Ordering::Equal
324                } else {
325                    Ordering::Greater
326                }
327            }
328            (List(_), _) => {
329                panic!("Attempt to compare List with non-List")
330            }
331            (LargeList(_), _) => todo!(),
332            (Map(_), Map(_)) => todo!(),
333            (Map(left), Null) => {
334                if left.is_null(0) {
335                    Ordering::Equal
336                } else {
337                    Ordering::Greater
338                }
339            }
340            (Map(_), _) => {
341                panic!("Attempt to compare Map with non-Map")
342            }
343            (Date32(v1), Date32(v2)) => v1.cmp(v2),
344            (Date32(v1), Null) => {
345                if v1.is_none() {
346                    Ordering::Equal
347                } else {
348                    Ordering::Greater
349                }
350            }
351            (Date32(_), _) => panic!("Attempt to compare Date32 with non-Date32"),
352            (Date64(v1), Date64(v2)) => v1.cmp(v2),
353            (Date64(v1), Null) => {
354                if v1.is_none() {
355                    Ordering::Equal
356                } else {
357                    Ordering::Greater
358                }
359            }
360            (Date64(_), _) => panic!("Attempt to compare Date64 with non-Date64"),
361            (Time32Second(v1), Time32Second(v2)) => v1.cmp(v2),
362            (Time32Second(v1), Null) => {
363                if v1.is_none() {
364                    Ordering::Equal
365                } else {
366                    Ordering::Greater
367                }
368            }
369            (Time32Second(_), _) => panic!("Attempt to compare Time32Second with non-Time32Second"),
370            (Time32Millisecond(v1), Time32Millisecond(v2)) => v1.cmp(v2),
371            (Time32Millisecond(v1), Null) => {
372                if v1.is_none() {
373                    Ordering::Equal
374                } else {
375                    Ordering::Greater
376                }
377            }
378            (Time32Millisecond(_), _) => {
379                panic!("Attempt to compare Time32Millisecond with non-Time32Millisecond")
380            }
381            (Time64Microsecond(v1), Time64Microsecond(v2)) => v1.cmp(v2),
382            (Time64Microsecond(v1), Null) => {
383                if v1.is_none() {
384                    Ordering::Equal
385                } else {
386                    Ordering::Greater
387                }
388            }
389            (Time64Microsecond(_), _) => {
390                panic!("Attempt to compare Time64Microsecond with non-Time64Microsecond")
391            }
392            (Time64Nanosecond(v1), Time64Nanosecond(v2)) => v1.cmp(v2),
393            (Time64Nanosecond(v1), Null) => {
394                if v1.is_none() {
395                    Ordering::Equal
396                } else {
397                    Ordering::Greater
398                }
399            }
400            (Time64Nanosecond(_), _) => {
401                panic!("Attempt to compare Time64Nanosecond with non-Time64Nanosecond")
402            }
403            (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.cmp(v2),
404            (TimestampSecond(v1, _), Null) => {
405                if v1.is_none() {
406                    Ordering::Equal
407                } else {
408                    Ordering::Greater
409                }
410            }
411            (TimestampSecond(_, _), _) => {
412                panic!("Attempt to compare TimestampSecond with non-TimestampSecond")
413            }
414            (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => v1.cmp(v2),
415            (TimestampMillisecond(v1, _), Null) => {
416                if v1.is_none() {
417                    Ordering::Equal
418                } else {
419                    Ordering::Greater
420                }
421            }
422            (TimestampMillisecond(_, _), _) => {
423                panic!("Attempt to compare TimestampMillisecond with non-TimestampMillisecond")
424            }
425            (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => v1.cmp(v2),
426            (TimestampMicrosecond(v1, _), Null) => {
427                if v1.is_none() {
428                    Ordering::Equal
429                } else {
430                    Ordering::Greater
431                }
432            }
433            (TimestampMicrosecond(_, _), _) => {
434                panic!("Attempt to compare TimestampMicrosecond with non-TimestampMicrosecond")
435            }
436            (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.cmp(v2),
437            (TimestampNanosecond(v1, _), Null) => {
438                if v1.is_none() {
439                    Ordering::Equal
440                } else {
441                    Ordering::Greater
442                }
443            }
444            (TimestampNanosecond(_, _), _) => {
445                panic!("Attempt to compare TimestampNanosecond with non-TimestampNanosecond")
446            }
447            (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.cmp(v2),
448            (IntervalYearMonth(v1), Null) => {
449                if v1.is_none() {
450                    Ordering::Equal
451                } else {
452                    Ordering::Greater
453                }
454            }
455            (IntervalYearMonth(_), _) => {
456                panic!("Attempt to compare IntervalYearMonth with non-IntervalYearMonth")
457            }
458            (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.cmp(v2),
459            (IntervalDayTime(v1), Null) => {
460                if v1.is_none() {
461                    Ordering::Equal
462                } else {
463                    Ordering::Greater
464                }
465            }
466            (IntervalDayTime(_), _) => {
467                panic!("Attempt to compare IntervalDayTime with non-IntervalDayTime")
468            }
469            (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.cmp(v2),
470            (IntervalMonthDayNano(v1), Null) => {
471                if v1.is_none() {
472                    Ordering::Equal
473                } else {
474                    Ordering::Greater
475                }
476            }
477            (IntervalMonthDayNano(_), _) => {
478                panic!("Attempt to compare IntervalMonthDayNano with non-IntervalMonthDayNano")
479            }
480            (DurationSecond(v1), DurationSecond(v2)) => v1.cmp(v2),
481            (DurationSecond(v1), Null) => {
482                if v1.is_none() {
483                    Ordering::Equal
484                } else {
485                    Ordering::Greater
486                }
487            }
488            (DurationSecond(_), _) => {
489                panic!("Attempt to compare DurationSecond with non-DurationSecond")
490            }
491            (DurationMillisecond(v1), DurationMillisecond(v2)) => v1.cmp(v2),
492            (DurationMillisecond(v1), Null) => {
493                if v1.is_none() {
494                    Ordering::Equal
495                } else {
496                    Ordering::Greater
497                }
498            }
499            (DurationMillisecond(_), _) => {
500                panic!("Attempt to compare DurationMillisecond with non-DurationMillisecond")
501            }
502            (DurationMicrosecond(v1), DurationMicrosecond(v2)) => v1.cmp(v2),
503            (DurationMicrosecond(v1), Null) => {
504                if v1.is_none() {
505                    Ordering::Equal
506                } else {
507                    Ordering::Greater
508                }
509            }
510            (DurationMicrosecond(_), _) => {
511                panic!("Attempt to compare DurationMicrosecond with non-DurationMicrosecond")
512            }
513            (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.cmp(v2),
514            (DurationNanosecond(v1), Null) => {
515                if v1.is_none() {
516                    Ordering::Equal
517                } else {
518                    Ordering::Greater
519                }
520            }
521            (DurationNanosecond(_), _) => {
522                panic!("Attempt to compare DurationNanosecond with non-DurationNanosecond")
523            }
524            (Struct(_arr), Struct(_arr2)) => todo!(),
525            (Struct(arr), Null) => {
526                if arr.is_empty() {
527                    Ordering::Equal
528                } else {
529                    Ordering::Greater
530                }
531            }
532            (Struct(_arr), _) => panic!("Attempt to compare Struct with non-Struct"),
533            (Dictionary(_k1, _v1), Dictionary(_k2, _v2)) => todo!(),
534            (Dictionary(_, v1), Null) => Self(*v1.clone()).cmp(&Self(ScalarValue::Null)),
535            (Dictionary(_, _), _) => panic!("Attempt to compare Dictionary with non-Dictionary"),
536            // What would a btree of unions even look like?  May not be possible.
537            (Union(_, _, _), _) => todo!("Support for union scalars"),
538            (Null, Null) => Ordering::Equal,
539            (Null, _) => todo!(),
540        }
541    }
542}
543
544#[derive(Debug, DeepSizeOf, PartialEq, Eq)]
545struct PageRecord {
546    max: OrderableScalarValue,
547    page_number: u32,
548}
549
550trait BTreeMapExt<K, V> {
551    fn largest_node_less(&self, key: &K) -> Option<(&K, &V)>;
552}
553
554impl<K: Ord, V> BTreeMapExt<K, V> for BTreeMap<K, V> {
555    fn largest_node_less(&self, key: &K) -> Option<(&K, &V)> {
556        self.range((Bound::Unbounded, Bound::Excluded(key)))
557            .next_back()
558    }
559}
560
561/// An in-memory structure that can quickly satisfy scalar queries using a btree of ScalarValue
562#[derive(Debug, DeepSizeOf, PartialEq, Eq)]
563pub struct BTreeLookup {
564    tree: BTreeMap<OrderableScalarValue, Vec<PageRecord>>,
565    /// Pages where the value may be null
566    null_pages: Vec<u32>,
567}
568
569impl BTreeLookup {
570    fn new(tree: BTreeMap<OrderableScalarValue, Vec<PageRecord>>, null_pages: Vec<u32>) -> Self {
571        Self { tree, null_pages }
572    }
573
574    // All pages that could have a value equal to val
575    fn pages_eq(&self, query: &OrderableScalarValue) -> Vec<u32> {
576        if query.0.is_null() {
577            self.pages_null()
578        } else {
579            self.pages_between((Bound::Included(query), Bound::Excluded(query)))
580        }
581    }
582
583    // All pages that could have a value equal to one of the values
584    fn pages_in(&self, values: impl IntoIterator<Item = OrderableScalarValue>) -> Vec<u32> {
585        let page_lists = values
586            .into_iter()
587            .map(|val| self.pages_eq(&val))
588            .collect::<Vec<_>>();
589        let total_size = page_lists.iter().map(|set| set.len()).sum();
590        let mut heap = BinaryHeap::with_capacity(total_size);
591        for page_list in page_lists {
592            heap.extend(page_list);
593        }
594        let mut all_pages = heap.into_sorted_vec();
595        all_pages.dedup();
596        all_pages
597    }
598
599    // All pages that could have a value in the range
600    fn pages_between(
601        &self,
602        range: (Bound<&OrderableScalarValue>, Bound<&OrderableScalarValue>),
603    ) -> Vec<u32> {
604        // We need to grab a little bit left of the given range because the query might be 7
605        // and the first page might be something like 5-10.
606        let lower_bound = match range.0 {
607            Bound::Unbounded => Bound::Unbounded,
608            // It doesn't matter if the bound is exclusive or inclusive.  We are going to grab
609            // the first node whose min is strictly less than the given bound.  Then we grab
610            // all nodes greater than or equal to that
611            //
612            // We have to peek a bit to the left because we might have something like a lower
613            // bound of 7 and there is a page [5-10] we want to search for.
614            Bound::Included(lower) => self
615                .tree
616                .largest_node_less(lower)
617                .map(|val| Bound::Included(val.0))
618                .unwrap_or(Bound::Unbounded),
619            Bound::Excluded(lower) => self
620                .tree
621                .largest_node_less(lower)
622                .map(|val| Bound::Included(val.0))
623                .unwrap_or(Bound::Unbounded),
624        };
625        let upper_bound = match range.1 {
626            Bound::Unbounded => Bound::Unbounded,
627            Bound::Included(upper) => Bound::Included(upper),
628            // Even if the upper bound is excluded we need to include it on an [x, x) query.  This is because the
629            // query might be [x, x).  Our lower bound might find some [a-x] bucket and we still
630            // want to include any [x, z] bucket.
631            //
632            // We could be slightly more accurate here and only include the upper bound if the lower bound
633            // is defined, inclusive, and equal to the upper bound.  However, let's keep it simple for now.  This
634            // should only affect the probably rare case that our query is a true range query and the value
635            // matches an upper bound.  This will all be moot if/when we merge pages.
636            Bound::Excluded(upper) => Bound::Included(upper),
637        };
638
639        match (lower_bound, upper_bound) {
640            (Bound::Excluded(lower), Bound::Excluded(upper))
641            | (Bound::Excluded(lower), Bound::Included(upper))
642            | (Bound::Included(lower), Bound::Excluded(upper)) => {
643                // It's not really clear what (Included(5), Excluded(5)) would mean so we
644                // interpret it as an empty range which matches rust's BTreeMap behavior
645                if lower >= upper {
646                    return vec![];
647                }
648            }
649            (Bound::Included(lower), Bound::Included(upper)) => {
650                if lower > upper {
651                    return vec![];
652                }
653            }
654            _ => {}
655        }
656
657        let candidates = self
658            .tree
659            .range((lower_bound, upper_bound))
660            .flat_map(|val| val.1);
661        match lower_bound {
662            Bound::Unbounded => candidates.map(|val| val.page_number).collect(),
663            Bound::Included(lower_bound) => candidates
664                .filter(|val| val.max.cmp(lower_bound) != Ordering::Less)
665                .map(|val| val.page_number)
666                .collect(),
667            Bound::Excluded(lower_bound) => candidates
668                .filter(|val| val.max.cmp(lower_bound) == Ordering::Greater)
669                .map(|val| val.page_number)
670                .collect(),
671        }
672    }
673
674    fn pages_null(&self) -> Vec<u32> {
675        self.null_pages.clone()
676    }
677}
678
679// Caches btree pages in memory
680#[derive(Debug)]
681struct BTreeCache(Cache<u32, Arc<dyn ScalarIndex>>);
682
683impl DeepSizeOf for BTreeCache {
684    fn deep_size_of_children(&self, _: &mut Context) -> usize {
685        self.0.iter().map(|(_, v)| v.deep_size_of()).sum()
686    }
687}
688
689// We only need to open a file reader for pages if we need to load a page.  If all
690// pages are cached we don't open it.  If we do open it we should only open it once.
691#[derive(Clone)]
692struct LazyIndexReader {
693    index_reader: Arc<tokio::sync::Mutex<Option<Arc<dyn IndexReader>>>>,
694    store: Arc<dyn IndexStore>,
695}
696
697impl LazyIndexReader {
698    fn new(store: Arc<dyn IndexStore>) -> Self {
699        Self {
700            index_reader: Arc::new(tokio::sync::Mutex::new(None)),
701            store,
702        }
703    }
704
705    async fn get(&self) -> Result<Arc<dyn IndexReader>> {
706        let mut reader = self.index_reader.lock().await;
707        if reader.is_none() {
708            let index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?;
709            *reader = Some(index_reader);
710        }
711        Ok(reader.as_ref().unwrap().clone())
712    }
713}
714
715/// A btree index satisfies scalar queries using a b tree
716///
717/// The upper layers of the btree are expected to be cached and, when unloaded,
718/// are stored in a btree structure in memory.  The leaves of the btree are left
719/// to be searched by some other kind of index (currently a flat search).
720///
721/// This strikes a balance between an expensive memory structure containing all
722/// of the values and an expensive disk structure that can't be efficiently searched.
723///
724/// For example, given 1Bi values we can store 256Ki leaves of size 4Ki.  We only
725/// need memory space for 256Ki leaves (depends on the data type but usually a few MiB
726/// at most) and can narrow our search to 4Ki values.
727///
728/// Note: this is very similar to the IVF index except we store the IVF part in a btree
729/// for faster lookup
730#[derive(Clone, Debug, DeepSizeOf)]
731pub struct BTreeIndex {
732    page_lookup: Arc<BTreeLookup>,
733    page_cache: Arc<BTreeCache>,
734    store: Arc<dyn IndexStore>,
735    sub_index: Arc<dyn BTreeSubIndex>,
736    batch_size: u64,
737    fri: Option<Arc<FragReuseIndex>>,
738}
739
740impl BTreeIndex {
741    fn new(
742        tree: BTreeMap<OrderableScalarValue, Vec<PageRecord>>,
743        null_pages: Vec<u32>,
744        store: Arc<dyn IndexStore>,
745        sub_index: Arc<dyn BTreeSubIndex>,
746        batch_size: u64,
747        fri: Option<Arc<FragReuseIndex>>,
748    ) -> Self {
749        let page_lookup = Arc::new(BTreeLookup::new(tree, null_pages));
750        let page_cache = Arc::new(BTreeCache(
751            Cache::builder()
752                .max_capacity(*CACHE_SIZE)
753                .weigher(|_, v: &Arc<dyn ScalarIndex>| v.deep_size_of() as u32)
754                .build(),
755        ));
756        Self {
757            page_lookup,
758            page_cache,
759            store,
760            sub_index,
761            batch_size,
762            fri,
763        }
764    }
765
766    async fn lookup_page(
767        &self,
768        page_number: u32,
769        index_reader: LazyIndexReader,
770        metrics: &dyn MetricsCollector,
771    ) -> Result<Arc<dyn ScalarIndex>> {
772        if let Some(cached) = self.page_cache.0.get(&page_number) {
773            return Ok(cached);
774        }
775        metrics.record_part_load();
776        info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="btree", part_id=page_number);
777        let index_reader = index_reader.get().await?;
778        let mut serialized_page = index_reader
779            .read_record_batch(page_number as u64, self.batch_size)
780            .await?;
781        if let Some(fri_ref) = self.fri.as_ref() {
782            serialized_page = fri_ref.remap_row_ids_record_batch(serialized_page, 1)?;
783        }
784        let subindex = self.sub_index.load_subindex(serialized_page).await?;
785        self.page_cache.0.insert(page_number, subindex.clone());
786        Ok(subindex)
787    }
788
789    async fn search_page(
790        &self,
791        query: &SargableQuery,
792        page_number: u32,
793        index_reader: LazyIndexReader,
794        metrics: &dyn MetricsCollector,
795    ) -> Result<RowIdTreeMap> {
796        let subindex = self.lookup_page(page_number, index_reader, metrics).await?;
797        // TODO: If this is an IN query we can perhaps simplify the subindex query by restricting it to the
798        // values that might be in the page.  E.g. if we are searching for X IN [5, 3, 7] and five is in pages
799        // 1 and 2 and three is in page 2 and seven is in pages 8 and 9 then when we search page 2 we only need
800        // to search for X IN [5, 3]
801        match subindex.search(query, metrics).await? {
802            SearchResult::Exact(map) => Ok(map),
803            _ => Err(Error::Internal {
804                message: "BTree sub-indices need to return exact results".to_string(),
805                location: location!(),
806            }),
807        }
808    }
809
810    fn try_from_serialized(
811        data: RecordBatch,
812        store: Arc<dyn IndexStore>,
813        batch_size: u64,
814        fri: Option<Arc<FragReuseIndex>>,
815    ) -> Result<Self> {
816        let mut map = BTreeMap::<OrderableScalarValue, Vec<PageRecord>>::new();
817        let mut null_pages = Vec::<u32>::new();
818
819        if data.num_rows() == 0 {
820            let data_type = data.column(0).data_type().clone();
821            let sub_index = Arc::new(FlatIndexMetadata::new(data_type));
822            return Ok(Self::new(
823                map, null_pages, store, sub_index, batch_size, fri,
824            ));
825        }
826
827        let mins = data.column(0);
828        let maxs = data.column(1);
829        let null_counts = data
830            .column(2)
831            .as_any()
832            .downcast_ref::<UInt32Array>()
833            .unwrap();
834        let page_numbers = data
835            .column(3)
836            .as_any()
837            .downcast_ref::<UInt32Array>()
838            .unwrap();
839
840        for idx in 0..data.num_rows() {
841            let min = OrderableScalarValue(ScalarValue::try_from_array(&mins, idx)?);
842            let max = OrderableScalarValue(ScalarValue::try_from_array(&maxs, idx)?);
843            let null_count = null_counts.values()[idx];
844            let page_number = page_numbers.values()[idx];
845
846            // If the page is entirely null don't even bother putting it in the tree
847            if !max.0.is_null() {
848                map.entry(min)
849                    .or_default()
850                    .push(PageRecord { max, page_number });
851            }
852
853            if null_count > 0 {
854                null_pages.push(page_number);
855            }
856        }
857
858        let last_max = ScalarValue::try_from_array(&maxs, data.num_rows() - 1)?;
859        map.entry(OrderableScalarValue(last_max)).or_default();
860
861        let data_type = mins.data_type();
862
863        // TODO: Support other page types?
864        let sub_index = Arc::new(FlatIndexMetadata::new(data_type.clone()));
865
866        Ok(Self::new(
867            map, null_pages, store, sub_index, batch_size, fri,
868        ))
869    }
870
871    /// Create a stream of all the data in the index, in the same format used to train the index
872    async fn into_data_stream(self) -> Result<impl RecordBatchStream> {
873        let reader = self.store.open_index_file(BTREE_PAGES_NAME).await?;
874        let schema = self.sub_index.schema().clone();
875        let reader_stream = IndexReaderStream::new(reader, self.batch_size).await;
876        let batches = reader_stream
877            .map(|fut| fut.map_err(DataFusionError::from))
878            .buffered(self.store.io_parallelism())
879            .boxed();
880        Ok(RecordBatchStreamAdapter::new(schema, batches))
881    }
882}
883
884fn wrap_bound(bound: &Bound<ScalarValue>) -> Bound<OrderableScalarValue> {
885    match bound {
886        Bound::Unbounded => Bound::Unbounded,
887        Bound::Included(val) => Bound::Included(OrderableScalarValue(val.clone())),
888        Bound::Excluded(val) => Bound::Excluded(OrderableScalarValue(val.clone())),
889    }
890}
891
892fn serialize_with_display<T: Display, S: Serializer>(
893    value: &Option<T>,
894    serializer: S,
895) -> std::result::Result<S::Ok, S::Error> {
896    if let Some(value) = value {
897        serializer.collect_str(value)
898    } else {
899        serializer.collect_str("N/A")
900    }
901}
902
903#[derive(Serialize)]
904struct BTreeStatistics {
905    #[serde(serialize_with = "serialize_with_display")]
906    min: Option<OrderableScalarValue>,
907    #[serde(serialize_with = "serialize_with_display")]
908    max: Option<OrderableScalarValue>,
909    num_pages: u32,
910}
911
912#[async_trait]
913impl Index for BTreeIndex {
914    fn as_any(&self) -> &dyn Any {
915        self
916    }
917
918    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
919        self
920    }
921
922    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
923        Err(Error::NotSupported {
924            source: "BTreeIndex is not vector index".into(),
925            location: location!(),
926        })
927    }
928
929    async fn prewarm(&self) -> Result<()> {
930        // TODO: BTree can (and should) support pre-warming by loading the pages into memory
931        Ok(())
932    }
933
934    fn index_type(&self) -> IndexType {
935        IndexType::BTree
936    }
937
938    fn statistics(&self) -> Result<serde_json::Value> {
939        let min = self
940            .page_lookup
941            .tree
942            .first_key_value()
943            .map(|(k, _)| k.clone());
944        let max = self
945            .page_lookup
946            .tree
947            .last_key_value()
948            .map(|(k, _)| k.clone());
949        serde_json::to_value(&BTreeStatistics {
950            num_pages: self.page_lookup.tree.len() as u32,
951            min,
952            max,
953        })
954        .map_err(|err| err.into())
955    }
956
957    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
958        let mut frag_ids = RoaringBitmap::default();
959
960        let sub_index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?;
961        let mut reader_stream = IndexReaderStream::new(sub_index_reader, self.batch_size)
962            .await
963            .buffered(self.store.io_parallelism());
964        while let Some(serialized) = reader_stream.try_next().await? {
965            let page = self.sub_index.load_subindex(serialized).await?;
966            frag_ids |= page.calculate_included_frags().await?;
967        }
968
969        Ok(frag_ids)
970    }
971}
972
973#[async_trait]
974impl ScalarIndex for BTreeIndex {
975    async fn search(
976        &self,
977        query: &dyn AnyQuery,
978        metrics: &dyn MetricsCollector,
979    ) -> Result<SearchResult> {
980        let query = query.as_any().downcast_ref::<SargableQuery>().unwrap();
981        let pages = match query {
982            SargableQuery::Equals(val) => self
983                .page_lookup
984                .pages_eq(&OrderableScalarValue(val.clone())),
985            SargableQuery::Range(start, end) => self
986                .page_lookup
987                .pages_between((wrap_bound(start).as_ref(), wrap_bound(end).as_ref())),
988            SargableQuery::IsIn(values) => self
989                .page_lookup
990                .pages_in(values.iter().map(|val| OrderableScalarValue(val.clone()))),
991            SargableQuery::FullTextSearch(_) => return Err(Error::invalid_input(
992                "full text search is not supported for BTree index, build a inverted index for it",
993                location!(),
994            )),
995            SargableQuery::IsNull() => self.page_lookup.pages_null(),
996        };
997        let lazy_index_reader = LazyIndexReader::new(self.store.clone());
998        let page_tasks = pages
999            .into_iter()
1000            .map(|page_index| {
1001                self.search_page(query, page_index, lazy_index_reader.clone(), metrics)
1002                    .boxed()
1003            })
1004            .collect::<Vec<_>>();
1005        debug!("Searching {} btree pages", page_tasks.len());
1006        let row_ids = stream::iter(page_tasks)
1007            // I/O and compute mixed here but important case is index in cache so
1008            // use compute intensive thread count
1009            .buffered(get_num_compute_intensive_cpus())
1010            .try_collect::<RowIdTreeMap>()
1011            .await?;
1012        Ok(SearchResult::Exact(row_ids))
1013    }
1014
1015    fn can_answer_exact(&self, _: &dyn AnyQuery) -> bool {
1016        true
1017    }
1018
1019    async fn load(
1020        store: Arc<dyn IndexStore>,
1021        fri: Option<Arc<FragReuseIndex>>,
1022    ) -> Result<Arc<Self>> {
1023        let page_lookup_file = store.open_index_file(BTREE_LOOKUP_NAME).await?;
1024        let num_rows_in_lookup = page_lookup_file.num_rows();
1025        let serialized_lookup = page_lookup_file
1026            .read_range(0..num_rows_in_lookup, None)
1027            .await?;
1028        let file_schema = page_lookup_file.schema();
1029        let batch_size = file_schema
1030            .metadata
1031            .get(BATCH_SIZE_META_KEY)
1032            .map(|bs| bs.parse().unwrap_or(DEFAULT_BTREE_BATCH_SIZE))
1033            .unwrap_or(DEFAULT_BTREE_BATCH_SIZE);
1034        Ok(Arc::new(Self::try_from_serialized(
1035            serialized_lookup,
1036            store,
1037            batch_size,
1038            fri,
1039        )?))
1040    }
1041
1042    async fn remap(
1043        &self,
1044        mapping: &HashMap<u64, Option<u64>>,
1045        dest_store: &dyn IndexStore,
1046    ) -> Result<()> {
1047        // Remap and write the pages
1048        let mut sub_index_file = dest_store
1049            .new_index_file(BTREE_PAGES_NAME, self.sub_index.schema().clone())
1050            .await?;
1051
1052        let sub_index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?;
1053        let mut reader_stream = IndexReaderStream::new(sub_index_reader, self.batch_size)
1054            .await
1055            .buffered(self.store.io_parallelism());
1056        while let Some(serialized) = reader_stream.try_next().await? {
1057            let remapped = self.sub_index.remap_subindex(serialized, mapping).await?;
1058            sub_index_file.write_record_batch(remapped).await?;
1059        }
1060
1061        sub_index_file.finish().await?;
1062
1063        // Copy the lookup file as-is
1064        self.store
1065            .copy_index_file(BTREE_LOOKUP_NAME, dest_store)
1066            .await
1067    }
1068
1069    async fn update(
1070        &self,
1071        new_data: SendableRecordBatchStream,
1072        dest_store: &dyn IndexStore,
1073    ) -> Result<()> {
1074        // Merge the existing index data with the new data and then retrain the index on the merged stream
1075        let merged_data_source = Box::new(BTreeUpdater::new(self.clone(), new_data));
1076        train_btree_index(
1077            merged_data_source,
1078            self.sub_index.as_ref(),
1079            dest_store,
1080            DEFAULT_BTREE_BATCH_SIZE as u32,
1081        )
1082        .await
1083    }
1084}
1085
1086struct BatchStats {
1087    min: ScalarValue,
1088    max: ScalarValue,
1089    null_count: u32,
1090}
1091
1092fn analyze_batch(batch: &RecordBatch) -> Result<BatchStats> {
1093    let values = batch.column(0);
1094    if values.is_empty() {
1095        return Err(Error::Internal {
1096            message: "received an empty batch in btree training".to_string(),
1097            location: location!(),
1098        });
1099    }
1100    let min = ScalarValue::try_from_array(&values, 0).map_err(|e| Error::Internal {
1101        message: format!("failed to get min value from batch: {}", e),
1102        location: location!(),
1103    })?;
1104    let max =
1105        ScalarValue::try_from_array(&values, values.len() - 1).map_err(|e| Error::Internal {
1106            message: format!("failed to get max value from batch: {}", e),
1107            location: location!(),
1108        })?;
1109
1110    Ok(BatchStats {
1111        min,
1112        max,
1113        null_count: values.null_count() as u32,
1114    })
1115}
1116
1117/// A trait that must be implemented by anything that wishes to act as a btree subindex
1118#[async_trait]
1119pub trait BTreeSubIndex: Debug + Send + Sync + DeepSizeOf {
1120    /// Trains the subindex on a single batch of data and serializes it to Arrow
1121    async fn train(&self, batch: RecordBatch) -> Result<RecordBatch>;
1122
1123    /// Deserialize a subindex from Arrow
1124    async fn load_subindex(&self, serialized: RecordBatch) -> Result<Arc<dyn ScalarIndex>>;
1125
1126    /// Retrieve the data used to originally train this page
1127    ///
1128    /// In order to perform an update we need to merge the old data in with the new data which
1129    /// means we need to access the new data.  Right now this is convenient for flat indices but
1130    /// we may need to take a different approach if we ever decide to use a sub-index other than
1131    /// flat
1132    async fn retrieve_data(&self, serialized: RecordBatch) -> Result<RecordBatch>;
1133
1134    /// The schema of the subindex when serialized to Arrow
1135    fn schema(&self) -> &Arc<Schema>;
1136
1137    /// Given a serialized page, deserialize it, remap the row ids, and re-serialize it
1138    async fn remap_subindex(
1139        &self,
1140        serialized: RecordBatch,
1141        mapping: &HashMap<u64, Option<u64>>,
1142    ) -> Result<RecordBatch>;
1143}
1144
1145struct EncodedBatch {
1146    stats: BatchStats,
1147    page_number: u32,
1148}
1149
1150async fn train_btree_page(
1151    batch: RecordBatch,
1152    batch_idx: u32,
1153    sub_index_trainer: &dyn BTreeSubIndex,
1154    writer: &mut dyn IndexWriter,
1155) -> Result<EncodedBatch> {
1156    let stats = analyze_batch(&batch)?;
1157    let trained = sub_index_trainer.train(batch).await?;
1158    writer.write_record_batch(trained).await?;
1159    Ok(EncodedBatch {
1160        stats,
1161        page_number: batch_idx,
1162    })
1163}
1164
1165fn btree_stats_as_batch(stats: Vec<EncodedBatch>, value_type: &DataType) -> Result<RecordBatch> {
1166    let mins = if stats.is_empty() {
1167        new_empty_array(value_type)
1168    } else {
1169        ScalarValue::iter_to_array(stats.iter().map(|stat| stat.stats.min.clone()))?
1170    };
1171    let maxs = if stats.is_empty() {
1172        new_empty_array(value_type)
1173    } else {
1174        ScalarValue::iter_to_array(stats.iter().map(|stat| stat.stats.max.clone()))?
1175    };
1176    let null_counts = UInt32Array::from_iter_values(stats.iter().map(|stat| stat.stats.null_count));
1177    let page_numbers = UInt32Array::from_iter_values(stats.iter().map(|stat| stat.page_number));
1178
1179    let schema = Arc::new(Schema::new(vec![
1180        // min and max can be null if the entire batch is null values
1181        Field::new("min", mins.data_type().clone(), true),
1182        Field::new("max", maxs.data_type().clone(), true),
1183        Field::new("null_count", null_counts.data_type().clone(), false),
1184        Field::new("page_idx", page_numbers.data_type().clone(), false),
1185    ]));
1186
1187    let columns = vec![
1188        mins,
1189        maxs,
1190        Arc::new(null_counts) as Arc<dyn Array>,
1191        Arc::new(page_numbers) as Arc<dyn Array>,
1192    ];
1193
1194    Ok(RecordBatch::try_new(schema, columns)?)
1195}
1196
1197#[async_trait]
1198pub trait TrainingSource: Send {
1199    /// Returns a stream of batches, ordered by the value column (in ascending order)
1200    ///
1201    /// Each batch should have chunk_size rows
1202    ///
1203    /// The schema for the batch is slightly flexible.
1204    /// The first column may have any name or type, these are the values to index
1205    /// The second column must be the row ids which must be UInt64Type
1206    async fn scan_ordered_chunks(
1207        self: Box<Self>,
1208        chunk_size: u32,
1209    ) -> Result<SendableRecordBatchStream>;
1210
1211    /// Returns a stream of batches
1212    ///
1213    /// Each batch should have chunk_size rows
1214    ///
1215    /// The schema for the batch is slightly flexible.
1216    /// The first column may have any name or type, these are the values to index
1217    /// The second column must be the row ids which must be UInt64Type
1218    async fn scan_unordered_chunks(
1219        self: Box<Self>,
1220        chunk_size: u32,
1221    ) -> Result<SendableRecordBatchStream>;
1222}
1223
1224/// Train a btree index from a stream of sorted page-size batches of values and row ids
1225///
1226/// Note: This is likely to change.  It is unreasonable to expect the caller to do the sorting
1227/// and re-chunking into page-size batches.  This is left for simplicity as this feature is still
1228/// a work in progress
1229pub async fn train_btree_index(
1230    data_source: Box<dyn TrainingSource + Send>,
1231    sub_index_trainer: &dyn BTreeSubIndex,
1232    index_store: &dyn IndexStore,
1233    batch_size: u32,
1234) -> Result<()> {
1235    let mut sub_index_file = index_store
1236        .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone())
1237        .await?;
1238    let mut encoded_batches = Vec::new();
1239    let mut batch_idx = 0;
1240    let mut batches_source = data_source.scan_ordered_chunks(batch_size).await?;
1241    let value_type = batches_source.schema().field(0).data_type().clone();
1242    while let Some(batch) = batches_source.try_next().await? {
1243        debug_assert_eq!(batch.num_columns(), 2);
1244        debug_assert_eq!(*batch.column(1).data_type(), DataType::UInt64);
1245        encoded_batches.push(
1246            train_btree_page(batch, batch_idx, sub_index_trainer, sub_index_file.as_mut()).await?,
1247        );
1248        batch_idx += 1;
1249    }
1250    sub_index_file.finish().await?;
1251    let record_batch = btree_stats_as_batch(encoded_batches, &value_type)?;
1252    let mut file_schema = record_batch.schema().as_ref().clone();
1253    file_schema
1254        .metadata
1255        .insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string());
1256    let mut btree_index_file = index_store
1257        .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema))
1258        .await?;
1259    btree_index_file.write_record_batch(record_batch).await?;
1260    btree_index_file.finish().await?;
1261    Ok(())
1262}
1263
1264/// A source of training data created by merging existing data with new data
1265struct BTreeUpdater {
1266    index: BTreeIndex,
1267    new_data: SendableRecordBatchStream,
1268}
1269
1270impl BTreeUpdater {
1271    fn new(index: BTreeIndex, new_data: SendableRecordBatchStream) -> Self {
1272        Self { index, new_data }
1273    }
1274}
1275
1276impl BTreeUpdater {
1277    fn into_old_input(index: BTreeIndex) -> Arc<dyn ExecutionPlan> {
1278        let schema = index.sub_index.schema().clone();
1279        let batches = index.into_data_stream().into_stream().try_flatten().boxed();
1280        let stream = Box::pin(RecordBatchStreamAdapter::new(schema, batches));
1281        Arc::new(OneShotExec::new(stream))
1282    }
1283}
1284
1285#[async_trait]
1286impl TrainingSource for BTreeUpdater {
1287    async fn scan_ordered_chunks(
1288        self: Box<Self>,
1289        chunk_size: u32,
1290    ) -> Result<SendableRecordBatchStream> {
1291        let data_type = self.new_data.schema().field(0).data_type().clone();
1292        // Datafusion currently has bugs with spilling on string columns
1293        // See https://github.com/apache/datafusion/issues/10073
1294        //
1295        // One we upgrade we can remove this
1296        let use_spilling = !matches!(data_type, DataType::Utf8 | DataType::LargeUtf8);
1297
1298        let new_input = Arc::new(OneShotExec::new(self.new_data));
1299        let old_input = Self::into_old_input(self.index);
1300        debug_assert_eq!(
1301            old_input.schema().flattened_fields().len(),
1302            new_input.schema().flattened_fields().len()
1303        );
1304        let sort_expr = PhysicalSortExpr {
1305            expr: Arc::new(Column::new("values", 0)),
1306            options: SortOptions {
1307                descending: false,
1308                nulls_first: true,
1309            },
1310        };
1311        // The UnionExec creates multiple partitions but the SortPreservingMergeExec merges
1312        // them back into a single partition.
1313        let all_data = Arc::new(UnionExec::new(vec![old_input, new_input]));
1314        let ordered = Arc::new(SortPreservingMergeExec::new(
1315            LexOrdering::new(vec![sort_expr]),
1316            all_data,
1317        ));
1318
1319        let unchunked = execute_plan(
1320            ordered,
1321            LanceExecutionOptions {
1322                use_spilling,
1323                ..Default::default()
1324            },
1325        )?;
1326        Ok(chunk_concat_stream(unchunked, chunk_size as usize))
1327    }
1328
1329    async fn scan_unordered_chunks(
1330        self: Box<Self>,
1331        _chunk_size: u32,
1332    ) -> Result<SendableRecordBatchStream> {
1333        // BTree indices will never use unordered scans
1334        unimplemented!()
1335    }
1336}
1337
1338/// A stream that reads the original training data back out of the index
1339///
1340/// This is used for updating the index
1341struct IndexReaderStream {
1342    reader: Arc<dyn IndexReader>,
1343    batch_size: u64,
1344    num_batches: u32,
1345    batch_idx: u32,
1346}
1347
1348impl IndexReaderStream {
1349    async fn new(reader: Arc<dyn IndexReader>, batch_size: u64) -> Self {
1350        let num_batches = reader.num_batches(batch_size).await;
1351        Self {
1352            reader,
1353            batch_size,
1354            num_batches,
1355            batch_idx: 0,
1356        }
1357    }
1358}
1359
1360impl Stream for IndexReaderStream {
1361    type Item = BoxFuture<'static, Result<RecordBatch>>;
1362
1363    fn poll_next(
1364        self: std::pin::Pin<&mut Self>,
1365        _cx: &mut std::task::Context<'_>,
1366    ) -> std::task::Poll<Option<Self::Item>> {
1367        let this = self.get_mut();
1368        if this.batch_idx >= this.num_batches {
1369            return std::task::Poll::Ready(None);
1370        }
1371        let batch_num = this.batch_idx;
1372        this.batch_idx += 1;
1373        let reader_copy = this.reader.clone();
1374        let batch_size = this.batch_size;
1375        let read_task = async move {
1376            reader_copy
1377                .read_record_batch(batch_num as u64, batch_size)
1378                .await
1379        }
1380        .boxed();
1381        std::task::Poll::Ready(Some(read_task))
1382    }
1383}
1384
1385#[cfg(test)]
1386mod tests {
1387    use std::{collections::HashMap, sync::Arc};
1388
1389    use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type};
1390    use arrow_array::FixedSizeListArray;
1391    use arrow_schema::DataType;
1392    use datafusion::{
1393        execution::{SendableRecordBatchStream, TaskContext},
1394        physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan},
1395    };
1396    use datafusion_common::{DataFusionError, ScalarValue};
1397    use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr};
1398    use deepsize::DeepSizeOf;
1399    use futures::TryStreamExt;
1400    use lance_core::{cache::LanceCache, utils::mask::RowIdTreeMap};
1401    use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt};
1402    use lance_datagen::{array, gen, ArrayGeneratorExt, BatchCount, RowCount};
1403    use lance_io::object_store::ObjectStore;
1404    use object_store::path::Path;
1405    use tempfile::tempdir;
1406
1407    use crate::{
1408        metrics::NoOpMetricsCollector,
1409        scalar::{
1410            btree::{BTreeIndex, BTREE_PAGES_NAME, DEFAULT_BTREE_BATCH_SIZE},
1411            flat::FlatIndexMetadata,
1412            lance_format::{tests::MockTrainingSource, LanceIndexStore},
1413            IndexStore, SargableQuery, ScalarIndex, SearchResult,
1414        },
1415    };
1416
1417    use super::{train_btree_index, OrderableScalarValue};
1418
1419    #[test]
1420    fn test_scalar_value_size() {
1421        let size_of_i32 = OrderableScalarValue(ScalarValue::Int32(Some(0))).deep_size_of();
1422        let size_of_many_i32 = OrderableScalarValue(ScalarValue::FixedSizeList(Arc::new(
1423            FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
1424                vec![Some(vec![Some(0); 128])],
1425                128,
1426            ),
1427        )))
1428        .deep_size_of();
1429
1430        // deep_size_of should account for the rust type overhead
1431        assert!(size_of_i32 > 4);
1432        assert!(size_of_many_i32 > 128 * 4);
1433    }
1434
1435    #[tokio::test]
1436    async fn test_null_ids() {
1437        let tmpdir = Arc::new(tempdir().unwrap());
1438        let test_store = Arc::new(LanceIndexStore::new(
1439            Arc::new(ObjectStore::local()),
1440            Path::from_filesystem_path(tmpdir.path()).unwrap(),
1441            Arc::new(LanceCache::no_cache()),
1442        ));
1443
1444        // Generate 50,000 rows of random data with 80% nulls
1445        let stream = gen()
1446            .col(
1447                "value",
1448                array::rand::<Float32Type>().with_nulls(&[true, false, false, false, false]),
1449            )
1450            .col("_rowid", array::step::<UInt64Type>())
1451            .into_df_stream(RowCount::from(5000), BatchCount::from(10));
1452        let data_source = Box::new(MockTrainingSource::from(stream));
1453        let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32);
1454
1455        train_btree_index(
1456            data_source,
1457            &sub_index_trainer,
1458            test_store.as_ref(),
1459            DEFAULT_BTREE_BATCH_SIZE as u32,
1460        )
1461        .await
1462        .unwrap();
1463
1464        let index = BTreeIndex::load(test_store.clone(), None).await.unwrap();
1465
1466        assert_eq!(index.page_lookup.null_pages.len(), 10);
1467
1468        let remap_dir = Arc::new(tempdir().unwrap());
1469        let remap_store = Arc::new(LanceIndexStore::new(
1470            Arc::new(ObjectStore::local()),
1471            Path::from_filesystem_path(remap_dir.path()).unwrap(),
1472            Arc::new(LanceCache::no_cache()),
1473        ));
1474
1475        // Remap with a no-op mapping.  The remapped index should be identical to the original
1476        index
1477            .remap(&HashMap::default(), remap_store.as_ref())
1478            .await
1479            .unwrap();
1480
1481        let remap_index = BTreeIndex::load(remap_store.clone(), None).await.unwrap();
1482
1483        assert_eq!(remap_index.page_lookup, index.page_lookup);
1484
1485        let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap();
1486        let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap();
1487
1488        assert_eq!(original_pages.num_rows(), remapped_pages.num_rows());
1489
1490        let original_data = original_pages
1491            .read_record_batch(0, original_pages.num_rows() as u64)
1492            .await
1493            .unwrap();
1494        let remapped_data = remapped_pages
1495            .read_record_batch(0, remapped_pages.num_rows() as u64)
1496            .await
1497            .unwrap();
1498
1499        assert_eq!(original_data, remapped_data);
1500    }
1501
1502    #[tokio::test]
1503    async fn test_nan_ordering() {
1504        let tmpdir = Arc::new(tempdir().unwrap());
1505        let test_store = Arc::new(LanceIndexStore::new(
1506            Arc::new(ObjectStore::local()),
1507            Path::from_filesystem_path(tmpdir.path()).unwrap(),
1508            Arc::new(LanceCache::no_cache()),
1509        ));
1510
1511        let values = vec![
1512            0.0,
1513            1.0,
1514            2.0,
1515            3.0,
1516            f64::NAN,
1517            f64::NEG_INFINITY,
1518            f64::INFINITY,
1519        ];
1520
1521        // This is a bit overkill but we've had bugs in the past where DF's sort
1522        // didn't agree with Arrow's sort so we do an end-to-end test here
1523        // and use DF to sort the data like we would in a real dataset.
1524        let data = gen()
1525            .col("value", array::cycle::<Float64Type>(values.clone()))
1526            .col("_rowid", array::step::<UInt64Type>())
1527            .into_df_exec(RowCount::from(10), BatchCount::from(100));
1528        let schema = data.schema();
1529        let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap());
1530        let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data));
1531        let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap();
1532        let stream = break_stream(stream, 64);
1533        let stream = stream.map_err(DataFusionError::from);
1534        let stream =
1535            Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream;
1536        let data_source = Box::new(MockTrainingSource::from(stream));
1537
1538        let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64);
1539
1540        train_btree_index(data_source, &sub_index_trainer, test_store.as_ref(), 64)
1541            .await
1542            .unwrap();
1543
1544        let index = BTreeIndex::load(test_store, None).await.unwrap();
1545
1546        for (idx, value) in values.into_iter().enumerate() {
1547            let query = SargableQuery::Equals(ScalarValue::Float64(Some(value)));
1548            let result = index.search(&query, &NoOpMetricsCollector).await.unwrap();
1549            assert_eq!(
1550                result,
1551                SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7)))
1552            );
1553        }
1554    }
1555}