use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow_array::{ListArray, RecordBatch};
use arrow_schema::{Field, Schema};
use async_trait::async_trait;
use datafusion::functions::string::contains::ContainsFunc;
use datafusion::functions_nested::array_has;
use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion_common::{Column, scalar::ScalarValue};
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::{any::Any, ops::Bound, sync::Arc};
use datafusion_expr::Expr;
use datafusion_expr::expr::ScalarFunction;
use deepsize::DeepSizeOf;
use inverted::query::{FtsQuery, FtsQueryNode, FtsSearchParams, MatchQuery, fill_fts_query_column};
use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap};
use lance_core::{Error, Result};
use roaring::RoaringBitmap;
use serde::Serialize;
use crate::metrics::MetricsCollector;
use crate::scalar::registry::TrainingCriteria;
use crate::{Index, IndexParams, IndexType};
pub mod bitmap;
pub mod bloomfilter;
pub mod btree;
pub mod expression;
pub mod inverted;
pub mod json;
pub mod label_list;
pub mod lance_format;
pub mod ngram;
pub mod registry;
#[cfg(feature = "geo")]
pub mod rtree;
pub mod zoned;
pub mod zonemap;
use crate::frag_reuse::FragReuseIndex;
pub use inverted::tokenizer::InvertedIndexParams;
use lance_datafusion::udf::CONTAINS_TOKENS_UDF;
pub const LANCE_SCALAR_INDEX: &str = "__lance_scalar_index";
#[derive(Debug, Clone, PartialEq, Eq, DeepSizeOf)]
pub enum BuiltinIndexType {
BTree,
Bitmap,
LabelList,
NGram,
ZoneMap,
BloomFilter,
RTree,
Inverted,
}
impl BuiltinIndexType {
pub fn as_str(&self) -> &str {
match self {
Self::BTree => "btree",
Self::Bitmap => "bitmap",
Self::LabelList => "labellist",
Self::NGram => "ngram",
Self::ZoneMap => "zonemap",
Self::Inverted => "inverted",
Self::BloomFilter => "bloomfilter",
Self::RTree => "rtree",
}
}
}
impl TryFrom<IndexType> for BuiltinIndexType {
type Error = Error;
fn try_from(value: IndexType) -> Result<Self> {
match value {
IndexType::BTree => Ok(Self::BTree),
IndexType::Bitmap => Ok(Self::Bitmap),
IndexType::LabelList => Ok(Self::LabelList),
IndexType::NGram => Ok(Self::NGram),
IndexType::ZoneMap => Ok(Self::ZoneMap),
IndexType::Inverted => Ok(Self::Inverted),
IndexType::BloomFilter => Ok(Self::BloomFilter),
IndexType::RTree => Ok(Self::RTree),
_ => Err(Error::index("Invalid index type".to_string())),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ScalarIndexParams {
pub index_type: String,
pub params: Option<String>,
}
impl Default for ScalarIndexParams {
fn default() -> Self {
Self {
index_type: BuiltinIndexType::BTree.as_str().to_string(),
params: None,
}
}
}
impl ScalarIndexParams {
pub fn for_builtin(index_type: BuiltinIndexType) -> Self {
Self {
index_type: index_type.as_str().to_string(),
params: None,
}
}
pub fn new(index_type: String) -> Self {
Self {
index_type,
params: None,
}
}
pub fn with_params<ParamsType: Serialize>(mut self, params: &ParamsType) -> Self {
self.params = Some(serde_json::to_string(params).unwrap());
self
}
}
impl IndexParams for ScalarIndexParams {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn index_name(&self) -> &str {
LANCE_SCALAR_INDEX
}
}
impl IndexParams for InvertedIndexParams {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn index_name(&self) -> &str {
"INVERTED"
}
}
#[async_trait]
pub trait IndexWriter: Send {
async fn write_record_batch(&mut self, batch: RecordBatch) -> Result<u64>;
async fn finish(&mut self) -> Result<()>;
async fn finish_with_metadata(&mut self, metadata: HashMap<String, String>) -> Result<()>;
}
#[async_trait]
pub trait IndexReader: Send + Sync {
async fn read_record_batch(&self, n: u64, batch_size: u64) -> Result<RecordBatch>;
async fn read_range(
&self,
range: std::ops::Range<usize>,
projection: Option<&[&str]>,
) -> Result<RecordBatch>;
async fn num_batches(&self, batch_size: u64) -> u32;
fn num_rows(&self) -> usize;
fn schema(&self) -> &lance_core::datatypes::Schema;
}
#[async_trait]
pub trait IndexStore: std::fmt::Debug + Send + Sync + DeepSizeOf {
fn as_any(&self) -> &dyn Any;
fn io_parallelism(&self) -> usize;
async fn new_index_file(&self, name: &str, schema: Arc<Schema>)
-> Result<Box<dyn IndexWriter>>;
async fn open_index_file(&self, name: &str) -> Result<Arc<dyn IndexReader>>;
async fn copy_index_file(&self, name: &str, dest_store: &dyn IndexStore) -> Result<()>;
async fn rename_index_file(&self, name: &str, new_name: &str) -> Result<()>;
async fn delete_index_file(&self, name: &str) -> Result<()>;
}
pub trait AnyQuery: std::fmt::Debug + Any + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn format(&self, col: &str) -> String;
fn to_expr(&self, col: String) -> Expr;
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool;
}
impl PartialEq for dyn AnyQuery {
fn eq(&self, other: &Self) -> bool {
self.dyn_eq(other)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FullTextSearchQuery {
pub query: FtsQuery,
pub limit: Option<i64>,
pub wand_factor: Option<f32>,
}
impl FullTextSearchQuery {
pub fn new(query: String) -> Self {
let query = MatchQuery::new(query).into();
Self {
query,
limit: None,
wand_factor: None,
}
}
pub fn new_fuzzy(term: String, max_distance: Option<u32>) -> Self {
let query = MatchQuery::new(term).with_fuzziness(max_distance).into();
Self {
query,
limit: None,
wand_factor: None,
}
}
pub fn new_query(query: FtsQuery) -> Self {
Self {
query,
limit: None,
wand_factor: None,
}
}
pub fn with_column(mut self, column: String) -> Result<Self> {
self.query = fill_fts_query_column(&self.query, &[column], true)?;
Ok(self)
}
pub fn with_columns(mut self, columns: &[String]) -> Result<Self> {
self.query = fill_fts_query_column(&self.query, columns, true)?;
Ok(self)
}
pub fn limit(mut self, limit: Option<i64>) -> Self {
self.limit = limit;
self
}
pub fn wand_factor(mut self, wand_factor: Option<f32>) -> Self {
self.wand_factor = wand_factor;
self
}
pub fn columns(&self) -> HashSet<String> {
self.query.columns()
}
pub fn params(&self) -> FtsSearchParams {
FtsSearchParams::new()
.with_limit(self.limit.map(|limit| limit as usize))
.with_wand_factor(self.wand_factor.unwrap_or(1.0))
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SargableQuery {
Range(Bound<ScalarValue>, Bound<ScalarValue>),
IsIn(Vec<ScalarValue>),
Equals(ScalarValue),
FullTextSearch(FullTextSearchQuery),
IsNull(),
}
impl AnyQuery for SargableQuery {
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self, col: &str) -> String {
match self {
Self::Range(lower, upper) => match (lower, upper) {
(Bound::Unbounded, Bound::Unbounded) => "true".to_string(),
(Bound::Unbounded, Bound::Included(rhs)) => format!("{} <= {}", col, rhs),
(Bound::Unbounded, Bound::Excluded(rhs)) => format!("{} < {}", col, rhs),
(Bound::Included(lhs), Bound::Unbounded) => format!("{} >= {}", col, lhs),
(Bound::Included(lhs), Bound::Included(rhs)) => {
format!("{} >= {} && {} <= {}", col, lhs, col, rhs)
}
(Bound::Included(lhs), Bound::Excluded(rhs)) => {
format!("{} >= {} && {} < {}", col, lhs, col, rhs)
}
(Bound::Excluded(lhs), Bound::Unbounded) => format!("{} > {}", col, lhs),
(Bound::Excluded(lhs), Bound::Included(rhs)) => {
format!("{} > {} && {} <= {}", col, lhs, col, rhs)
}
(Bound::Excluded(lhs), Bound::Excluded(rhs)) => {
format!("{} > {} && {} < {}", col, lhs, col, rhs)
}
},
Self::IsIn(values) => {
format!(
"{} IN [{}]",
col,
values
.iter()
.map(|val| val.to_string())
.collect::<Vec<_>>()
.join(",")
)
}
Self::FullTextSearch(query) => {
format!("fts({})", query.query)
}
Self::IsNull() => {
format!("{} IS NULL", col)
}
Self::Equals(val) => {
format!("{} = {}", col, val)
}
}
}
fn to_expr(&self, col: String) -> Expr {
let col_expr = Expr::Column(Column::new_unqualified(col));
match self {
Self::Range(lower, upper) => match (lower, upper) {
(Bound::Unbounded, Bound::Unbounded) => {
Expr::Literal(ScalarValue::Boolean(Some(true)), None)
}
(Bound::Unbounded, Bound::Included(rhs)) => {
col_expr.lt_eq(Expr::Literal(rhs.clone(), None))
}
(Bound::Unbounded, Bound::Excluded(rhs)) => {
col_expr.lt(Expr::Literal(rhs.clone(), None))
}
(Bound::Included(lhs), Bound::Unbounded) => {
col_expr.gt_eq(Expr::Literal(lhs.clone(), None))
}
(Bound::Included(lhs), Bound::Included(rhs)) => col_expr.between(
Expr::Literal(lhs.clone(), None),
Expr::Literal(rhs.clone(), None),
),
(Bound::Included(lhs), Bound::Excluded(rhs)) => col_expr
.clone()
.gt_eq(Expr::Literal(lhs.clone(), None))
.and(col_expr.lt(Expr::Literal(rhs.clone(), None))),
(Bound::Excluded(lhs), Bound::Unbounded) => {
col_expr.gt(Expr::Literal(lhs.clone(), None))
}
(Bound::Excluded(lhs), Bound::Included(rhs)) => col_expr
.clone()
.gt(Expr::Literal(lhs.clone(), None))
.and(col_expr.lt_eq(Expr::Literal(rhs.clone(), None))),
(Bound::Excluded(lhs), Bound::Excluded(rhs)) => col_expr
.clone()
.gt(Expr::Literal(lhs.clone(), None))
.and(col_expr.lt(Expr::Literal(rhs.clone(), None))),
},
Self::IsIn(values) => col_expr.in_list(
values
.iter()
.map(|val| Expr::Literal(val.clone(), None))
.collect::<Vec<_>>(),
false,
),
Self::FullTextSearch(query) => col_expr.like(Expr::Literal(
ScalarValue::Utf8(Some(query.query.to_string())),
None,
)),
Self::IsNull() => col_expr.is_null(),
Self::Equals(value) => col_expr.eq(Expr::Literal(value.clone(), None)),
}
}
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum LabelListQuery {
HasAllLabels(Vec<ScalarValue>),
HasAnyLabel(Vec<ScalarValue>),
}
impl AnyQuery for LabelListQuery {
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self, col: &str) -> String {
format!("{}", self.to_expr(col.to_string()))
}
fn to_expr(&self, col: String) -> Expr {
match self {
Self::HasAllLabels(labels) => {
let labels_arr = ScalarValue::iter_to_array(labels.iter().cloned()).unwrap();
let offsets_buffer =
OffsetBuffer::new(ScalarBuffer::<i32>::from(vec![0, labels_arr.len() as i32]));
let labels_list = ListArray::try_new(
Arc::new(Field::new("item", labels_arr.data_type().clone(), true)),
offsets_buffer,
labels_arr,
None,
)
.unwrap();
let labels_arr = Arc::new(labels_list);
Expr::ScalarFunction(ScalarFunction {
func: Arc::new(array_has::ArrayHasAll::new().into()),
args: vec![
Expr::Column(Column::new_unqualified(col)),
Expr::Literal(ScalarValue::List(labels_arr), None),
],
})
}
Self::HasAnyLabel(labels) => {
let labels_arr = ScalarValue::iter_to_array(labels.iter().cloned()).unwrap();
let offsets_buffer =
OffsetBuffer::new(ScalarBuffer::<i32>::from(vec![0, labels_arr.len() as i32]));
let labels_list = ListArray::try_new(
Arc::new(Field::new("item", labels_arr.data_type().clone(), true)),
offsets_buffer,
labels_arr,
None,
)
.unwrap();
let labels_arr = Arc::new(labels_list);
Expr::ScalarFunction(ScalarFunction {
func: Arc::new(array_has::ArrayHasAny::new().into()),
args: vec![
Expr::Column(Column::new_unqualified(col)),
Expr::Literal(ScalarValue::List(labels_arr), None),
],
})
}
}
}
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TextQuery {
StringContains(String),
}
impl AnyQuery for TextQuery {
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self, col: &str) -> String {
format!("{}", self.to_expr(col.to_string()))
}
fn to_expr(&self, col: String) -> Expr {
match self {
Self::StringContains(substr) => Expr::ScalarFunction(ScalarFunction {
func: Arc::new(ContainsFunc::new().into()),
args: vec![
Expr::Column(Column::new_unqualified(col)),
Expr::Literal(ScalarValue::Utf8(Some(substr.clone())), None),
],
}),
}
}
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TokenQuery {
TokensContains(String),
}
#[derive(Debug, Clone, PartialEq)]
pub enum BloomFilterQuery {
Equals(ScalarValue),
IsNull(),
IsIn(Vec<ScalarValue>),
}
impl AnyQuery for BloomFilterQuery {
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self, col: &str) -> String {
match self {
Self::Equals(val) => {
format!("{} = {}", col, val)
}
Self::IsNull() => {
format!("{} IS NULL", col)
}
Self::IsIn(values) => {
format!(
"{} IN [{}]",
col,
values
.iter()
.map(|val| val.to_string())
.collect::<Vec<_>>()
.join(",")
)
}
}
}
fn to_expr(&self, col: String) -> Expr {
let col_expr = Expr::Column(Column::new_unqualified(col));
match self {
Self::Equals(value) => col_expr.eq(Expr::Literal(value.clone(), None)),
Self::IsNull() => col_expr.is_null(),
Self::IsIn(values) => col_expr.in_list(
values
.iter()
.map(|val| Expr::Literal(val.clone(), None))
.collect::<Vec<_>>(),
false,
),
}
}
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
}
impl AnyQuery for TokenQuery {
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self, col: &str) -> String {
format!("{}", self.to_expr(col.to_string()))
}
fn to_expr(&self, col: String) -> Expr {
match self {
Self::TokensContains(substr) => Expr::ScalarFunction(ScalarFunction {
func: Arc::new(CONTAINS_TOKENS_UDF.clone()),
args: vec![
Expr::Column(Column::new_unqualified(col)),
Expr::Literal(ScalarValue::Utf8(Some(substr.clone())), None),
],
}),
}
}
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
}
#[cfg(feature = "geo")]
#[derive(Debug, Clone, PartialEq)]
pub struct RelationQuery {
pub value: ScalarValue,
pub field: Field,
}
#[cfg(feature = "geo")]
#[derive(Debug, Clone, PartialEq)]
pub enum GeoQuery {
IntersectQuery(RelationQuery),
IsNull,
}
#[cfg(feature = "geo")]
impl AnyQuery for GeoQuery {
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self, col: &str) -> String {
match self {
Self::IntersectQuery(query) => {
format!("Intersect({} {})", col, query.value)
}
Self::IsNull => {
format!("{} IS NULL", col)
}
}
}
fn to_expr(&self, _col: String) -> Expr {
todo!()
}
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
}
#[derive(Debug, PartialEq)]
pub enum SearchResult {
Exact(NullableRowAddrSet),
AtMost(NullableRowAddrSet),
AtLeast(NullableRowAddrSet),
}
impl SearchResult {
pub fn exact(row_ids: impl Into<RowAddrTreeMap>) -> Self {
Self::Exact(NullableRowAddrSet::new(row_ids.into(), Default::default()))
}
pub fn at_most(row_ids: impl Into<RowAddrTreeMap>) -> Self {
Self::AtMost(NullableRowAddrSet::new(row_ids.into(), Default::default()))
}
pub fn at_least(row_ids: impl Into<RowAddrTreeMap>) -> Self {
Self::AtLeast(NullableRowAddrSet::new(row_ids.into(), Default::default()))
}
pub fn with_nulls(self, nulls: impl Into<RowAddrTreeMap>) -> Self {
match self {
Self::Exact(row_ids) => Self::Exact(row_ids.with_nulls(nulls.into())),
Self::AtMost(row_ids) => Self::AtMost(row_ids.with_nulls(nulls.into())),
Self::AtLeast(row_ids) => Self::AtLeast(row_ids.with_nulls(nulls.into())),
}
}
pub fn row_addrs(&self) -> &NullableRowAddrSet {
match self {
Self::Exact(row_addrs) => row_addrs,
Self::AtMost(row_addrs) => row_addrs,
Self::AtLeast(row_addrs) => row_addrs,
}
}
pub fn is_exact(&self) -> bool {
matches!(self, Self::Exact(_))
}
}
pub struct CreatedIndex {
pub index_details: prost_types::Any,
pub index_version: u32,
}
pub struct UpdateCriteria {
pub requires_old_data: bool,
pub data_criteria: TrainingCriteria,
}
impl UpdateCriteria {
pub fn requires_old_data(data_criteria: TrainingCriteria) -> Self {
Self {
requires_old_data: true,
data_criteria,
}
}
pub fn only_new_data(data_criteria: TrainingCriteria) -> Self {
Self {
requires_old_data: false,
data_criteria,
}
}
}
#[async_trait]
pub trait ScalarIndex: Send + Sync + std::fmt::Debug + Index + DeepSizeOf {
async fn search(
&self,
query: &dyn AnyQuery,
metrics: &dyn MetricsCollector,
) -> Result<SearchResult>;
fn can_remap(&self) -> bool;
async fn remap(
&self,
mapping: &HashMap<u64, Option<u64>>,
dest_store: &dyn IndexStore,
) -> Result<CreatedIndex>;
async fn update(
&self,
new_data: SendableRecordBatchStream,
dest_store: &dyn IndexStore,
valid_old_fragments: Option<&RoaringBitmap>,
) -> Result<CreatedIndex>;
fn update_criteria(&self) -> UpdateCriteria;
fn derive_index_params(&self) -> Result<ScalarIndexParams>;
}