use crate::column::Column;
use crate::dataframe::DataFrame;
#[derive(Debug, Clone)]
pub enum DExpr {
Col(String),
LitInt(i64),
LitFloat(f64),
LitBool(bool),
LitStr(String),
BinOp {
op: BinOp,
left: Box<DExpr>,
right: Box<DExpr>,
},
Not(Box<DExpr>),
And(Box<DExpr>, Box<DExpr>),
Or(Box<DExpr>, Box<DExpr>),
}
#[derive(Debug, Clone, Copy)]
pub enum BinOp {
Add,
Sub,
Mul,
Div,
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
}
#[derive(Debug, Clone)]
pub enum ExprValue {
Int(i64),
Float(f64),
Bool(bool),
Str(String),
}
impl ExprValue {
pub fn type_name(&self) -> &'static str {
match self {
ExprValue::Int(_) => "Int",
ExprValue::Float(_) => "Float",
ExprValue::Bool(_) => "Bool",
ExprValue::Str(_) => "Str",
}
}
pub fn as_f64(&self) -> Option<f64> {
match self {
ExprValue::Int(v) => Some(*v as f64),
ExprValue::Float(v) => Some(*v),
ExprValue::Bool(v) => Some(if *v { 1.0 } else { 0.0 }),
ExprValue::Str(_) => None,
}
}
pub fn as_bool(&self) -> Option<bool> {
match self {
ExprValue::Bool(v) => Some(*v),
_ => None,
}
}
}
pub fn eval_expr_row(df: &DataFrame, expr: &DExpr, row: usize) -> Result<ExprValue, String> {
match expr {
DExpr::Col(name) => {
let col = df
.get_column(name)
.ok_or_else(|| format!("column `{}` not found", name))?;
Ok(match col {
Column::Int(v) => ExprValue::Int(v[row]),
Column::Float(v) => ExprValue::Float(v[row]),
Column::Str(v) => ExprValue::Str(v[row].clone()),
Column::Bool(v) => ExprValue::Bool(v[row]),
})
}
DExpr::LitInt(v) => Ok(ExprValue::Int(*v)),
DExpr::LitFloat(v) => Ok(ExprValue::Float(*v)),
DExpr::LitBool(v) => Ok(ExprValue::Bool(*v)),
DExpr::LitStr(v) => Ok(ExprValue::Str(v.clone())),
DExpr::BinOp { op, left, right } => {
let lv = eval_expr_row(df, left, row)?;
let rv = eval_expr_row(df, right, row)?;
eval_binop(*op, &lv, &rv)
}
DExpr::Not(inner) => {
let v = eval_expr_row(df, inner, row)?;
match v {
ExprValue::Bool(b) => Ok(ExprValue::Bool(!b)),
_ => Err(format!("NOT requires Bool, got {}", v.type_name())),
}
}
DExpr::And(a, b) => {
let av = eval_expr_row(df, a, row)?;
let bv = eval_expr_row(df, b, row)?;
match (av, bv) {
(ExprValue::Bool(x), ExprValue::Bool(y)) => Ok(ExprValue::Bool(x && y)),
_ => Err("AND requires two Bool operands".into()),
}
}
DExpr::Or(a, b) => {
let av = eval_expr_row(df, a, row)?;
let bv = eval_expr_row(df, b, row)?;
match (av, bv) {
(ExprValue::Bool(x), ExprValue::Bool(y)) => Ok(ExprValue::Bool(x || y)),
_ => Err("OR requires two Bool operands".into()),
}
}
}
}
fn eval_binop(op: BinOp, lv: &ExprValue, rv: &ExprValue) -> Result<ExprValue, String> {
match op {
BinOp::Eq => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Equal))),
BinOp::Ne => Ok(ExprValue::Bool(cmp_values(lv, rv) != Some(std::cmp::Ordering::Equal))),
BinOp::Lt => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Less))),
BinOp::Le => Ok(ExprValue::Bool(matches!(
cmp_values(lv, rv),
Some(std::cmp::Ordering::Less) | Some(std::cmp::Ordering::Equal)
))),
BinOp::Gt => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Greater))),
BinOp::Ge => Ok(ExprValue::Bool(matches!(
cmp_values(lv, rv),
Some(std::cmp::Ordering::Greater) | Some(std::cmp::Ordering::Equal)
))),
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => {
let l = lv.as_f64().ok_or_else(|| {
format!("arithmetic requires numeric types, got {}", lv.type_name())
})?;
let r = rv.as_f64().ok_or_else(|| {
format!("arithmetic requires numeric types, got {}", rv.type_name())
})?;
let result = match op {
BinOp::Add => l + r,
BinOp::Sub => l - r,
BinOp::Mul => l * r,
BinOp::Div => l / r,
_ => unreachable!(),
};
Ok(ExprValue::Float(result))
}
}
}
fn cmp_values(a: &ExprValue, b: &ExprValue) -> Option<std::cmp::Ordering> {
match (a, b) {
(ExprValue::Int(x), ExprValue::Int(y)) => Some(x.cmp(y)),
(ExprValue::Float(x), ExprValue::Float(y)) => x.partial_cmp(y),
(ExprValue::Int(x), ExprValue::Float(y)) => (*x as f64).partial_cmp(y),
(ExprValue::Float(x), ExprValue::Int(y)) => x.partial_cmp(&(*y as f64)),
(ExprValue::Str(x), ExprValue::Str(y)) => Some(x.cmp(y)),
(ExprValue::Bool(x), ExprValue::Bool(y)) => Some(x.cmp(y)),
(ExprValue::Str(x), ExprValue::Int(y)) => Some(x.cmp(&y.to_string())),
(ExprValue::Int(x), ExprValue::Str(y)) => Some(x.to_string().cmp(y)),
_ => None,
}
}
pub fn try_eval_predicate_columnar(
df: &DataFrame,
expr: &DExpr,
current_mask: &crate::bitmask::BitMask,
) -> Option<crate::bitmask::BitMask> {
match expr {
DExpr::BinOp { op, left, right } => {
let (col_name, lit, flip) = match (left.as_ref(), right.as_ref()) {
(DExpr::Col(name), lit) if is_literal(lit) => (name.as_str(), lit, false),
(lit, DExpr::Col(name)) if is_literal(lit) => (name.as_str(), lit, true),
_ => return None,
};
let col = df.get_column(col_name)?;
let nrows = df.nrows();
let mut new_words = current_mask.words.clone();
match (col, lit) {
(Column::Int(data), DExpr::LitInt(val)) => {
for row in current_mask.iter_set() {
let (l, r) = if flip {
(*val, data[row])
} else {
(data[row], *val)
};
if !cmp_i64(*op, l, r) {
new_words[row / 64] &= !(1u64 << (row % 64));
}
}
}
(Column::Float(data), DExpr::LitFloat(val)) => {
for row in current_mask.iter_set() {
let (l, r) = if flip {
(*val, data[row])
} else {
(data[row], *val)
};
if !cmp_f64(*op, l, r) {
new_words[row / 64] &= !(1u64 << (row % 64));
}
}
}
(Column::Int(data), DExpr::LitFloat(val)) => {
for row in current_mask.iter_set() {
let (l, r) = if flip {
(*val, data[row] as f64)
} else {
(data[row] as f64, *val)
};
if !cmp_f64(*op, l, r) {
new_words[row / 64] &= !(1u64 << (row % 64));
}
}
}
(Column::Str(data), DExpr::LitStr(val)) => {
for row in current_mask.iter_set() {
let pass = if flip {
cmp_str(*op, val, &data[row])
} else {
cmp_str(*op, &data[row], val)
};
if !pass {
new_words[row / 64] &= !(1u64 << (row % 64));
}
}
}
_ => return None,
}
Some(crate::bitmask::BitMask {
words: new_words,
nrows,
})
}
_ => None,
}
}
fn is_literal(expr: &DExpr) -> bool {
matches!(
expr,
DExpr::LitInt(_) | DExpr::LitFloat(_) | DExpr::LitBool(_) | DExpr::LitStr(_)
)
}
#[inline]
fn cmp_i64(op: BinOp, l: i64, r: i64) -> bool {
match op {
BinOp::Eq => l == r,
BinOp::Ne => l != r,
BinOp::Lt => l < r,
BinOp::Le => l <= r,
BinOp::Gt => l > r,
BinOp::Ge => l >= r,
_ => false,
}
}
#[inline]
fn cmp_f64(op: BinOp, l: f64, r: f64) -> bool {
match op {
BinOp::Eq => l == r,
BinOp::Ne => l != r,
BinOp::Lt => l < r,
BinOp::Le => l <= r,
BinOp::Gt => l > r,
BinOp::Ge => l >= r,
_ => false,
}
}
#[inline]
fn cmp_str(op: BinOp, l: &str, r: &str) -> bool {
match op {
BinOp::Eq => l == r,
BinOp::Ne => l != r,
BinOp::Lt => l < r,
BinOp::Le => l <= r,
BinOp::Gt => l > r,
BinOp::Ge => l >= r,
_ => false,
}
}
pub fn col(name: &str) -> DExpr {
DExpr::Col(name.to_string())
}
pub fn binop(op: BinOp, left: DExpr, right: DExpr) -> DExpr {
DExpr::BinOp {
op,
left: Box::new(left),
right: Box::new(right),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eval_comparison() {
let df = DataFrame::from_columns(vec![
("x".into(), Column::Int(vec![10, 20, 30])),
])
.unwrap();
let expr = binop(BinOp::Gt, col("x"), DExpr::LitInt(15));
let r0 = eval_expr_row(&df, &expr, 0).unwrap();
let r1 = eval_expr_row(&df, &expr, 1).unwrap();
assert_eq!(r0.as_bool(), Some(false)); assert_eq!(r1.as_bool(), Some(true)); }
#[test]
fn test_columnar_fast_path() {
let df = DataFrame::from_columns(vec![
("x".into(), Column::Int(vec![1, 2, 3, 4, 5])),
])
.unwrap();
let mask = crate::bitmask::BitMask::all_true(5);
let expr = binop(BinOp::Gt, col("x"), DExpr::LitInt(3));
let result = try_eval_predicate_columnar(&df, &expr, &mask).unwrap();
let indices: Vec<usize> = result.iter_set().collect();
assert_eq!(indices, vec![3, 4]); }
}