use crate::expr::{ColumnRef, ComparisonOp, Datum, Predicate};
use crate::spec::{PrimitiveType, Schema, Type};
use std::collections::HashMap;
fn resolve_column_id(col: &ColumnRef, schema: &Schema) -> Option<i32> {
match col {
ColumnRef::Id(id) => Some(*id),
ColumnRef::Named(name) => schema.as_struct().field_by_name(name).map(|f| f.id()),
}
}
fn get_field_type(field_id: i32, schema: &Schema) -> Option<&PrimitiveType> {
schema.as_struct().field_by_id(field_id).and_then(|f| {
if let Type::Primitive(p) = f.field_type() {
Some(p)
} else {
None
}
})
}
pub fn evaluate_bounds(
predicate: &Predicate,
schema: &Schema,
lower_bounds: &HashMap<i32, Vec<u8>>,
upper_bounds: &HashMap<i32, Vec<u8>>,
null_counts: &HashMap<i32, i64>,
row_count: i64,
) -> bool {
match predicate {
Predicate::AlwaysTrue => true,
Predicate::AlwaysFalse => false,
Predicate::Comparison { column, op, value } => {
let Some(field_id) = resolve_column_id(column, schema) else {
return true;
};
let Some(prim_type) = get_field_type(field_id, schema) else {
return true;
};
let lower = lower_bounds
.get(&field_id)
.and_then(|b| Datum::from_bytes(b, prim_type));
let upper = upper_bounds
.get(&field_id)
.and_then(|b| Datum::from_bytes(b, prim_type));
evaluate_comparison(value, *op, lower.as_ref(), upper.as_ref())
}
Predicate::IsNull(column) => {
let Some(field_id) = resolve_column_id(column, schema) else {
return true;
};
match null_counts.get(&field_id) {
Some(&0) => false,
_ => true, }
}
Predicate::IsNotNull(column) => {
let Some(field_id) = resolve_column_id(column, schema) else {
return true;
};
match null_counts.get(&field_id) {
Some(&count) if count == row_count => false,
_ => true, }
}
Predicate::In { column, values } => {
let Some(field_id) = resolve_column_id(column, schema) else {
return true;
};
let Some(prim_type) = get_field_type(field_id, schema) else {
return true;
};
let lower = lower_bounds
.get(&field_id)
.and_then(|b| Datum::from_bytes(b, prim_type));
let upper = upper_bounds
.get(&field_id)
.and_then(|b| Datum::from_bytes(b, prim_type));
if let (Some(lower), Some(upper)) = (&lower, &upper) {
for v in values {
let ge_lower = v
.compare(lower)
.map(|o| o != std::cmp::Ordering::Less)
.unwrap_or(true);
let le_upper = v
.compare(upper)
.map(|o| o != std::cmp::Ordering::Greater)
.unwrap_or(true);
if ge_lower && le_upper {
return true;
}
}
return false;
}
true
}
Predicate::And(preds) => preds.iter().all(|p| {
evaluate_bounds(
p,
schema,
lower_bounds,
upper_bounds,
null_counts,
row_count,
)
}),
Predicate::Or(preds) => preds.iter().any(|p| {
evaluate_bounds(
p,
schema,
lower_bounds,
upper_bounds,
null_counts,
row_count,
)
}),
Predicate::Not(_inner) => {
true
}
}
}
fn evaluate_comparison(
value: &Datum,
op: ComparisonOp,
lower: Option<&Datum>,
upper: Option<&Datum>,
) -> bool {
match op {
ComparisonOp::Eq => {
if let Some(lower) = lower {
if let Some(ord) = value.compare(lower) {
if ord == std::cmp::Ordering::Less {
return false; }
}
}
if let Some(upper) = upper {
if let Some(ord) = value.compare(upper) {
if ord == std::cmp::Ordering::Greater {
return false; }
}
}
true
}
ComparisonOp::NotEq => {
if let (Some(lower), Some(upper)) = (lower, upper) {
if lower == upper && value == lower {
return false;
}
}
true
}
ComparisonOp::Lt => {
if let Some(lower) = lower {
if let Some(ord) = lower.compare(value) {
if ord != std::cmp::Ordering::Less {
return false; }
}
}
true
}
ComparisonOp::LtEq => {
if let Some(lower) = lower {
if let Some(ord) = lower.compare(value) {
if ord == std::cmp::Ordering::Greater {
return false; }
}
}
true
}
ComparisonOp::Gt => {
if let Some(upper) = upper {
if let Some(ord) = upper.compare(value) {
if ord != std::cmp::Ordering::Greater {
return false; }
}
}
true
}
ComparisonOp::GtEq => {
if let Some(upper) = upper {
if let Some(ord) = upper.compare(value) {
if ord == std::cmp::Ordering::Less {
return false; }
}
}
true
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_evaluate_eq_in_range() {
let lower = Some(Datum::Int(10));
let upper = Some(Datum::Int(100));
let value = Datum::Int(50);
assert!(evaluate_comparison(
&value,
ComparisonOp::Eq,
lower.as_ref(),
upper.as_ref()
));
}
#[test]
fn test_evaluate_eq_below_range() {
let lower = Some(Datum::Int(10));
let upper = Some(Datum::Int(100));
let value = Datum::Int(5);
assert!(!evaluate_comparison(
&value,
ComparisonOp::Eq,
lower.as_ref(),
upper.as_ref()
));
}
#[test]
fn test_evaluate_eq_above_range() {
let lower = Some(Datum::Int(10));
let upper = Some(Datum::Int(100));
let value = Datum::Int(150);
assert!(!evaluate_comparison(
&value,
ComparisonOp::Eq,
lower.as_ref(),
upper.as_ref()
));
}
#[test]
fn test_evaluate_lt_skip() {
let lower = Some(Datum::Int(10));
let upper = Some(Datum::Int(100));
let value = Datum::Int(5);
assert!(!evaluate_comparison(
&value,
ComparisonOp::Lt,
lower.as_ref(),
upper.as_ref()
));
}
#[test]
fn test_evaluate_lt_no_skip() {
let lower = Some(Datum::Int(10));
let upper = Some(Datum::Int(100));
let value = Datum::Int(50);
assert!(evaluate_comparison(
&value,
ComparisonOp::Lt,
lower.as_ref(),
upper.as_ref()
));
}
#[test]
fn test_evaluate_gt_skip() {
let lower = Some(Datum::Int(10));
let upper = Some(Datum::Int(100));
let value = Datum::Int(150);
assert!(!evaluate_comparison(
&value,
ComparisonOp::Gt,
lower.as_ref(),
upper.as_ref()
));
}
#[test]
fn test_decode_bound_int() {
let bytes = 42i32.to_le_bytes().to_vec();
assert_eq!(
Datum::from_bytes(&bytes, &PrimitiveType::Int),
Some(Datum::Int(42))
);
}
#[test]
fn test_decode_bound_string() {
let bytes = b"hello".to_vec();
assert_eq!(
Datum::from_bytes(&bytes, &PrimitiveType::String),
Some(Datum::String("hello".to_string()))
);
}
}