use arrow::array::{Array, Scalar as ArrowScalar};
use arrow::array::{BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray};
use arrow::compute::kernels::cmp as arrow_cmp;
use parquet::data_type::ByteArray;
use parquet::file::metadata::RowGroupMetaData;
use parquet::file::statistics::Statistics;
use parquet::schema::types::SchemaDescriptor;
use crate::ColumnarError;
#[derive(Clone, Debug, PartialEq)]
pub enum Scalar {
Bool(bool),
Int32(i32),
Int64(i64),
Float32(f32),
Float64(f64),
Bytes(Vec<u8>),
Null,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum CmpOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
}
#[derive(Clone, Debug)]
pub enum Predicate {
Cmp {
column: String,
op: CmpOp,
value: Scalar,
},
And(Vec<Predicate>),
Or(Vec<Predicate>),
Not(Box<Predicate>),
All,
None,
}
impl Predicate {
pub fn evaluate_batch(
&self,
batch: &arrow::record_batch::RecordBatch,
) -> Result<BooleanArray, ColumnarError> {
match self {
Predicate::All => Ok(BooleanArray::from(vec![true; batch.num_rows()])),
Predicate::None => Ok(BooleanArray::from(vec![false; batch.num_rows()])),
Predicate::Cmp { column, op, value } => {
let col_idx = batch.schema().index_of(column).map_err(|_| {
ColumnarError::SchemaMismatch(format!("column '{}' not found in batch", column))
})?;
let col = batch.column(col_idx);
evaluate_cmp_column(col.as_ref(), op, value)
}
Predicate::And(preds) => {
let mut result = BooleanArray::from(vec![true; batch.num_rows()]);
for pred in preds {
let mask = pred.evaluate_batch(batch)?;
result = arrow::compute::and(&result, &mask).map_err(ColumnarError::Arrow)?;
}
Ok(result)
}
Predicate::Or(preds) => {
let mut result = BooleanArray::from(vec![false; batch.num_rows()]);
for pred in preds {
let mask = pred.evaluate_batch(batch)?;
result = arrow::compute::or(&result, &mask).map_err(ColumnarError::Arrow)?;
}
Ok(result)
}
Predicate::Not(pred) => {
let mask = pred.evaluate_batch(batch)?;
arrow::compute::not(&mask).map_err(ColumnarError::Arrow)
}
}
}
pub fn row_group_might_match(&self, rg: &RowGroupMetaData, schema: &SchemaDescriptor) -> bool {
match self {
Predicate::All => true,
Predicate::None => false,
Predicate::Not(_) => true,
Predicate::And(preds) => preds.iter().all(|p| p.row_group_might_match(rg, schema)),
Predicate::Or(preds) => preds.iter().any(|p| p.row_group_might_match(rg, schema)),
Predicate::Cmp { column, op, value } => cmp_might_match(rg, schema, column, op, value),
}
}
}
fn evaluate_cmp_column(
col: &dyn Array,
op: &CmpOp,
value: &Scalar,
) -> Result<BooleanArray, ColumnarError> {
match value {
Scalar::Int32(v) => {
let arr = col.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
ColumnarError::SchemaMismatch("type mismatch: expected Int32".into())
})?;
let rhs = ArrowScalar::new(Int32Array::from(vec![*v]));
apply_cmp_i32(arr, op, &rhs)
}
Scalar::Int64(v) => {
let arr = col.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
ColumnarError::SchemaMismatch("type mismatch: expected Int64".into())
})?;
let rhs = ArrowScalar::new(Int64Array::from(vec![*v]));
apply_cmp_i64(arr, op, &rhs)
}
Scalar::Float32(v) => {
let arr = col.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
ColumnarError::SchemaMismatch("type mismatch: expected Float32".into())
})?;
let rhs = ArrowScalar::new(Float32Array::from(vec![*v]));
apply_cmp_f32(arr, op, &rhs)
}
Scalar::Float64(v) => {
let arr = col.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
ColumnarError::SchemaMismatch("type mismatch: expected Float64".into())
})?;
let rhs = ArrowScalar::new(Float64Array::from(vec![*v]));
apply_cmp_f64(arr, op, &rhs)
}
Scalar::Bytes(v) => {
let arr = col.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
ColumnarError::SchemaMismatch(
"type mismatch: expected Utf8 (StringArray) for Bytes comparison".into(),
)
})?;
let s = std::str::from_utf8(v).map_err(|_| {
ColumnarError::SchemaMismatch(
"Bytes scalar is not valid UTF-8; cannot compare against Utf8 column".into(),
)
})?;
let rhs = ArrowScalar::new(StringArray::from(vec![s]));
apply_cmp_str(arr, op, &rhs)
}
Scalar::Bool(_) | Scalar::Null => Err(ColumnarError::SchemaMismatch(format!(
"unsupported scalar type for row-level filter: {value:?}",
))),
}
}
fn apply_cmp_i32(
col: &Int32Array,
op: &CmpOp,
rhs: &ArrowScalar<Int32Array>,
) -> Result<BooleanArray, ColumnarError> {
match op {
CmpOp::Eq => arrow_cmp::eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ne => arrow_cmp::neq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Lt => arrow_cmp::lt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Le => arrow_cmp::lt_eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Gt => arrow_cmp::gt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ge => arrow_cmp::gt_eq(col, rhs).map_err(ColumnarError::Arrow),
}
}
fn apply_cmp_i64(
col: &Int64Array,
op: &CmpOp,
rhs: &ArrowScalar<Int64Array>,
) -> Result<BooleanArray, ColumnarError> {
match op {
CmpOp::Eq => arrow_cmp::eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ne => arrow_cmp::neq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Lt => arrow_cmp::lt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Le => arrow_cmp::lt_eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Gt => arrow_cmp::gt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ge => arrow_cmp::gt_eq(col, rhs).map_err(ColumnarError::Arrow),
}
}
fn apply_cmp_f32(
col: &Float32Array,
op: &CmpOp,
rhs: &ArrowScalar<Float32Array>,
) -> Result<BooleanArray, ColumnarError> {
match op {
CmpOp::Eq => arrow_cmp::eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ne => arrow_cmp::neq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Lt => arrow_cmp::lt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Le => arrow_cmp::lt_eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Gt => arrow_cmp::gt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ge => arrow_cmp::gt_eq(col, rhs).map_err(ColumnarError::Arrow),
}
}
fn apply_cmp_f64(
col: &Float64Array,
op: &CmpOp,
rhs: &ArrowScalar<Float64Array>,
) -> Result<BooleanArray, ColumnarError> {
match op {
CmpOp::Eq => arrow_cmp::eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ne => arrow_cmp::neq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Lt => arrow_cmp::lt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Le => arrow_cmp::lt_eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Gt => arrow_cmp::gt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ge => arrow_cmp::gt_eq(col, rhs).map_err(ColumnarError::Arrow),
}
}
fn apply_cmp_str(
col: &StringArray,
op: &CmpOp,
rhs: &ArrowScalar<StringArray>,
) -> Result<BooleanArray, ColumnarError> {
match op {
CmpOp::Eq => arrow_cmp::eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ne => arrow_cmp::neq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Lt => arrow_cmp::lt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Le => arrow_cmp::lt_eq(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Gt => arrow_cmp::gt(col, rhs).map_err(ColumnarError::Arrow),
CmpOp::Ge => arrow_cmp::gt_eq(col, rhs).map_err(ColumnarError::Arrow),
}
}
fn cmp_might_match(
rg: &RowGroupMetaData,
schema: &SchemaDescriptor,
column: &str,
op: &CmpOp,
value: &Scalar,
) -> bool {
if matches!(value, Scalar::Null) {
return true;
}
let col_idx = match schema.columns().iter().position(|c| c.name() == column) {
Some(idx) => idx,
Option::None => return true, };
if col_idx >= rg.num_columns() {
return true;
}
let stats = match rg.column(col_idx).statistics() {
Some(s) => s,
Option::None => return true, };
evaluate_cmp(stats, op, value)
}
fn evaluate_cmp(stats: &Statistics, op: &CmpOp, value: &Scalar) -> bool {
match (stats, value) {
(Statistics::Boolean(s), Scalar::Bool(v)) => {
let (min, max) = match (s.min_opt(), s.max_opt()) {
(Some(mn), Some(mx)) => (*mn, *mx),
_ => return true,
};
apply_bool_op(op, min, max, *v)
}
(Statistics::Int32(s), Scalar::Int32(v)) => {
let (min, max) = match (s.min_opt(), s.max_opt()) {
(Some(mn), Some(mx)) => (*mn, *mx),
_ => return true,
};
apply_ord_op(op, min, max, *v, s.null_count_opt())
}
(Statistics::Int64(s), Scalar::Int64(v)) => {
let (min, max) = match (s.min_opt(), s.max_opt()) {
(Some(mn), Some(mx)) => (*mn, *mx),
_ => return true,
};
apply_ord_op(op, min, max, *v, s.null_count_opt())
}
(Statistics::Float(s), Scalar::Float32(v)) => {
let (min, max) = match (s.min_opt(), s.max_opt()) {
(Some(mn), Some(mx)) => (*mn, *mx),
_ => return true,
};
apply_float_op_f32(op, min, max, *v, s.null_count_opt())
}
(Statistics::Double(s), Scalar::Float64(v)) => {
let (min, max) = match (s.min_opt(), s.max_opt()) {
(Some(mn), Some(mx)) => (*mn, *mx),
_ => return true,
};
apply_float_op_f64(op, min, max, *v, s.null_count_opt())
}
(Statistics::ByteArray(s), Scalar::Bytes(v)) => {
let (min, max) = match (s.min_opt(), s.max_opt()) {
(Some(mn), Some(mx)) => (mn, mx),
_ => return true,
};
apply_bytes_op(op, min, max, v.as_slice(), s.null_count_opt())
}
_ => true,
}
}
fn apply_bool_op(op: &CmpOp, min: bool, max: bool, v: bool) -> bool {
match op {
CmpOp::Eq => (!min || v) && (!v || max),
CmpOp::Ne => true,
CmpOp::Lt => !min && v,
CmpOp::Le => !min || v,
CmpOp::Gt => max && !v,
CmpOp::Ge => max || !v,
}
}
fn apply_ord_op<T: PartialOrd + PartialEq>(
op: &CmpOp,
min: T,
max: T,
v: T,
null_count: Option<u64>,
) -> bool {
match op {
CmpOp::Eq => min <= v && v <= max,
CmpOp::Ne => {
let all_same = min == v && max == v;
let no_nulls = null_count == Some(0);
!(all_same && no_nulls)
}
CmpOp::Lt => min < v,
CmpOp::Le => min <= v,
CmpOp::Gt => max > v,
CmpOp::Ge => max >= v,
}
}
fn apply_float_op_f32(op: &CmpOp, min: f32, max: f32, v: f32, null_count: Option<u64>) -> bool {
if v.is_nan() || min.is_nan() || max.is_nan() {
return true;
}
match op {
CmpOp::Eq => min <= v && v <= max,
CmpOp::Ne => {
let all_same = (min - v).abs() < f32::EPSILON && (max - v).abs() < f32::EPSILON;
let no_nulls = null_count == Some(0);
!(all_same && no_nulls)
}
CmpOp::Lt => min < v,
CmpOp::Le => min <= v,
CmpOp::Gt => max > v,
CmpOp::Ge => max >= v,
}
}
fn apply_float_op_f64(op: &CmpOp, min: f64, max: f64, v: f64, null_count: Option<u64>) -> bool {
if v.is_nan() || min.is_nan() || max.is_nan() {
return true;
}
match op {
CmpOp::Eq => min <= v && v <= max,
CmpOp::Ne => {
let all_same = (min - v).abs() < f64::EPSILON && (max - v).abs() < f64::EPSILON;
let no_nulls = null_count == Some(0);
!(all_same && no_nulls)
}
CmpOp::Lt => min < v,
CmpOp::Le => min <= v,
CmpOp::Gt => max > v,
CmpOp::Ge => max >= v,
}
}
fn apply_bytes_op(
op: &CmpOp,
min: &ByteArray,
max: &ByteArray,
v: &[u8],
null_count: Option<u64>,
) -> bool {
let min_data = min.data();
let max_data = max.data();
match op {
CmpOp::Eq => min_data <= v && v <= max_data,
CmpOp::Ne => {
let all_same = min_data == v && max_data == v;
let no_nulls = null_count == Some(0);
!(all_same && no_nulls)
}
CmpOp::Lt => min_data < v,
CmpOp::Le => min_data <= v,
CmpOp::Gt => max_data > v,
CmpOp::Ge => max_data >= v,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn predicate_all_and_none() {
let stats = Statistics::int32(Some(10), Some(20), None, Some(0), false);
assert!(evaluate_cmp(&stats, &CmpOp::Eq, &Scalar::Int32(15)));
assert!(!evaluate_cmp(&stats, &CmpOp::Eq, &Scalar::Int32(5)));
assert!(evaluate_cmp(&stats, &CmpOp::Gt, &Scalar::Int32(15)));
assert!(!evaluate_cmp(&stats, &CmpOp::Gt, &Scalar::Int32(25)));
assert!(evaluate_cmp(&stats, &CmpOp::Lt, &Scalar::Int32(15)));
assert!(!evaluate_cmp(&stats, &CmpOp::Lt, &Scalar::Int32(5)));
}
#[test]
fn ne_pruning_requires_no_nulls() {
let stats_no_nulls = Statistics::int32(Some(10), Some(10), None, Some(0), false);
assert!(!evaluate_cmp(
&stats_no_nulls,
&CmpOp::Ne,
&Scalar::Int32(10)
));
assert!(evaluate_cmp(&stats_no_nulls, &CmpOp::Ne, &Scalar::Int32(5)));
let stats_with_nulls = Statistics::int32(Some(10), Some(10), None, Some(1), false);
assert!(evaluate_cmp(
&stats_with_nulls,
&CmpOp::Ne,
&Scalar::Int32(10)
));
}
#[test]
fn type_mismatch_keeps_group() {
let stats = Statistics::int32(Some(10), Some(20), None, Some(0), false);
assert!(evaluate_cmp(&stats, &CmpOp::Eq, &Scalar::Int64(15)));
}
#[test]
fn null_scalar_keeps_group() {
let stats = Statistics::int32(Some(10), Some(20), None, Some(0), false);
assert!(evaluate_cmp(&stats, &CmpOp::Eq, &Scalar::Null));
}
#[test]
fn ge_boundary_keeps_group_at_max() {
let stats = Statistics::int64(Some(1), Some(100), None, Some(0), false);
assert!(evaluate_cmp(&stats, &CmpOp::Ge, &Scalar::Int64(100)));
assert!(!evaluate_cmp(&stats, &CmpOp::Ge, &Scalar::Int64(101)));
}
}