use std::sync::Arc;
use arrow_array::{
ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array, Int8Array, Int16Array,
Int32Array, Int64Array, RecordBatch, Scalar, StringArray, TimestampMicrosecondArray,
UInt8Array, UInt16Array, UInt32Array, UInt64Array,
};
use arrow_ord::cmp;
use crate::error::IndexError;
use crate::filter::FilterIndex;
#[derive(Debug, Clone)]
pub enum ScalarPredicate {
Eq(ScalarValue),
NotEq(ScalarValue),
Lt(ScalarValue),
Lte(ScalarValue),
Gt(ScalarValue),
Gte(ScalarValue),
IsNull,
IsNotNull,
Between {
lo: ScalarValue,
hi: ScalarValue,
},
}
#[derive(Debug, Clone)]
pub enum ScalarValue {
Bool(bool),
Int8(i8),
Int16(i16),
Int32(i32),
Int64(i64),
UInt8(u8),
UInt16(u16),
UInt32(u32),
UInt64(u64),
Float32(f32),
Float64(f64),
Utf8(String),
Date32(i32),
TimestampMicros(i64),
}
impl ScalarValue {
fn to_array(&self) -> ArrayRef {
match self {
Self::Bool(v) => Arc::new(BooleanArray::from(vec![*v])),
Self::Int8(v) => Arc::new(Int8Array::from(vec![*v])),
Self::Int16(v) => Arc::new(Int16Array::from(vec![*v])),
Self::Int32(v) => Arc::new(Int32Array::from(vec![*v])),
Self::Int64(v) => Arc::new(Int64Array::from(vec![*v])),
Self::UInt8(v) => Arc::new(UInt8Array::from(vec![*v])),
Self::UInt16(v) => Arc::new(UInt16Array::from(vec![*v])),
Self::UInt32(v) => Arc::new(UInt32Array::from(vec![*v])),
Self::UInt64(v) => Arc::new(UInt64Array::from(vec![*v])),
Self::Float32(v) => Arc::new(Float32Array::from(vec![*v])),
Self::Float64(v) => Arc::new(Float64Array::from(vec![*v])),
Self::Utf8(v) => Arc::new(StringArray::from(vec![v.as_str()])),
Self::Date32(v) => Arc::new(Date32Array::from(vec![*v])),
Self::TimestampMicros(v) => Arc::new(TimestampMicrosecondArray::from(vec![*v])),
}
}
fn to_scalar(&self) -> Scalar<ArrayRef> {
Scalar::new(self.to_array())
}
}
#[allow(clippy::needless_pass_by_value)]
fn map_err(e: arrow_schema::ArrowError) -> IndexError {
IndexError::PredicateEvalFailed(e.to_string())
}
fn is_null_mask(column: &ArrayRef) -> BooleanArray {
(0..column.len())
.map(|i| Some(column.is_null(i)))
.collect::<BooleanArray>()
}
fn is_not_null_mask(column: &ArrayRef) -> BooleanArray {
(0..column.len())
.map(|i| Some(column.is_valid(i)))
.collect::<BooleanArray>()
}
fn bool_and(a: &BooleanArray, b: &BooleanArray) -> BooleanArray {
a.iter()
.zip(b.iter())
.map(|(x, y)| match (x, y) {
(Some(true), Some(true)) => Some(true),
(Some(false), _) | (_, Some(false)) => Some(false),
_ => None,
})
.collect::<BooleanArray>()
}
pub fn evaluate(column: &ArrayRef, predicate: &ScalarPredicate) -> Result<FilterIndex, IndexError> {
let mask: BooleanArray = match predicate {
ScalarPredicate::Eq(v) => cmp::eq(column, &v.to_scalar()).map_err(map_err)?,
ScalarPredicate::NotEq(v) => cmp::neq(column, &v.to_scalar()).map_err(map_err)?,
ScalarPredicate::Lt(v) => cmp::lt(column, &v.to_scalar()).map_err(map_err)?,
ScalarPredicate::Lte(v) => cmp::lt_eq(column, &v.to_scalar()).map_err(map_err)?,
ScalarPredicate::Gt(v) => cmp::gt(column, &v.to_scalar()).map_err(map_err)?,
ScalarPredicate::Gte(v) => cmp::gt_eq(column, &v.to_scalar()).map_err(map_err)?,
ScalarPredicate::IsNull => is_null_mask(column),
ScalarPredicate::IsNotNull => is_not_null_mask(column),
ScalarPredicate::Between { lo, hi } => {
let gte = cmp::gt_eq(column, &lo.to_scalar()).map_err(map_err)?;
let lte = cmp::lt_eq(column, &hi.to_scalar()).map_err(map_err)?;
bool_and(>e, <e)
}
};
Ok(FilterIndex::from_boolean_array(&mask))
}
#[allow(clippy::cast_possible_truncation)]
pub fn evaluate_batches(
batches: impl Iterator<Item = (RecordBatch, usize)>,
predicate: &ScalarPredicate,
) -> Result<FilterIndex, IndexError> {
let mut combined = FilterIndex::from_ids(std::iter::empty::<u32>());
let mut offset: u64 = 0;
for (batch, col_idx) in batches {
let n = batch.num_rows() as u64;
if offset + n > u64::from(u32::MAX) {
return Err(IndexError::TooManyRows(offset + n));
}
let column = batch.column(col_idx).clone();
let local_filter = evaluate(&column, predicate)?;
let offset_u32 = offset as u32;
let global_ids = local_filter.iter().map(|id| id + offset_u32);
let global_filter = FilterIndex::from_ids(global_ids);
combined = combined.union(&global_filter);
offset += n;
}
Ok(combined)
}