virtual-frame 0.1.1

Deterministic data pipeline toolkit for LLM training — bitmask-filtered virtual views, NFA regex, Kahan summation, full audit trail. Python bindings included.
Documentation
//! Expression system — predicates for filter, computed columns for mutate.
//!
//! DExpr is a simple expression tree: column references, literals, and
//! binary operations. It evaluates against a DataFrame row-by-row or
//! column-by-column (the columnar fast path).

use crate::column::Column;
use crate::dataframe::DataFrame;

/// Data expression — used in filter predicates and mutate computations.
#[derive(Debug, Clone)]
pub enum DExpr {
    /// Column reference: col("name")
    Col(String),
    /// Literal integer
    LitInt(i64),
    /// Literal float
    LitFloat(f64),
    /// Literal bool
    LitBool(bool),
    /// Literal string
    LitStr(String),
    /// Binary operation
    BinOp {
        op: BinOp,
        left: Box<DExpr>,
        right: Box<DExpr>,
    },
    /// Unary NOT
    Not(Box<DExpr>),
    /// AND of two boolean expressions
    And(Box<DExpr>, Box<DExpr>),
    /// OR of two boolean expressions
    Or(Box<DExpr>, Box<DExpr>),
}

/// Binary operators supported in expressions.
#[derive(Debug, Clone, Copy)]
pub enum BinOp {
    Add,
    Sub,
    Mul,
    Div,
    Eq,
    Ne,
    Lt,
    Le,
    Gt,
    Ge,
}

/// Result of evaluating a DExpr at a single row.
#[derive(Debug, Clone)]
pub enum ExprValue {
    Int(i64),
    Float(f64),
    Bool(bool),
    Str(String),
}

impl ExprValue {
    pub fn type_name(&self) -> &'static str {
        match self {
            ExprValue::Int(_) => "Int",
            ExprValue::Float(_) => "Float",
            ExprValue::Bool(_) => "Bool",
            ExprValue::Str(_) => "Str",
        }
    }

    pub fn as_f64(&self) -> Option<f64> {
        match self {
            ExprValue::Int(v) => Some(*v as f64),
            ExprValue::Float(v) => Some(*v),
            ExprValue::Bool(v) => Some(if *v { 1.0 } else { 0.0 }),
            ExprValue::Str(_) => None,
        }
    }

    pub fn as_bool(&self) -> Option<bool> {
        match self {
            ExprValue::Bool(v) => Some(*v),
            _ => None,
        }
    }
}

/// Evaluate a DExpr at a single row of a DataFrame.
pub fn eval_expr_row(df: &DataFrame, expr: &DExpr, row: usize) -> Result<ExprValue, String> {
    match expr {
        DExpr::Col(name) => {
            let col = df
                .get_column(name)
                .ok_or_else(|| format!("column `{}` not found", name))?;
            Ok(match col {
                Column::Int(v) => ExprValue::Int(v[row]),
                Column::Float(v) => ExprValue::Float(v[row]),
                Column::Str(v) => ExprValue::Str(v[row].clone()),
                Column::Bool(v) => ExprValue::Bool(v[row]),
            })
        }
        DExpr::LitInt(v) => Ok(ExprValue::Int(*v)),
        DExpr::LitFloat(v) => Ok(ExprValue::Float(*v)),
        DExpr::LitBool(v) => Ok(ExprValue::Bool(*v)),
        DExpr::LitStr(v) => Ok(ExprValue::Str(v.clone())),
        DExpr::BinOp { op, left, right } => {
            let lv = eval_expr_row(df, left, row)?;
            let rv = eval_expr_row(df, right, row)?;
            eval_binop(*op, &lv, &rv)
        }
        DExpr::Not(inner) => {
            let v = eval_expr_row(df, inner, row)?;
            match v {
                ExprValue::Bool(b) => Ok(ExprValue::Bool(!b)),
                _ => Err(format!("NOT requires Bool, got {}", v.type_name())),
            }
        }
        DExpr::And(a, b) => {
            let av = eval_expr_row(df, a, row)?;
            let bv = eval_expr_row(df, b, row)?;
            match (av, bv) {
                (ExprValue::Bool(x), ExprValue::Bool(y)) => Ok(ExprValue::Bool(x && y)),
                _ => Err("AND requires two Bool operands".into()),
            }
        }
        DExpr::Or(a, b) => {
            let av = eval_expr_row(df, a, row)?;
            let bv = eval_expr_row(df, b, row)?;
            match (av, bv) {
                (ExprValue::Bool(x), ExprValue::Bool(y)) => Ok(ExprValue::Bool(x || y)),
                _ => Err("OR requires two Bool operands".into()),
            }
        }
    }
}

