1use crate::column::Column;
8use crate::dataframe::DataFrame;
9
10#[derive(Debug, Clone)]
12pub enum DExpr {
13 Col(String),
15 LitInt(i64),
17 LitFloat(f64),
19 LitBool(bool),
21 LitStr(String),
23 BinOp {
25 op: BinOp,
26 left: Box<DExpr>,
27 right: Box<DExpr>,
28 },
29 Not(Box<DExpr>),
31 And(Box<DExpr>, Box<DExpr>),
33 Or(Box<DExpr>, Box<DExpr>),
35}
36
37#[derive(Debug, Clone, Copy)]
39pub enum BinOp {
40 Add,
41 Sub,
42 Mul,
43 Div,
44 Eq,
45 Ne,
46 Lt,
47 Le,
48 Gt,
49 Ge,
50}
51
52#[derive(Debug, Clone)]
54pub enum ExprValue {
55 Int(i64),
56 Float(f64),
57 Bool(bool),
58 Str(String),
59}
60
61impl ExprValue {
62 pub fn type_name(&self) -> &'static str {
63 match self {
64 ExprValue::Int(_) => "Int",
65 ExprValue::Float(_) => "Float",
66 ExprValue::Bool(_) => "Bool",
67 ExprValue::Str(_) => "Str",
68 }
69 }
70
71 pub fn as_f64(&self) -> Option<f64> {
72 match self {
73 ExprValue::Int(v) => Some(*v as f64),
74 ExprValue::Float(v) => Some(*v),
75 ExprValue::Bool(v) => Some(if *v { 1.0 } else { 0.0 }),
76 ExprValue::Str(_) => None,
77 }
78 }
79
80 pub fn as_bool(&self) -> Option<bool> {
81 match self {
82 ExprValue::Bool(v) => Some(*v),
83 _ => None,
84 }
85 }
86}
87
88pub fn eval_expr_row(df: &DataFrame, expr: &DExpr, row: usize) -> Result<ExprValue, String> {
90 match expr {
91 DExpr::Col(name) => {
92 let col = df
93 .get_column(name)
94 .ok_or_else(|| format!("column `{}` not found", name))?;
95 Ok(match col {
96 Column::Int(v) => ExprValue::Int(v[row]),
97 Column::Float(v) => ExprValue::Float(v[row]),
98 Column::Str(v) => ExprValue::Str(v[row].clone()),
99 Column::Bool(v) => ExprValue::Bool(v[row]),
100 })
101 }
102 DExpr::LitInt(v) => Ok(ExprValue::Int(*v)),
103 DExpr::LitFloat(v) => Ok(ExprValue::Float(*v)),
104 DExpr::LitBool(v) => Ok(ExprValue::Bool(*v)),
105 DExpr::LitStr(v) => Ok(ExprValue::Str(v.clone())),
106 DExpr::BinOp { op, left, right } => {
107 let lv = eval_expr_row(df, left, row)?;
108 let rv = eval_expr_row(df, right, row)?;
109 eval_binop(*op, &lv, &rv)
110 }
111 DExpr::Not(inner) => {
112 let v = eval_expr_row(df, inner, row)?;
113 match v {
114 ExprValue::Bool(b) => Ok(ExprValue::Bool(!b)),
115 _ => Err(format!("NOT requires Bool, got {}", v.type_name())),
116 }
117 }
118 DExpr::And(a, b) => {
119 let av = eval_expr_row(df, a, row)?;
120 let bv = eval_expr_row(df, b, row)?;
121 match (av, bv) {
122 (ExprValue::Bool(x), ExprValue::Bool(y)) => Ok(ExprValue::Bool(x && y)),
123 _ => Err("AND requires two Bool operands".into()),
124 }
125 }
126 DExpr::Or(a, b) => {
127 let av = eval_expr_row(df, a, row)?;
128 let bv = eval_expr_row(df, b, row)?;
129 match (av, bv) {
130 (ExprValue::Bool(x), ExprValue::Bool(y)) => Ok(ExprValue::Bool(x || y)),
131 _ => Err("OR requires two Bool operands".into()),
132 }
133 }
134 }
135}
136
137fn eval_binop(op: BinOp, lv: &ExprValue, rv: &ExprValue) -> Result<ExprValue, String> {
138 match op {
139 BinOp::Eq => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Equal))),
141 BinOp::Ne => Ok(ExprValue::Bool(cmp_values(lv, rv) != Some(std::cmp::Ordering::Equal))),
142 BinOp::Lt => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Less))),
143 BinOp::Le => Ok(ExprValue::Bool(matches!(
144 cmp_values(lv, rv),
145 Some(std::cmp::Ordering::Less) | Some(std::cmp::Ordering::Equal)
146 ))),
147 BinOp::Gt => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Greater))),
148 BinOp::Ge => Ok(ExprValue::Bool(matches!(
149 cmp_values(lv, rv),
150 Some(std::cmp::Ordering::Greater) | Some(std::cmp::Ordering::Equal)
151 ))),
152 BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => {
154 let l = lv.as_f64().ok_or_else(|| {
155 format!("arithmetic requires numeric types, got {}", lv.type_name())
156 })?;
157 let r = rv.as_f64().ok_or_else(|| {
158 format!("arithmetic requires numeric types, got {}", rv.type_name())
159 })?;
160 let result = match op {
161 BinOp::Add => l + r,
162 BinOp::Sub => l - r,
163 BinOp::Mul => l * r,
164 BinOp::Div => l / r,
165 _ => unreachable!(),
166 };
167 Ok(ExprValue::Float(result))
168 }
169 }
170}
171
172fn cmp_values(a: &ExprValue, b: &ExprValue) -> Option<std::cmp::Ordering> {
173 match (a, b) {
174 (ExprValue::Int(x), ExprValue::Int(y)) => Some(x.cmp(y)),
175 (ExprValue::Float(x), ExprValue::Float(y)) => x.partial_cmp(y),
176 (ExprValue::Int(x), ExprValue::Float(y)) => (*x as f64).partial_cmp(y),
177 (ExprValue::Float(x), ExprValue::Int(y)) => x.partial_cmp(&(*y as f64)),
178 (ExprValue::Str(x), ExprValue::Str(y)) => Some(x.cmp(y)),
179 (ExprValue::Bool(x), ExprValue::Bool(y)) => Some(x.cmp(y)),
180 (ExprValue::Str(x), ExprValue::Int(y)) => Some(x.cmp(&y.to_string())),
181 (ExprValue::Int(x), ExprValue::Str(y)) => Some(x.to_string().cmp(y)),
182 _ => None,
183 }
184}
185
186pub fn try_eval_predicate_columnar(
192 df: &DataFrame,
193 expr: &DExpr,
194 current_mask: &crate::bitmask::BitMask,
195) -> Option<crate::bitmask::BitMask> {
196 match expr {
197 DExpr::BinOp { op, left, right } => {
198 let (col_name, lit, flip) = match (left.as_ref(), right.as_ref()) {
200 (DExpr::Col(name), lit) if is_literal(lit) => (name.as_str(), lit, false),
201 (lit, DExpr::Col(name)) if is_literal(lit) => (name.as_str(), lit, true),
202 _ => return None,
203 };
204
205 let col = df.get_column(col_name)?;
206 let nrows = df.nrows();
207 let mut new_words = current_mask.words.clone();
208
209 match (col, lit) {
210 (Column::Int(data), DExpr::LitInt(val)) => {
211 for row in current_mask.iter_set() {
212 let (l, r) = if flip {
213 (*val, data[row])
214 } else {
215 (data[row], *val)
216 };
217 if !cmp_i64(*op, l, r) {
218 new_words[row / 64] &= !(1u64 << (row % 64));
219 }
220 }
221 }
222 (Column::Float(data), DExpr::LitFloat(val)) => {
223 for row in current_mask.iter_set() {
224 let (l, r) = if flip {
225 (*val, data[row])
226 } else {
227 (data[row], *val)
228 };
229 if !cmp_f64(*op, l, r) {
230 new_words[row / 64] &= !(1u64 << (row % 64));
231 }
232 }
233 }
234 (Column::Int(data), DExpr::LitFloat(val)) => {
235 for row in current_mask.iter_set() {
236 let (l, r) = if flip {
237 (*val, data[row] as f64)
238 } else {
239 (data[row] as f64, *val)
240 };
241 if !cmp_f64(*op, l, r) {
242 new_words[row / 64] &= !(1u64 << (row % 64));
243 }
244 }
245 }
246 (Column::Str(data), DExpr::LitStr(val)) => {
247 for row in current_mask.iter_set() {
248 let pass = if flip {
249 cmp_str(*op, val, &data[row])
250 } else {
251 cmp_str(*op, &data[row], val)
252 };
253 if !pass {
254 new_words[row / 64] &= !(1u64 << (row % 64));
255 }
256 }
257 }
258 _ => return None,
259 }
260
261 Some(crate::bitmask::BitMask {
262 words: new_words,
263 nrows,
264 })
265 }
266 _ => None,
267 }
268}
269
270fn is_literal(expr: &DExpr) -> bool {
271 matches!(
272 expr,
273 DExpr::LitInt(_) | DExpr::LitFloat(_) | DExpr::LitBool(_) | DExpr::LitStr(_)
274 )
275}
276
277#[inline]
278fn cmp_i64(op: BinOp, l: i64, r: i64) -> bool {
279 match op {
280 BinOp::Eq => l == r,
281 BinOp::Ne => l != r,
282 BinOp::Lt => l < r,
283 BinOp::Le => l <= r,
284 BinOp::Gt => l > r,
285 BinOp::Ge => l >= r,
286 _ => false,
287 }
288}
289
290#[inline]
291fn cmp_f64(op: BinOp, l: f64, r: f64) -> bool {
292 match op {
293 BinOp::Eq => l == r,
294 BinOp::Ne => l != r,
295 BinOp::Lt => l < r,
296 BinOp::Le => l <= r,
297 BinOp::Gt => l > r,
298 BinOp::Ge => l >= r,
299 _ => false,
300 }
301}
302
303#[inline]
304fn cmp_str(op: BinOp, l: &str, r: &str) -> bool {
305 match op {
306 BinOp::Eq => l == r,
307 BinOp::Ne => l != r,
308 BinOp::Lt => l < r,
309 BinOp::Le => l <= r,
310 BinOp::Gt => l > r,
311 BinOp::Ge => l >= r,
312 _ => false,
313 }
314}
315
316pub fn col(name: &str) -> DExpr {
320 DExpr::Col(name.to_string())
321}
322
323pub fn binop(op: BinOp, left: DExpr, right: DExpr) -> DExpr {
325 DExpr::BinOp {
326 op,
327 left: Box::new(left),
328 right: Box::new(right),
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_eval_comparison() {
338 let df = DataFrame::from_columns(vec![
339 ("x".into(), Column::Int(vec![10, 20, 30])),
340 ])
341 .unwrap();
342 let expr = binop(BinOp::Gt, col("x"), DExpr::LitInt(15));
343 let r0 = eval_expr_row(&df, &expr, 0).unwrap();
344 let r1 = eval_expr_row(&df, &expr, 1).unwrap();
345 assert_eq!(r0.as_bool(), Some(false)); assert_eq!(r1.as_bool(), Some(true)); }
348
349 #[test]
350 fn test_columnar_fast_path() {
351 let df = DataFrame::from_columns(vec![
352 ("x".into(), Column::Int(vec![1, 2, 3, 4, 5])),
353 ])
354 .unwrap();
355 let mask = crate::bitmask::BitMask::all_true(5);
356 let expr = binop(BinOp::Gt, col("x"), DExpr::LitInt(3));
357 let result = try_eval_predicate_columnar(&df, &expr, &mask).unwrap();
358 let indices: Vec<usize> = result.iter_set().collect();
359 assert_eq!(indices, vec![3, 4]); }
361}