use std::sync::{Arc, LazyLock};
use arrow_array::{Array, RecordBatch, UInt32Array, builder::BinaryBuilder};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use roaring::RoaringBitmap;
use lance_core::{Error, Result};
use crate::mask::{NullableRowAddrMask, RowAddrMask, RowSetOps};
#[derive(Debug, Clone)]
pub struct NullableIndexExprResult {
pub lower: NullableRowAddrMask,
pub upper: NullableRowAddrMask,
exact: bool,
}
impl NullableIndexExprResult {
pub fn exact(mask: NullableRowAddrMask) -> Self {
Self {
lower: mask.clone(),
upper: mask,
exact: true,
}
}
pub fn at_most(mask: NullableRowAddrMask) -> Self {
Self {
lower: NullableRowAddrMask::allow_nothing(),
upper: mask,
exact: false,
}
}
pub fn at_least(mask: NullableRowAddrMask) -> Self {
Self {
lower: mask,
upper: NullableRowAddrMask::all_rows(),
exact: false,
}
}
pub fn new(lower: NullableRowAddrMask, upper: NullableRowAddrMask) -> Self {
if lower == upper {
Self::exact(lower)
} else {
Self {
lower,
upper,
exact: false,
}
}
}
pub fn is_exact(&self) -> bool {
self.exact
}
pub fn is_at_most(&self) -> bool {
matches!(&self.lower, NullableRowAddrMask::AllowList(set) if set.is_empty())
}
pub fn is_at_least(&self) -> bool {
matches!(&self.upper, NullableRowAddrMask::BlockList(set) if set.is_empty())
}
pub fn drop_nulls(self) -> IndexExprResult {
IndexExprResult {
lower: self.lower.drop_nulls(),
upper: self.upper.drop_nulls(),
exact: self.exact,
}
}
}
impl std::ops::Not for NullableIndexExprResult {
type Output = Self;
fn not(self) -> Self {
Self {
lower: !self.upper,
upper: !self.lower,
exact: self.exact,
}
}
}
impl std::ops::BitAnd<Self> for NullableIndexExprResult {
type Output = Self;
fn bitand(self, rhs: Self) -> Self {
Self {
lower: self.lower & rhs.lower,
upper: self.upper & rhs.upper,
exact: self.exact && rhs.exact,
}
}
}
impl std::ops::BitOr<Self> for NullableIndexExprResult {
type Output = Self;
fn bitor(self, rhs: Self) -> Self {
Self {
lower: self.lower | rhs.lower,
upper: self.upper | rhs.upper,
exact: self.exact && rhs.exact,
}
}
}
#[derive(Debug, Clone)]
pub struct IndexExprResult {
pub lower: RowAddrMask,
pub upper: RowAddrMask,
exact: bool,
}
impl IndexExprResult {
pub fn exact(mask: RowAddrMask) -> Self {
Self {
lower: mask.clone(),
upper: mask,
exact: true,
}
}
pub fn at_most(mask: RowAddrMask) -> Self {
Self {
lower: RowAddrMask::allow_nothing(),
upper: mask,
exact: false,
}
}
pub fn at_least(mask: RowAddrMask) -> Self {
Self {
lower: mask,
upper: RowAddrMask::all_rows(),
exact: false,
}
}
pub fn new(lower: RowAddrMask, upper: RowAddrMask) -> Self {
Self {
lower,
upper,
exact: false,
}
}
pub fn is_exact(&self) -> bool {
self.exact
}
pub fn is_at_most(&self) -> bool {
matches!(&self.lower, RowAddrMask::AllowList(set) if set.is_empty())
}
pub fn is_at_least(&self) -> bool {
matches!(&self.upper, RowAddrMask::BlockList(set) if set.is_empty())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{NullableRowAddrSet, RowAddrTreeMap};
fn allow(rows: &[u64]) -> NullableRowAddrMask {
NullableRowAddrMask::AllowList(NullableRowAddrSet::new(
RowAddrTreeMap::from_iter(rows.iter().copied()),
RowAddrTreeMap::new(),
))
}
#[test]
fn nullable_index_expr_result_new_canonicalizes_exact() {
let lower = allow(&[1, 2]);
let result = NullableIndexExprResult::new(lower.clone(), lower.clone());
assert!(result.is_exact());
assert_eq!(result.lower, lower);
assert_eq!(result.upper, lower);
}
#[test]
fn nullable_index_expr_result_new_preserves_interval() {
let lower = allow(&[1, 2]);
let upper = allow(&[1, 2, 3]);
let result = NullableIndexExprResult::new(lower.clone(), upper.clone());
assert!(!result.is_exact());
assert_eq!(result.lower, lower);
assert_eq!(result.upper, upper);
}
}
impl std::ops::Not for IndexExprResult {
type Output = Self;
fn not(self) -> Self {
Self {
lower: !self.upper,
upper: !self.lower,
exact: self.exact,
}
}
}
impl std::ops::BitAnd<Self> for IndexExprResult {
type Output = Self;
fn bitand(self, rhs: Self) -> Self {
Self {
lower: self.lower & rhs.lower,
upper: self.upper & rhs.upper,
exact: self.exact && rhs.exact,
}
}
}
impl std::ops::BitOr<Self> for IndexExprResult {
type Output = Self;
fn bitor(self, rhs: Self) -> Self {
Self {
lower: self.lower | rhs.lower,
upper: self.upper | rhs.upper,
exact: self.exact && rhs.exact,
}
}
}
static TWO_MASK_RESULT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
Arc::new(Schema::new(vec![
Field::new("lower", DataType::Binary, true),
Field::new("upper", DataType::Binary, true),
Field::new("fragments_covered", DataType::Binary, true),
]))
});
static THREE_VARIANT_RESULT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
Arc::new(Schema::new(vec![
Field::new("result".to_string(), DataType::Binary, true),
Field::new("discriminant".to_string(), DataType::UInt32, true),
Field::new("fragments_covered".to_string(), DataType::Binary, true),
]))
});
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndexExprResultWireFormat {
ThreeVariant, #[default]
TwoMask, }
impl IndexExprResultWireFormat {
pub fn schema(&self) -> &SchemaRef {
match self {
Self::ThreeVariant => &THREE_VARIANT_RESULT_SCHEMA,
Self::TwoMask => &TWO_MASK_RESULT_SCHEMA,
}
}
}
impl IndexExprResult {
#[tracing::instrument(skip_all)]
fn serialize_standard(&self, fragments_covered: &RoaringBitmap) -> Result<RecordBatch> {
let lower_arr = self.lower.into_arrow()?;
let upper_arr = if self.is_exact() {
let mut b = BinaryBuilder::new();
b.append_null();
b.append_null();
b.finish()
} else {
self.upper.into_arrow()?
};
let mut frags_builder = BinaryBuilder::new();
let mut frags_bytes = Vec::with_capacity(fragments_covered.serialized_size());
fragments_covered.serialize_into(&mut frags_bytes)?;
frags_builder.append_value(frags_bytes);
frags_builder.append_null();
Ok(RecordBatch::try_new(
TWO_MASK_RESULT_SCHEMA.clone(),
vec![
Arc::new(lower_arr),
Arc::new(upper_arr),
Arc::new(frags_builder.finish()) as Arc<dyn Array>,
],
)?)
}
fn serialize_three_variant(&self, fragments_covered: &RoaringBitmap) -> Result<RecordBatch> {
let (mask, discriminant) = if self.is_exact() {
(&self.lower, 0u32)
} else if self.is_at_most() {
(&self.upper, 1)
} else if self.is_at_least() {
(&self.lower, 2)
} else {
tracing::warn!(
"Legacy serialization of refined index-expr result: degrading to AtMost(upper); \
answer will remain correct but query will be more expensive"
);
(&self.upper, 1)
};
let mask_arr = mask.into_arrow()?;
let discriminant_arr =
Arc::new(UInt32Array::from(vec![discriminant, discriminant])) as Arc<dyn Array>;
let mut frags_builder = BinaryBuilder::new();
let mut frags_bytes = Vec::with_capacity(fragments_covered.serialized_size());
fragments_covered.serialize_into(&mut frags_bytes)?;
frags_builder.append_value(frags_bytes);
frags_builder.append_null();
Ok(RecordBatch::try_new(
THREE_VARIANT_RESULT_SCHEMA.clone(),
vec![
Arc::new(mask_arr),
discriminant_arr,
Arc::new(frags_builder.finish()) as Arc<dyn Array>,
],
)?)
}
pub fn serialize(
&self,
fragments_covered: &RoaringBitmap,
format: IndexExprResultWireFormat,
) -> Result<RecordBatch> {
match format {
IndexExprResultWireFormat::ThreeVariant => {
self.serialize_three_variant(fragments_covered)
}
IndexExprResultWireFormat::TwoMask => self.serialize_standard(fragments_covered),
}
}
pub fn deserialize(batch: &RecordBatch) -> Result<(Self, RoaringBitmap)> {
use arrow_array::cast::AsArray;
if batch.num_rows() != 2 {
return Err(Error::invalid_input_source(
format!(
"Expected a batch with exactly 2 rows but there are {} rows",
batch.num_rows()
)
.into(),
));
}
if batch.num_columns() != 3 {
return Err(Error::invalid_input_source(
format!(
"Expected a batch with exactly three columns but there are {} columns",
batch.num_columns()
)
.into(),
));
}
let first_col_name = batch.schema().field(0).name().clone();
let index_result = if first_col_name == "lower" {
let lower = RowAddrMask::from_arrow(batch.column(0).as_binary())?;
let upper_col = batch.column(1).as_binary::<i32>();
if upper_col.is_null(0) && upper_col.is_null(1) {
Self::exact(lower)
} else {
let upper = RowAddrMask::from_arrow(upper_col)?;
Self {
lower,
upper,
exact: false,
}
}
} else if first_col_name == "result" {
let row_addr_mask = RowAddrMask::from_arrow(batch.column(0).as_binary())?;
let match_type = batch
.column(1)
.as_primitive::<arrow_array::types::UInt32Type>()
.values()[0];
if match_type == 0 {
Self::exact(row_addr_mask)
} else if match_type == 1 {
Self::at_most(row_addr_mask)
} else if match_type == 2 {
Self::at_least(row_addr_mask)
} else {
return Err(Error::internal(format!(
"Unexpected match type: {match_type}"
)));
}
} else {
return Err(Error::internal(format!(
"Unexpected column name: {first_col_name}"
)));
};
let frags_col = batch.column(2).as_binary::<i32>();
let fragments = RoaringBitmap::deserialize_from(frags_col.value(0))?;
Ok((index_result, fragments))
}
}