1use crate::data::{DataFrame, Value};
17
18#[derive(Debug, Clone)]
19enum Expr {
20 Num(f64),
21 Col(String),
22 Neg(Box<Expr>),
23 Bin(char, Box<Expr>, Box<Expr>),
24 Func(String, Box<Expr>),
25}
26
27#[derive(Debug, Clone, PartialEq)]
28enum Tok {
29 Num(f64),
30 Ident(String),
31 Op(char),
32}
33
34fn tokenize(s: &str) -> Option<Vec<Tok>> {
35 let chars: Vec<char> = s.chars().collect();
36 let mut toks = Vec::new();
37 let mut i = 0;
38 while i < chars.len() {
39 let c = chars[i];
40 if c.is_whitespace() {
41 i += 1;
42 } else if c.is_ascii_digit()
43 || (c == '.' && i + 1 < chars.len() && chars[i + 1].is_ascii_digit())
44 {
45 let start = i;
46 while i < chars.len() && (chars[i].is_ascii_digit() || chars[i] == '.') {
47 i += 1;
48 }
49 if i < chars.len() && (chars[i] == 'e' || chars[i] == 'E') {
50 i += 1;
51 if i < chars.len() && (chars[i] == '+' || chars[i] == '-') {
52 i += 1;
53 }
54 while i < chars.len() && chars[i].is_ascii_digit() {
55 i += 1;
56 }
57 }
58 let num: String = chars[start..i].iter().collect();
59 toks.push(Tok::Num(num.parse().ok()?));
60 } else if c.is_alphabetic() || c == '_' {
61 let start = i;
62 while i < chars.len()
63 && (chars[i].is_alphanumeric() || chars[i] == '_' || chars[i] == '.')
64 {
65 i += 1;
66 }
67 toks.push(Tok::Ident(chars[start..i].iter().collect()));
68 } else if "+-*/%^()".contains(c) {
69 toks.push(Tok::Op(c));
70 i += 1;
71 } else {
72 return None; }
74 }
75 Some(toks)
76}
77
78struct Parser {
79 toks: Vec<Tok>,
80 pos: usize,
81}
82
83impl Parser {
84 fn peek(&self) -> Option<&Tok> {
85 self.toks.get(self.pos)
86 }
87 fn eat_op(&mut self, c: char) -> bool {
88 if matches!(self.peek(), Some(Tok::Op(o)) if *o == c) {
89 self.pos += 1;
90 true
91 } else {
92 false
93 }
94 }
95 fn expr(&mut self) -> Option<Expr> {
96 let mut left = self.term()?;
97 while let Some(Tok::Op(c @ ('+' | '-'))) = self.peek().cloned() {
98 self.pos += 1;
99 let right = self.term()?;
100 left = Expr::Bin(c, Box::new(left), Box::new(right));
101 }
102 Some(left)
103 }
104 fn term(&mut self) -> Option<Expr> {
105 let mut left = self.factor()?;
106 while let Some(Tok::Op(c @ ('*' | '/' | '%'))) = self.peek().cloned() {
107 self.pos += 1;
108 let right = self.factor()?;
109 left = Expr::Bin(c, Box::new(left), Box::new(right));
110 }
111 Some(left)
112 }
113 fn factor(&mut self) -> Option<Expr> {
114 let base = self.unary()?;
115 if self.eat_op('^') {
116 let exp = self.factor()?; return Some(Expr::Bin('^', Box::new(base), Box::new(exp)));
118 }
119 Some(base)
120 }
121 fn unary(&mut self) -> Option<Expr> {
122 if self.eat_op('-') {
123 return Some(Expr::Neg(Box::new(self.unary()?)));
124 }
125 if self.eat_op('+') {
126 return self.unary();
127 }
128 self.primary()
129 }
130 fn primary(&mut self) -> Option<Expr> {
131 let tok = self.toks.get(self.pos).cloned()?;
132 self.pos += 1;
133 match tok {
134 Tok::Num(n) => Some(Expr::Num(n)),
135 Tok::Op('(') => {
136 let e = self.expr()?;
137 self.eat_op(')').then_some(e)
138 }
139 Tok::Ident(name) => {
140 if self.eat_op('(') {
141 let arg = self.expr()?;
142 if !self.eat_op(')') {
143 return None;
144 }
145 Some(Expr::Func(name.to_lowercase(), Box::new(arg)))
146 } else {
147 Some(Expr::Col(name))
148 }
149 }
150 _ => None,
151 }
152 }
153}
154
155fn parse(s: &str) -> Option<Expr> {
156 let toks = tokenize(s)?;
157 if toks.is_empty() {
158 return None;
159 }
160 let mut p = Parser { toks, pos: 0 };
161 let e = p.expr()?;
162 (p.pos == p.toks.len()).then_some(e)
163}
164
165fn eval(e: &Expr, data: &DataFrame, row: usize) -> Option<f64> {
166 match e {
167 Expr::Num(n) => Some(*n),
168 Expr::Col(name) => data
169 .column(name)
170 .and_then(|c| c.get(row))
171 .and_then(|v| v.as_f64()),
172 Expr::Neg(a) => Some(-eval(a, data, row)?),
173 Expr::Bin(op, a, b) => {
174 let (x, y) = (eval(a, data, row)?, eval(b, data, row)?);
175 Some(match op {
176 '+' => x + y,
177 '-' => x - y,
178 '*' => x * y,
179 '/' => x / y,
180 '%' => x % y,
181 '^' => x.powf(y),
182 _ => return None,
183 })
184 }
185 Expr::Func(name, a) => {
186 if let Some(agg) = aggregate(name) {
189 let vals: Vec<f64> = (0..data.nrows())
190 .filter_map(|r| eval(a, data, r))
191 .filter(|v| v.is_finite())
192 .collect();
193 return Some(agg(&vals));
194 }
195 let x = eval(a, data, row)?;
196 Some(match name.as_str() {
197 "ln" | "log" => x.ln(),
198 "log10" => x.log10(),
199 "log2" => x.log2(),
200 "sqrt" => x.sqrt(),
201 "exp" => x.exp(),
202 "abs" => x.abs(),
203 "sin" => x.sin(),
204 "cos" => x.cos(),
205 "tan" => x.tan(),
206 "floor" => x.floor(),
207 "ceil" => x.ceil(),
208 "round" => x.round(),
209 "sign" => x.signum(),
210 _ => return None,
211 })
212 }
213 }
214}
215
216fn aggregate(name: &str) -> Option<fn(&[f64]) -> f64> {
220 Some(match name {
221 "sum" => |v: &[f64]| v.iter().sum(),
222 "count" => |v: &[f64]| v.len() as f64,
223 "prod" => |v: &[f64]| v.iter().product(),
224 "mean" | "avg" => |v: &[f64]| {
225 if v.is_empty() {
226 f64::NAN
227 } else {
228 v.iter().sum::<f64>() / v.len() as f64
229 }
230 },
231 "max" => |v: &[f64]| v.iter().copied().fold(f64::NAN, f64::max),
232 "min" => |v: &[f64]| v.iter().copied().fold(f64::NAN, f64::min),
233 "median" => |v: &[f64]| {
234 if v.is_empty() {
235 return f64::NAN;
236 }
237 let mut s = v.to_vec();
238 s.sort_by(|a, b| a.partial_cmp(b).unwrap());
239 let m = s.len() / 2;
240 if s.len().is_multiple_of(2) {
241 (s[m - 1] + s[m]) / 2.0
242 } else {
243 s[m]
244 }
245 },
246 _ => return None,
247 })
248}
249
250fn references_known_column(e: &Expr, data: &DataFrame) -> bool {
251 match e {
252 Expr::Col(name) => data.has_column(name),
253 Expr::Num(_) => false,
254 Expr::Neg(a) | Expr::Func(_, a) => references_known_column(a, data),
255 Expr::Bin(_, a, b) => references_known_column(a, data) || references_known_column(b, data),
256 }
257}
258
259pub fn eval_expression(expr: &str, data: &DataFrame) -> Option<Vec<Value>> {
264 let parsed = parse(expr)?;
265 if !references_known_column(&parsed, data) {
266 return None;
267 }
268 let n = data.nrows();
269 let mut out = Vec::with_capacity(n);
270 for row in 0..n {
271 out.push(match eval(&parsed, data, row) {
272 Some(v) if v.is_finite() => Value::Float(v),
273 _ => Value::Na,
274 });
275 }
276 Some(out)
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 fn df() -> DataFrame {
284 let mut d = DataFrame::new();
285 d.add_column("a".into(), vec![Value::Float(2.0), Value::Float(4.0)]);
286 d.add_column("b".into(), vec![Value::Float(8.0), Value::Float(2.0)]);
287 d
288 }
289
290 fn f(vals: &[Value]) -> Vec<f64> {
291 vals.iter().filter_map(|v| v.as_f64()).collect()
292 }
293
294 #[test]
295 fn arithmetic_and_precedence() {
296 let d = df();
297 assert_eq!(f(&eval_expression("a / b", &d).unwrap()), vec![0.25, 2.0]);
298 assert_eq!(
299 f(&eval_expression("a + b * 2", &d).unwrap()),
300 vec![18.0, 8.0]
301 );
302 assert_eq!(
303 f(&eval_expression("(a + b) * 2", &d).unwrap()),
304 vec![20.0, 12.0]
305 );
306 assert_eq!(f(&eval_expression("2 ^ a", &d).unwrap()), vec![4.0, 16.0]);
307 assert_eq!(f(&eval_expression("-a", &d).unwrap()), vec![-2.0, -4.0]);
308 }
309
310 #[test]
311 fn functions() {
312 let d = df();
313 assert_eq!(
314 f(&eval_expression("sqrt(b)", &d).unwrap()),
315 vec![8f64.sqrt(), 2f64.sqrt()]
316 );
317 assert_eq!(f(&eval_expression("log2(b)", &d).unwrap()), vec![3.0, 1.0]);
318 assert_eq!(
319 f(&eval_expression("abs(a - b)", &d).unwrap()),
320 vec![6.0, 2.0]
321 );
322 }
323
324 #[test]
325 fn non_expression_or_unknown_returns_none() {
326 let d = df();
327 assert!(eval_expression("nonexistent_col", &d).is_none());
328 assert!(eval_expression("1 + 2", &d).is_none()); assert!(eval_expression("a +", &d).is_none()); assert!(eval_expression("a $ b", &d).is_none()); }
332
333 #[test]
334 fn aggregates_broadcast_over_column() {
335 let mut d = DataFrame::new();
336 d.add_column(
337 "count".into(),
338 vec![Value::Float(1.0), Value::Float(3.0), Value::Float(4.0)],
339 );
340 assert_eq!(
342 f(&eval_expression("count / sum(count)", &d).unwrap()),
343 vec![0.125, 0.375, 0.5]
344 );
345 assert_eq!(
347 f(&eval_expression("count / max(count)", &d).unwrap()),
348 vec![0.25, 0.75, 1.0]
349 );
350 assert_eq!(
351 f(&eval_expression("mean(count)", &d).unwrap()),
352 vec![8.0 / 3.0; 3]
353 );
354 }
355
356 #[test]
357 fn division_by_zero_is_na() {
358 let mut d = DataFrame::new();
359 d.add_column("a".into(), vec![Value::Float(1.0)]);
360 d.add_column("z".into(), vec![Value::Float(0.0)]);
361 assert!(matches!(
362 eval_expression("a / z", &d).unwrap()[0],
363 Value::Na
364 ));
365 }
366}