fn eval_binop(op: BinOp, lv: &ExprValue, rv: &ExprValue) -> Result<ExprValue, String> {
    match op {
        // Comparison operators
        BinOp::Eq => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Equal))),
        BinOp::Ne => Ok(ExprValue::Bool(cmp_values(lv, rv) != Some(std::cmp::Ordering::Equal))),
        BinOp::Lt => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Less))),
        BinOp::Le => Ok(ExprValue::Bool(matches!(
            cmp_values(lv, rv),
            Some(std::cmp::Ordering::Less) | Some(std::cmp::Ordering::Equal)
        ))),
        BinOp::Gt => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Greater))),
        BinOp::Ge => Ok(ExprValue::Bool(matches!(
            cmp_values(lv, rv),
            Some(std::cmp::Ordering::Greater) | Some(std::cmp::Ordering::Equal)
        ))),
        // Arithmetic operators
        BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => {
            let l = lv.as_f64().ok_or_else(|| {
                format!("arithmetic requires numeric types, got {}", lv.type_name())
            })?;
            let r = rv.as_f64().ok_or_else(|| {
                format!("arithmetic requires numeric types, got {}", rv.type_name())
            })?;
            let result = match op {
                BinOp::Add => l + r,
                BinOp::Sub => l - r,
                BinOp::Mul => l * r,
                BinOp::Div => l / r,
                _ => unreachable!(),
            };
            Ok(ExprValue::Float(result))
        }
    }
}

fn cmp_values(a: &ExprValue, b: &ExprValue) -> Option<std::cmp::Ordering> {
    match (a, b) {
        (ExprValue::Int(x), ExprValue::Int(y)) => Some(x.cmp(y)),
        (ExprValue::Float(x), ExprValue::Float(y)) => x.partial_cmp(y),
        (ExprValue::Int(x), ExprValue::Float(y)) => (*x as f64).partial_cmp(y),
        (ExprValue::Float(x), ExprValue::Int(y)) => x.partial_cmp(&(*y as f64)),
        (ExprValue::Str(x), ExprValue::Str(y)) => Some(x.cmp(y)),
        (ExprValue::Bool(x), ExprValue::Bool(y)) => Some(x.cmp(y)),
        (ExprValue::Str(x), ExprValue::Int(y)) => Some(x.cmp(&y.to_string())),
        (ExprValue::Int(x), ExprValue::Str(y)) => Some(x.to_string().cmp(y)),
        _ => None,
    }
}

/// Try to evaluate a predicate in columnar mode (fast path).
///
/// For simple predicates like `col("x") > 5`, this scans the column
/// directly instead of evaluating row-by-row. Returns None if the
/// expression is too complex for the columnar fast path.
pub fn try_eval_predicate_columnar(
    df: &DataFrame,
    expr: &DExpr,
    current_mask: &crate::bitmask::BitMask,
) -> Option<crate::bitmask::BitMask> {
    match expr {
        DExpr::BinOp { op, left, right } => {
            // Only handle Col op Literal patterns
            let (col_name, lit, flip) = match (left.as_ref(), right.as_ref()) {
                (DExpr::Col(name), lit) if is_literal(lit) => (name.as_str(), lit, false),
                (lit, DExpr::Col(name)) if is_literal(lit) => (name.as_str(), lit, true),
                _ => return None,
            };

            let col = df.get_column(col_name)?;
            let nrows = df.nrows();
            let mut new_words = current_mask.words.clone();

            match (col, lit) {
                (Column::Int(data), DExpr::LitInt(val)) => {
                    for row in current_mask.iter_set() {
                        let (l, r) = if flip {
                            (*val, data[row])
                        } else {
                            (data[row], *val)
                        };
                        if !cmp_i64(*op, l, r) {
                            new_words[row / 64] &= !(1u64 << (row % 64));
                        }
                    }
                }
                (Column::Float(data), DExpr::LitFloat(val)) => {
                    for row in current_mask.iter_set() {
                        let (l, r) = if flip {
                            (*val, data[row])
                        } else {
                            (data[row], *val)
                        };
                        if !cmp_f64(*op, l, r) {
                            new_words[row / 64] &= !(1u64 << (row % 64));
                        }
                    }
                }
                (Column::Int(data), DExpr::LitFloat(val)) => {
                    for row in current_mask.iter_set() {
                        let (l, r) = if flip {
                            (*val, data[row] as f64)
                        } else {
                            (data[row] as f64, *val)
                        };
                        if !cmp_f64(*op, l, r) {
                            new_words[row / 64] &= !(1u64 << (row % 64));
                        }
                    }
                }
                (Column::Str(data), DExpr::LitStr(val)) => {
                    for row in current_mask.iter_set() {
                        let pass = if flip {
                            cmp_str(*op, val, &data[row])
                        } else {
                            cmp_str(*op, &data[row], val)
                        };
                        if !pass {
                            new_words[row / 64] &= !(1u64 << (row % 64));
                        }
                    }
                }
                _ => return None,
            }

            Some(crate::bitmask::BitMask {
                words: new_words,
                nrows,
            })
        }
        _ => None,
    }
}

