use arrow_array::types::*;
use arrow_array::{Array, RecordBatch};
use arrow_schema::DataType;
use crossbeam_skiplist::SkipMap;
use datafusion::common::ScalarValue;
use lance_core::{Error, Result};
use lance_index::scalar::btree::OrderableScalarValue;
use super::RowPosition;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct IndexKey {
pub value: OrderableScalarValue,
pub row_position: RowPosition,
}
impl PartialOrd for IndexKey {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for IndexKey {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.value.cmp(&other.value) {
std::cmp::Ordering::Equal => self.row_position.cmp(&other.row_position),
ord => ord,
}
}
}
#[derive(Debug)]
pub struct BTreeMemIndex {
lookup: SkipMap<IndexKey, ()>,
field_id: i32,
column_name: String,
}
impl BTreeMemIndex {
pub fn new(field_id: i32, column_name: String) -> Self {
Self {
lookup: SkipMap::new(),
field_id,
column_name,
}
}
pub fn field_id(&self) -> i32 {
self.field_id
}
pub fn insert(&self, batch: &RecordBatch, row_offset: u64) -> Result<()> {
let col_idx = batch
.schema()
.column_with_name(&self.column_name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
Error::invalid_input(format!("Column '{}' not found in batch", self.column_name))
})?;
let column = batch.column(col_idx);
self.insert_array(column.as_ref(), row_offset)
}
fn insert_array(&self, array: &dyn Array, row_offset: u64) -> Result<()> {
macro_rules! insert_primitive {
($array_type:ty, $scalar_variant:ident) => {{
let typed_array = array
.as_any()
.downcast_ref::<arrow_array::PrimitiveArray<$array_type>>()
.unwrap();
for (row_idx, value) in typed_array.iter().enumerate() {
let row_position = row_offset + row_idx as u64;
let key = IndexKey {
value: OrderableScalarValue(ScalarValue::$scalar_variant(value)),
row_position,
};
self.lookup.insert(key, ());
}
}};
}
match array.data_type() {
DataType::Int8 => insert_primitive!(Int8Type, Int8),
DataType::Int16 => insert_primitive!(Int16Type, Int16),
DataType::Int32 => insert_primitive!(Int32Type, Int32),
DataType::Int64 => insert_primitive!(Int64Type, Int64),
DataType::UInt8 => insert_primitive!(UInt8Type, UInt8),
DataType::UInt16 => insert_primitive!(UInt16Type, UInt16),
DataType::UInt32 => insert_primitive!(UInt32Type, UInt32),
DataType::UInt64 => insert_primitive!(UInt64Type, UInt64),
DataType::Float32 => insert_primitive!(Float32Type, Float32),
DataType::Float64 => insert_primitive!(Float64Type, Float64),
DataType::Date32 => insert_primitive!(Date32Type, Date32),
DataType::Date64 => insert_primitive!(Date64Type, Date64),
DataType::Utf8 => {
let typed_array = array
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
for (row_idx, value) in typed_array.iter().enumerate() {
let row_position = row_offset + row_idx as u64;
let key = IndexKey {
value: OrderableScalarValue(ScalarValue::Utf8(
value.map(|s| s.to_string()),
)),
row_position,
};
self.lookup.insert(key, ());
}
}
DataType::LargeUtf8 => {
let typed_array = array
.as_any()
.downcast_ref::<arrow_array::LargeStringArray>()
.unwrap();
for (row_idx, value) in typed_array.iter().enumerate() {
let row_position = row_offset + row_idx as u64;
let key = IndexKey {
value: OrderableScalarValue(ScalarValue::LargeUtf8(
value.map(|s| s.to_string()),
)),
row_position,
};
self.lookup.insert(key, ());
}
}
DataType::Boolean => {
let typed_array = array
.as_any()
.downcast_ref::<arrow_array::BooleanArray>()
.unwrap();
for (row_idx, value) in typed_array.iter().enumerate() {
let row_position = row_offset + row_idx as u64;
let key = IndexKey {
value: OrderableScalarValue(ScalarValue::Boolean(value)),
row_position,
};
self.lookup.insert(key, ());
}
}
_ => {
for row_idx in 0..array.len() {
let value = ScalarValue::try_from_array(array, row_idx)?;
let row_position = row_offset + row_idx as u64;
let key = IndexKey {
value: OrderableScalarValue(value),
row_position,
};
self.lookup.insert(key, ());
}
}
}
Ok(())
}
pub fn get(&self, value: &ScalarValue) -> Vec<RowPosition> {
let orderable = OrderableScalarValue(value.clone());
let start = IndexKey {
value: orderable.clone(),
row_position: 0,
};
let end = IndexKey {
value: orderable,
row_position: u64::MAX,
};
self.lookup
.range(start..=end)
.map(|entry| entry.key().row_position)
.collect()
}
pub fn len(&self) -> usize {
self.lookup.len()
}
pub fn is_empty(&self) -> bool {
self.lookup.is_empty()
}
pub fn column_name(&self) -> &str {
&self.column_name
}
pub fn snapshot(&self) -> Vec<(OrderableScalarValue, Vec<RowPosition>)> {
let mut result: Vec<(OrderableScalarValue, Vec<RowPosition>)> = Vec::new();
for entry in self.lookup.iter() {
let key = entry.key();
if let Some(last) = result.last_mut()
&& last.0 == key.value
{
last.1.push(key.row_position);
continue;
}
result.push((key.value.clone(), vec![key.row_position]));
}
result
}
pub fn data_type(&self) -> Option<arrow_schema::DataType> {
self.lookup
.front()
.map(|entry| entry.key().value.0.data_type())
}
pub fn to_training_batches(&self, batch_size: usize) -> Result<Vec<RecordBatch>> {
use arrow_schema::{DataType, Field, Schema};
use lance_core::ROW_ID;
use lance_index::scalar::registry::VALUE_COLUMN_NAME;
use std::sync::Arc;
if self.lookup.is_empty() {
return Ok(vec![]);
}
let first_entry = self.lookup.front().unwrap();
let data_type = first_entry.key().value.0.data_type();
let schema = Arc::new(Schema::new(vec![
Field::new(VALUE_COLUMN_NAME, data_type, true),
Field::new(ROW_ID, DataType::UInt64, false),
]));
let mut batches = Vec::new();
let mut values: Vec<ScalarValue> = Vec::with_capacity(batch_size);
let mut row_ids: Vec<u64> = Vec::with_capacity(batch_size);
for entry in self.lookup.iter() {
let key = entry.key();
values.push(key.value.0.clone());
row_ids.push(key.row_position);
if values.len() >= batch_size {
let batch = self.build_training_batch(&schema, &values, &row_ids)?;
batches.push(batch);
values.clear();
row_ids.clear();
}
}
if !values.is_empty() {
let batch = self.build_training_batch(&schema, &values, &row_ids)?;
batches.push(batch);
}
Ok(batches)
}
pub fn to_training_batches_reversed(
&self,
batch_size: usize,
total_rows: usize,
) -> Result<Vec<RecordBatch>> {
use arrow_schema::{DataType, Field, Schema};
use lance_core::ROW_ID;
use lance_index::scalar::registry::VALUE_COLUMN_NAME;
use std::sync::Arc;
if self.lookup.is_empty() {
return Ok(vec![]);
}
let first_entry = self.lookup.front().unwrap();
let data_type = first_entry.key().value.0.data_type();
let schema = Arc::new(Schema::new(vec![
Field::new(VALUE_COLUMN_NAME, data_type, true),
Field::new(ROW_ID, DataType::UInt64, false),
]));
let total_rows_u64 = total_rows as u64;
let mut batches = Vec::new();
let mut values: Vec<ScalarValue> = Vec::with_capacity(batch_size);
let mut row_ids: Vec<u64> = Vec::with_capacity(batch_size);
for entry in self.lookup.iter() {
let key = entry.key();
values.push(key.value.0.clone());
let reversed_position = total_rows_u64 - key.row_position - 1;
row_ids.push(reversed_position);
if values.len() >= batch_size {
let batch = self.build_training_batch(&schema, &values, &row_ids)?;
batches.push(batch);
values.clear();
row_ids.clear();
}
}
if !values.is_empty() {
let batch = self.build_training_batch(&schema, &values, &row_ids)?;
batches.push(batch);
}
Ok(batches)
}
fn build_training_batch(
&self,
schema: &std::sync::Arc<arrow_schema::Schema>,
values: &[ScalarValue],
row_ids: &[u64],
) -> Result<RecordBatch> {
use arrow_array::UInt64Array;
use std::sync::Arc;
let value_array = ScalarValue::iter_to_array(values.iter().cloned())?;
let row_id_array = Arc::new(UInt64Array::from(row_ids.to_vec()));
RecordBatch::try_new(schema.clone(), vec![value_array, row_id_array])
.map_err(|e| Error::io(format!("Failed to create training batch: {}", e)))
}
}
#[derive(Debug, Clone)]
pub struct BTreeIndexConfig {
pub name: String,
pub field_id: i32,
pub column: String,
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use std::sync::Arc;
fn create_test_schema() -> Arc<ArrowSchema> {
Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &ArrowSchema, start_id: i32) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![start_id, start_id + 1, start_id + 2])),
Arc::new(StringArray::from(vec!["alice", "bob", "charlie"])),
],
)
.unwrap()
}
#[test]
fn test_btree_index_insert_and_lookup() {
let schema = create_test_schema();
let index = BTreeMemIndex::new(0, "id".to_string());
let batch = create_test_batch(&schema, 0);
index.insert(&batch, 0).unwrap();
assert_eq!(index.len(), 3);
let result = index.get(&ScalarValue::Int32(Some(0)));
assert!(!result.is_empty());
assert_eq!(result, vec![0]);
let result = index.get(&ScalarValue::Int32(Some(1)));
assert!(!result.is_empty());
assert_eq!(result, vec![1]);
}
#[test]
fn test_btree_index_multiple_batches() {
let schema = create_test_schema();
let index = BTreeMemIndex::new(0, "id".to_string());
let batch1 = create_test_batch(&schema, 0);
let batch2 = create_test_batch(&schema, 10);
index.insert(&batch1, 0).unwrap();
index.insert(&batch2, 3).unwrap();
assert_eq!(index.len(), 6);
let result = index.get(&ScalarValue::Int32(Some(10)));
assert!(!result.is_empty());
assert_eq!(result, vec![3]);
}
#[test]
fn test_btree_index_to_training_batches() {
use lance_core::ROW_ID;
use lance_index::scalar::registry::VALUE_COLUMN_NAME;
let schema = create_test_schema();
let index = BTreeMemIndex::new(0, "id".to_string());
let batch1 = create_test_batch(&schema, 0); let batch2 = create_test_batch(&schema, 10);
index.insert(&batch1, 0).unwrap(); index.insert(&batch2, 3).unwrap();
let batches = index.to_training_batches(100).unwrap();
assert_eq!(batches.len(), 1);
let batch = &batches[0];
assert_eq!(batch.num_rows(), 6);
assert_eq!(batch.schema().field(0).name(), VALUE_COLUMN_NAME);
assert_eq!(batch.schema().field(1).name(), ROW_ID);
let values = batch
.column_by_name(VALUE_COLUMN_NAME)
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(values.value(0), 0);
assert_eq!(values.value(1), 1);
assert_eq!(values.value(2), 2);
assert_eq!(values.value(3), 10);
assert_eq!(values.value(4), 11);
assert_eq!(values.value(5), 12);
let row_ids = batch
.column_by_name(ROW_ID)
.unwrap()
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.unwrap();
assert_eq!(row_ids.value(0), 0); assert_eq!(row_ids.value(1), 1); assert_eq!(row_ids.value(2), 2); assert_eq!(row_ids.value(3), 3); assert_eq!(row_ids.value(4), 4); assert_eq!(row_ids.value(5), 5); }
#[test]
fn test_btree_index_to_training_batches_reversed() {
use lance_core::ROW_ID;
use lance_index::scalar::registry::VALUE_COLUMN_NAME;
let schema = create_test_schema();
let index = BTreeMemIndex::new(0, "id".to_string());
let batch1 = create_test_batch(&schema, 0); let batch2 = create_test_batch(&schema, 10);
index.insert(&batch1, 0).unwrap(); index.insert(&batch2, 3).unwrap();
let batches = index.to_training_batches_reversed(100, 6).unwrap();
assert_eq!(batches.len(), 1);
let batch = &batches[0];
assert_eq!(batch.num_rows(), 6);
let values = batch
.column_by_name(VALUE_COLUMN_NAME)
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(values.value(0), 0);
assert_eq!(values.value(1), 1);
assert_eq!(values.value(2), 2);
assert_eq!(values.value(3), 10);
assert_eq!(values.value(4), 11);
assert_eq!(values.value(5), 12);
let row_ids = batch
.column_by_name(ROW_ID)
.unwrap()
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.unwrap();
assert_eq!(row_ids.value(0), 5); assert_eq!(row_ids.value(1), 4); assert_eq!(row_ids.value(2), 3); assert_eq!(row_ids.value(3), 2); assert_eq!(row_ids.value(4), 1); assert_eq!(row_ids.value(5), 0); }
#[test]
fn test_btree_index_snapshot() {
let schema = create_test_schema();
let index = BTreeMemIndex::new(0, "id".to_string());
let batch = create_test_batch(&schema, 0);
index.insert(&batch, 0).unwrap();
let snapshot = index.snapshot();
assert_eq!(snapshot.len(), 3);
assert_eq!(snapshot[0].0.0, ScalarValue::Int32(Some(0)));
assert_eq!(snapshot[1].0.0, ScalarValue::Int32(Some(1)));
assert_eq!(snapshot[2].0.0, ScalarValue::Int32(Some(2)));
}
}