use std::path::{Path, PathBuf};
use std::sync::Arc;
use arrow::array::{
Array, ArrayRef, DictionaryArray, FixedSizeBinaryArray, Int64Array, Int64BufferBuilder,
RecordBatch, RecordBatchIterator, StringArray, UInt32Array, UInt32BufferBuilder,
};
use arrow::buffer::ScalarBuffer;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use lance::deps::arrow_array::UInt8Array;
use lance_index::DatasetIndexExt as _;
use re_chunk_store::Chunk;
use re_log_types::{ComponentPath, EntityPath, TimelineName};
use re_protos::cloud::v1alpha1::ext::{IndexConfig, IndexProperties};
use re_protos::common::v1alpha1::ext::SegmentId;
use re_types_core::ComponentIdentifier;
use crate::chunk_index::{
ArcCell, FIELD_CHUNK_ID, FIELD_INSTANCE, FIELD_INSTANCE_ID, FIELD_RERUN_SEGMENT_ID,
FIELD_RERUN_SEGMENT_LAYER, FIELD_TIMEPOINT,
};
use crate::store::{Dataset, Error as StoreError};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum IndexType {
Inverted,
VectorIvfPq,
BTree,
}
pub struct IndexDataTypes {
pub instances: DataType,
pub timepoints: DataType,
}
impl From<&IndexProperties> for IndexType {
fn from(properties: &IndexProperties) -> Self {
match properties {
IndexProperties::Inverted { .. } => Self::Inverted,
IndexProperties::VectorIvfPq { .. } => Self::VectorIvfPq,
IndexProperties::Btree => Self::BTree,
}
}
}
impl super::Index {
pub async fn store_chunks(
&self,
chunks: Vec<(SegmentId, String, Arc<Chunk>)>,
checkout_latest: bool,
) -> Result<(), StoreError> {
let index_type: IndexType = (&self.config.properties).into();
let timeline = self.config.time_index;
let component = self.config.column.descriptor.component;
let batches = chunks
.into_iter()
.filter_map(move |(segment_id, layer, chunk)| {
Self::prepare_record_batch(
index_type,
&segment_id,
layer,
timeline,
component,
&chunk,
)
.transpose()
});
let mut lance: lance::Dataset = self.lance_dataset.cloned();
let mut iter = batches.peekable();
if let Some(Ok(first)) = iter.peek() {
let schema = first.schema();
lance
.append(
RecordBatchIterator::new(iter, schema),
Some(Default::default()),
)
.await?;
if checkout_latest {
lance.checkout_latest().await?;
self.lance_dataset.replace(lance);
}
} else {
Err(StoreError::IndexingError(
"Cannot determine indexed data schema".to_owned(),
))?;
}
Ok(())
}
pub async fn remove_layers(
&self,
layers: &[(SegmentId, String)],
checkout_latest: bool,
) -> Result<(), StoreError> {
let mut lance: lance::Dataset = self.lance_dataset.cloned();
let predicate = if cfg!(false) {
use datafusion::prelude::*;
fn balanced_binary_exprs(
mut exprs: Vec<datafusion::logical_expr::Expr>,
op: datafusion::logical_expr::Operator,
) -> Option<datafusion::logical_expr::Expr> {
while exprs.len() > 1 {
let mut exprs_next = Vec::with_capacity(exprs.len() / 2 + 1);
let mut exprs_prev = exprs.into_iter();
while let Some(left) = exprs_prev.next() {
if let Some(right) = exprs_prev.next() {
exprs_next.push(datafusion::prelude::binary_expr(left, op, right));
} else {
exprs_next.push(left);
}
}
exprs = exprs_next;
}
exprs.into_iter().next()
}
let predicates = layers
.iter()
.map(|(segment, layer)| {
(cast(col(FIELD_RERUN_SEGMENT_ID), DataType::Utf8).eq(lit(&segment.id)))
.and(cast(col(FIELD_RERUN_SEGMENT_LAYER), DataType::Utf8).eq(lit(layer)))
})
.collect();
let Some(predicate) =
balanced_binary_exprs(predicates, datafusion::logical_expr::Operator::Or)
else {
if checkout_latest {
lance.checkout_latest().await?;
self.lance_dataset.replace(lance);
}
return Ok(());
};
datafusion::sql::unparser::expr_to_sql(&predicate)?.to_string()
} else {
layers
.iter()
.map(|(segment, layer)| {
format!(
"(CAST({} AS string) = '{}' AND CAST({} AS string) = '{}')",
FIELD_RERUN_SEGMENT_ID,
segment.id.replace('\'', "''"),
FIELD_RERUN_SEGMENT_LAYER,
layer.replace('\'', "''"),
)
})
.collect::<Vec<_>>()
.join(" OR ")
};
lance.delete(&predicate).await?;
lance
.optimize_indices(&lance_index::optimize::OptimizeOptions::append())
.await?;
if checkout_latest {
lance.checkout_latest().await?;
self.lance_dataset.replace(lance);
}
Ok(())
}
pub fn prepare_record_batch(
index_type: IndexType,
segment_id: &SegmentId,
layer: String,
timeline: TimelineName,
component: ComponentIdentifier,
chunk: &Arc<Chunk>,
) -> Result<Option<RecordBatch>, ArrowError> {
let Some(timeline) = chunk.timelines().get(&timeline) else {
return Ok(None);
};
let Some(component) = chunk.components().get(component) else {
return Ok(None);
};
let row_is_array_of_instances = match index_type {
IndexType::Inverted | IndexType::BTree => true,
IndexType::VectorIvfPq if !component.list_array.value_type().is_numeric() => true,
IndexType::VectorIvfPq => false,
};
let total_instances = if row_is_array_of_instances {
component
.list_array
.iter()
.map(|x| x.map(|x| x.len()).unwrap_or(0))
.sum()
} else {
component.list_array.len() - component.list_array.null_count()
};
let dict_keys = UInt8Array::from_iter_values(std::iter::repeat_n(0, total_instances));
let segment_id_array = {
let segment_id_values = StringArray::from_iter_values([segment_id.id.as_str()]);
DictionaryArray::new(dict_keys.clone(), Arc::new(segment_id_values))
};
let layer_array = {
let layer_values = StringArray::from_iter_values([layer]);
DictionaryArray::new(dict_keys.clone(), Arc::new(layer_values))
};
let chunk_id_array = FixedSizeBinaryArray::try_from_iter(std::iter::repeat_n(
chunk.id().as_bytes(),
total_instances,
))?;
let instance_id_array: UInt32Array;
let timepoint_array: ArrayRef;
let instance_array: ArrayRef;
if row_is_array_of_instances {
let mut timepoints = Int64BufferBuilder::new(total_instances);
let mut instance_ids = UInt32BufferBuilder::new(total_instances);
let mut instances = Vec::new();
for (row_num, instance) in component.list_array.iter().enumerate() {
let Some(instance) = instance else {
continue;
};
timepoints.append_n(instance.len(), timeline.times_raw()[row_num]);
for i in 0..instance.len() as u32 {
instance_ids.append(i);
}
instances.push(instance);
}
timepoint_array = arrow::compute::cast(
&Int64Array::new(ScalarBuffer::from(timepoints), None),
&timeline.timeline().datatype(),
)?;
let instance_arrays: Vec<&dyn Array> = instances.iter().map(|x| x.as_ref()).collect();
instance_array = re_arrow_util::concat_arrays(instance_arrays.as_slice())?;
instance_id_array = UInt32Array::new(ScalarBuffer::from(instance_ids), None);
} else {
if component.list_array.null_count() == 0 {
instance_array = Arc::new(component.list_array.clone());
timepoint_array = Arc::new(timeline.times_array().clone());
} else {
let non_nulls = arrow::compute::is_not_null(&component.list_array)?;
let list_array: ArrayRef = Arc::new(component.list_array.clone());
instance_array = re_arrow_util::filter_array(&list_array, &non_nulls);
timepoint_array = re_arrow_util::filter_array(&timeline.times_array(), &non_nulls);
}
let mut instance_ids = UInt32BufferBuilder::new(total_instances);
instance_ids.append_n(total_instances, 0);
instance_id_array = UInt32Array::new(ScalarBuffer::from(instance_ids), None);
}
let batch = RecordBatch::try_from_iter([
(
FIELD_RERUN_SEGMENT_ID,
Arc::new(segment_id_array) as ArrayRef,
),
(FIELD_RERUN_SEGMENT_LAYER, Arc::new(layer_array)),
(FIELD_CHUNK_ID, Arc::new(chunk_id_array)),
(FIELD_TIMEPOINT, Arc::new(timepoint_array)),
(FIELD_INSTANCE_ID, Arc::new(instance_id_array)),
(FIELD_INSTANCE, Arc::new(instance_array)),
])?;
Ok(Some(batch))
}
}
pub async fn create_index(
dataset: &Dataset,
config: &IndexConfig,
path: PathBuf,
) -> Result<super::Index, StoreError> {
let index_type: IndexType = (&config.properties).into();
let types: IndexDataTypes = find_datatypes(
dataset,
index_type,
&config.column.entity_path,
&config.column.descriptor.component,
&config.time_index,
)
.ok_or_else(|| {
StoreError::ComponentPathNotFound(ComponentPath::new(
config.column.entity_path.clone(),
config.column.descriptor.component,
))
})?;
let mut lance_table = create_lance_dataset(&path, types).await?;
create_lance_index(&mut lance_table, &config.properties).await?;
Ok(super::Index {
lance_dataset: ArcCell::new(lance_table),
config: config.clone(),
})
}
async fn create_lance_dataset(
path: &Path,
types: IndexDataTypes,
) -> Result<lance::Dataset, StoreError> {
let non_nullable = false;
let schema = Arc::new(
#[expect(clippy::disallowed_methods)]
Schema::new(vec![
Field::new_dictionary(
FIELD_RERUN_SEGMENT_ID,
DataType::UInt8,
DataType::Utf8,
non_nullable,
)
.with_dict_is_ordered(true),
Field::new_dictionary(
FIELD_RERUN_SEGMENT_LAYER,
DataType::UInt8,
DataType::Utf8,
non_nullable,
)
.with_dict_is_ordered(true),
Field::new(FIELD_CHUNK_ID, DataType::FixedSizeBinary(16), non_nullable)
.with_dict_is_ordered(true),
Field::new(FIELD_TIMEPOINT, types.timepoints, non_nullable),
Field::new(FIELD_INSTANCE_ID, DataType::UInt32, non_nullable),
Field::new(FIELD_INSTANCE, types.instances, true),
]),
);
let batch = RecordBatch::new_empty(schema.clone());
let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone());
let dataset = lance::Dataset::write(batches, path.to_string_lossy().as_ref(), None).await?;
Ok(dataset)
}
async fn create_lance_index(
lance_table: &mut lance::Dataset,
properties: &IndexProperties,
) -> Result<(), StoreError> {
use lance::index::vector::VectorIndexParams;
use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams};
use lance_index::{DatasetIndexExt as _, IndexParams, IndexType};
use lance_linalg::distance::MetricType;
use re_protos::cloud::v1alpha1::VectorDistanceMetric;
let (index_type, index_params): (IndexType, &dyn IndexParams) = match properties {
IndexProperties::Inverted {
store_position,
base_tokenizer,
} => (
IndexType::Inverted,
&InvertedIndexParams::default()
.with_position(*store_position)
.base_tokenizer(base_tokenizer.clone()),
),
IndexProperties::VectorIvfPq {
target_partition_num_rows,
num_sub_vectors,
metric,
} => {
let ivf_params = lance_index::vector::ivf::IvfBuildParams {
target_partition_size: target_partition_num_rows.map(|v| v as usize),
..Default::default()
};
let pq_params = lance_index::vector::pq::PQBuildParams {
num_sub_vectors: *num_sub_vectors as usize,
..Default::default()
};
let lance_metric = match metric {
VectorDistanceMetric::Unspecified => {
return Err(StoreError::IndexingError(
"Unspecified distance metric".to_owned(),
));
}
VectorDistanceMetric::L2 => MetricType::L2,
VectorDistanceMetric::Cosine => MetricType::Cosine,
VectorDistanceMetric::Dot => MetricType::Dot,
VectorDistanceMetric::Hamming => MetricType::Hamming,
};
(
IndexType::Vector,
&VectorIndexParams::with_ivf_pq_params(lance_metric, ivf_params, pq_params),
)
}
IndexProperties::Btree => (IndexType::BTree, &ScalarIndexParams::default()),
};
match lance_table
.create_index(&["instance"], index_type, None, index_params, false)
.await
{
Ok(_) => Ok(()),
Err(lance::Error::Index { message, .. }) if message.contains("already exists") => Ok(()),
Err(lance::Error::Index { ref message, .. })
if message.contains("Not enough rows to train PQ")
|| message.contains("KMeans: can not train") =>
{
tracing::warn!("not enough rows to train index yet");
Ok(())
}
Err(lance::Error::NotSupported { source, .. })
if source
.to_string()
.contains("empty vector indices with train=False") =>
{
tracing::warn!("not enough rows to train index yet");
Ok(())
}
Err(err) => Err(err),
}?;
Ok(())
}
fn find_datatypes(
dataset: &Dataset,
index_type: IndexType,
entity_path: &EntityPath,
component: &ComponentIdentifier,
timeline_name: &TimelineName,
) -> Option<IndexDataTypes> {
for segment in dataset.segments().values() {
for layer in segment.layers().values() {
let chunk_store = layer.store_handle().read();
for chunk in chunk_store.iter_physical_chunks() {
if chunk.entity_path() == entity_path
&& let Some(component) = chunk.components().0.get(component)
&& let Some(timeline) = chunk.timelines().get(timeline_name)
{
let instance_type = if index_type == IndexType::VectorIvfPq
&& component.list_array.value_type().is_numeric()
{
component.list_array.data_type().clone()
} else {
component.list_array.value_type()
};
return Some(IndexDataTypes {
instances: instance_type,
timepoints: timeline.timeline().datatype(),
});
}
}
}
}
None
}