fn is_literal(expr: &DExpr) -> bool {
    matches!(
        expr,
        DExpr::LitInt(_) | DExpr::LitFloat(_) | DExpr::LitBool(_) | DExpr::LitStr(_)
    )
}

#[inline]
fn cmp_i64(op: BinOp, l: i64, r: i64) -> bool {
    match op {
        BinOp::Eq => l == r,
        BinOp::Ne => l != r,
        BinOp::Lt => l < r,
        BinOp::Le => l <= r,
        BinOp::Gt => l > r,
        BinOp::Ge => l >= r,
        _ => false,
    }
}

#[inline]
fn cmp_f64(op: BinOp, l: f64, r: f64) -> bool {
    match op {
        BinOp::Eq => l == r,
        BinOp::Ne => l != r,
        BinOp::Lt => l < r,
        BinOp::Le => l <= r,
        BinOp::Gt => l > r,
        BinOp::Ge => l >= r,
        _ => false,
    }
}

#[inline]
fn cmp_str(op: BinOp, l: &str, r: &str) -> bool {
    match op {
        BinOp::Eq => l == r,
        BinOp::Ne => l != r,
        BinOp::Lt => l < r,
        BinOp::Le => l <= r,
        BinOp::Gt => l > r,
        BinOp::Ge => l >= r,
        _ => false,
    }
}

// ── Builder helpers (for Python API ergonomics) ──────────────────────────

/// Create a column reference expression.
pub fn col(name: &str) -> DExpr {
    DExpr::Col(name.to_string())
}

/// Create a binary operation expression.
pub fn binop(op: BinOp, left: DExpr, right: DExpr) -> DExpr {
    DExpr::BinOp {
        op,
        left: Box::new(left),
        right: Box::new(right),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_eval_comparison() {
        let df = DataFrame::from_columns(vec![
            ("x".into(), Column::Int(vec![10, 20, 30])),
        ])
        .unwrap();
        let expr = binop(BinOp::Gt, col("x"), DExpr::LitInt(15));
        let r0 = eval_expr_row(&df, &expr, 0).unwrap();
        let r1 = eval_expr_row(&df, &expr, 1).unwrap();
        assert_eq!(r0.as_bool(), Some(false)); // 10 > 15 = false
        assert_eq!(r1.as_bool(), Some(true)); // 20 > 15 = true
    }

    #[test]
    fn test_columnar_fast_path() {
        let df = DataFrame::from_columns(vec![
            ("x".into(), Column::Int(vec![1, 2, 3, 4, 5])),
        ])
        .unwrap();
        let mask = crate::bitmask::BitMask::all_true(5);
        let expr = binop(BinOp::Gt, col("x"), DExpr::LitInt(3));
        let result = try_eval_predicate_columnar(&df, &expr, &mask).unwrap();
        let indices: Vec<usize> = result.iter_set().collect();
        assert_eq!(indices, vec![3, 4]); // x=4 and x=5
    }
}