Skip to main content

ggplot_rs/aes/
expr.rs

1//! A tiny arithmetic expression evaluator for **computed aesthetics** — so
2//! `aes` can map e.g. `"pop / 1e6"`, `"log(gdp)"`, or `"a * b + 1"` instead of a
3//! bare column name. Anything that isn't an existing column is parsed and
4//! evaluated per row against the data's numeric columns.
5//!
6//! Grammar (standard precedence, `^` right-associative):
7//! `expr := term (('+'|'-') term)*`, `term := factor (('*'|'/'|'%') factor)*`,
8//! `factor := unary ('^' factor)?`, `unary := ('-'|'+') unary | primary`,
9//! `primary := number | ident | ident '(' expr ')' | '(' expr ')'`.
10//! Functions: `ln`/`log`, `log10`, `log2`, `sqrt`, `exp`, `abs`, `sin`, `cos`,
11//! `tan`, `floor`, `ceil`, `round`, `sign`. Aggregate functions reduce their
12//! argument over *all* rows to a scalar (broadcast to every row): `sum`, `mean`
13//! (`avg`), `max`, `min`, `count`, `median`, `prod` — enabling normalized
14//! `after_stat` mappings such as `"count / sum(count)"`.
15
16use crate::data::{DataFrame, Value};
17
18#[derive(Debug, Clone)]
19enum Expr {
20    Num(f64),
21    Col(String),
22    Neg(Box<Expr>),
23    Bin(char, Box<Expr>, Box<Expr>),
24    Func(String, Box<Expr>),
25}
26
27#[derive(Debug, Clone, PartialEq)]
28enum Tok {
29    Num(f64),
30    Ident(String),
31    Op(char),
32}
33
34fn tokenize(s: &str) -> Option<Vec<Tok>> {
35    let chars: Vec<char> = s.chars().collect();
36    let mut toks = Vec::new();
37    let mut i = 0;
38    while i < chars.len() {
39        let c = chars[i];
40        if c.is_whitespace() {
41            i += 1;
42        } else if c.is_ascii_digit()
43            || (c == '.' && i + 1 < chars.len() && chars[i + 1].is_ascii_digit())
44        {
45            let start = i;
46            while i < chars.len() && (chars[i].is_ascii_digit() || chars[i] == '.') {
47                i += 1;
48            }
49            if i < chars.len() && (chars[i] == 'e' || chars[i] == 'E') {
50                i += 1;
51                if i < chars.len() && (chars[i] == '+' || chars[i] == '-') {
52                    i += 1;
53                }
54                while i < chars.len() && chars[i].is_ascii_digit() {
55                    i += 1;
56                }
57            }
58            let num: String = chars[start..i].iter().collect();
59            toks.push(Tok::Num(num.parse().ok()?));
60        } else if c.is_alphabetic() || c == '_' {
61            let start = i;
62            while i < chars.len()
63                && (chars[i].is_alphanumeric() || chars[i] == '_' || chars[i] == '.')
64            {
65                i += 1;
66            }
67            toks.push(Tok::Ident(chars[start..i].iter().collect()));
68        } else if "+-*/%^()".contains(c) {
69            toks.push(Tok::Op(c));
70            i += 1;
71        } else {
72            return None; // unknown character → not an expression
73        }
74    }
75    Some(toks)
76}
77
78struct Parser {
79    toks: Vec<Tok>,
80    pos: usize,
81}
82
83impl Parser {
84    fn peek(&self) -> Option<&Tok> {
85        self.toks.get(self.pos)
86    }
87    fn eat_op(&mut self, c: char) -> bool {
88        if matches!(self.peek(), Some(Tok::Op(o)) if *o == c) {
89            self.pos += 1;
90            true
91        } else {
92            false
93        }
94    }
95    fn expr(&mut self) -> Option<Expr> {
96        let mut left = self.term()?;
97        while let Some(Tok::Op(c @ ('+' | '-'))) = self.peek().cloned() {
98            self.pos += 1;
99            let right = self.term()?;
100            left = Expr::Bin(c, Box::new(left), Box::new(right));
101        }
102        Some(left)
103    }
104    fn term(&mut self) -> Option<Expr> {
105        let mut left = self.factor()?;
106        while let Some(Tok::Op(c @ ('*' | '/' | '%'))) = self.peek().cloned() {
107            self.pos += 1;
108            let right = self.factor()?;
109            left = Expr::Bin(c, Box::new(left), Box::new(right));
110        }
111        Some(left)
112    }
113    fn factor(&mut self) -> Option<Expr> {
114        let base = self.unary()?;
115        if self.eat_op('^') {
116            let exp = self.factor()?; // right-associative
117            return Some(Expr::Bin('^', Box::new(base), Box::new(exp)));
118        }
119        Some(base)
120    }
121    fn unary(&mut self) -> Option<Expr> {
122        if self.eat_op('-') {
123            return Some(Expr::Neg(Box::new(self.unary()?)));
124        }
125        if self.eat_op('+') {
126            return self.unary();
127        }
128        self.primary()
129    }
130    fn primary(&mut self) -> Option<Expr> {
131        let tok = self.toks.get(self.pos).cloned()?;
132        self.pos += 1;
133        match tok {
134            Tok::Num(n) => Some(Expr::Num(n)),
135            Tok::Op('(') => {
136                let e = self.expr()?;
137                self.eat_op(')').then_some(e)
138            }
139            Tok::Ident(name) => {
140                if self.eat_op('(') {
141                    let arg = self.expr()?;
142                    if !self.eat_op(')') {
143                        return None;
144                    }
145                    Some(Expr::Func(name.to_lowercase(), Box::new(arg)))
146                } else {
147                    Some(Expr::Col(name))
148                }
149            }
150            _ => None,
151        }
152    }
153}
154
155fn parse(s: &str) -> Option<Expr> {
156    let toks = tokenize(s)?;
157    if toks.is_empty() {
158        return None;
159    }
160    let mut p = Parser { toks, pos: 0 };
161    let e = p.expr()?;
162    (p.pos == p.toks.len()).then_some(e)
163}
164
165fn eval(e: &Expr, data: &DataFrame, row: usize) -> Option<f64> {
166    match e {
167        Expr::Num(n) => Some(*n),
168        Expr::Col(name) => data
169            .column(name)
170            .and_then(|c| c.get(row))
171            .and_then(|v| v.as_f64()),
172        Expr::Neg(a) => Some(-eval(a, data, row)?),
173        Expr::Bin(op, a, b) => {
174            let (x, y) = (eval(a, data, row)?, eval(b, data, row)?);
175            Some(match op {
176                '+' => x + y,
177                '-' => x - y,
178                '*' => x * y,
179                '/' => x / y,
180                '%' => x % y,
181                '^' => x.powf(y),
182                _ => return None,
183            })
184        }
185        Expr::Func(name, a) => {
186            // Aggregate functions reduce the argument over all rows to a scalar,
187            // broadcast identically to every row (e.g. sum(count) for proportions).
188            if let Some(agg) = aggregate(name) {
189                let vals: Vec<f64> = (0..data.nrows())
190                    .filter_map(|r| eval(a, data, r))
191                    .filter(|v| v.is_finite())
192                    .collect();
193                return Some(agg(&vals));
194            }
195            let x = eval(a, data, row)?;
196            Some(match name.as_str() {
197                "ln" | "log" => x.ln(),
198                "log10" => x.log10(),
199                "log2" => x.log2(),
200                "sqrt" => x.sqrt(),
201                "exp" => x.exp(),
202                "abs" => x.abs(),
203                "sin" => x.sin(),
204                "cos" => x.cos(),
205                "tan" => x.tan(),
206                "floor" => x.floor(),
207                "ceil" => x.ceil(),
208                "round" => x.round(),
209                "sign" => x.signum(),
210                _ => return None,
211            })
212        }
213    }
214}
215
216/// If `name` is an aggregate function, return its reducer over a column's finite
217/// values. An empty input reduces to a neutral value (0 for sum/count, NaN for
218/// the rest, which becomes `Na`).
219fn aggregate(name: &str) -> Option<fn(&[f64]) -> f64> {
220    Some(match name {
221        "sum" => |v: &[f64]| v.iter().sum(),
222        "count" => |v: &[f64]| v.len() as f64,
223        "prod" => |v: &[f64]| v.iter().product(),
224        "mean" | "avg" => |v: &[f64]| {
225            if v.is_empty() {
226                f64::NAN
227            } else {
228                v.iter().sum::<f64>() / v.len() as f64
229            }
230        },
231        "max" => |v: &[f64]| v.iter().copied().fold(f64::NAN, f64::max),
232        "min" => |v: &[f64]| v.iter().copied().fold(f64::NAN, f64::min),
233        "median" => |v: &[f64]| {
234            if v.is_empty() {
235                return f64::NAN;
236            }
237            let mut s = v.to_vec();
238            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
239            let m = s.len() / 2;
240            if s.len().is_multiple_of(2) {
241                (s[m - 1] + s[m]) / 2.0
242            } else {
243                s[m]
244            }
245        },
246        _ => return None,
247    })
248}
249
250fn references_known_column(e: &Expr, data: &DataFrame) -> bool {
251    match e {
252        Expr::Col(name) => data.has_column(name),
253        Expr::Num(_) => false,
254        Expr::Neg(a) | Expr::Func(_, a) => references_known_column(a, data),
255        Expr::Bin(_, a, b) => references_known_column(a, data) || references_known_column(b, data),
256    }
257}
258
259/// Evaluate `expr` over every row of `data`, producing one `Value` per row
260/// (non-finite results become `Value::Na`). Returns `None` if the string is not
261/// a valid expression or references no existing column (so a plain unknown
262/// column name / typo is left for the caller to handle, not silently computed).
263pub fn eval_expression(expr: &str, data: &DataFrame) -> Option<Vec<Value>> {
264    let parsed = parse(expr)?;
265    if !references_known_column(&parsed, data) {
266        return None;
267    }
268    let n = data.nrows();
269    let mut out = Vec::with_capacity(n);
270    for row in 0..n {
271        out.push(match eval(&parsed, data, row) {
272            Some(v) if v.is_finite() => Value::Float(v),
273            _ => Value::Na,
274        });
275    }
276    Some(out)
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    fn df() -> DataFrame {
284        let mut d = DataFrame::new();
285        d.add_column("a".into(), vec![Value::Float(2.0), Value::Float(4.0)]);
286        d.add_column("b".into(), vec![Value::Float(8.0), Value::Float(2.0)]);
287        d
288    }
289
290    fn f(vals: &[Value]) -> Vec<f64> {
291        vals.iter().filter_map(|v| v.as_f64()).collect()
292    }
293
294    #[test]
295    fn arithmetic_and_precedence() {
296        let d = df();
297        assert_eq!(f(&eval_expression("a / b", &d).unwrap()), vec![0.25, 2.0]);
298        assert_eq!(
299            f(&eval_expression("a + b * 2", &d).unwrap()),
300            vec![18.0, 8.0]
301        );
302        assert_eq!(
303            f(&eval_expression("(a + b) * 2", &d).unwrap()),
304            vec![20.0, 12.0]
305        );
306        assert_eq!(f(&eval_expression("2 ^ a", &d).unwrap()), vec![4.0, 16.0]);
307        assert_eq!(f(&eval_expression("-a", &d).unwrap()), vec![-2.0, -4.0]);
308    }
309
310    #[test]
311    fn functions() {
312        let d = df();
313        assert_eq!(
314            f(&eval_expression("sqrt(b)", &d).unwrap()),
315            vec![8f64.sqrt(), 2f64.sqrt()]
316        );
317        assert_eq!(f(&eval_expression("log2(b)", &d).unwrap()), vec![3.0, 1.0]);
318        assert_eq!(
319            f(&eval_expression("abs(a - b)", &d).unwrap()),
320            vec![6.0, 2.0]
321        );
322    }
323
324    #[test]
325    fn non_expression_or_unknown_returns_none() {
326        let d = df();
327        assert!(eval_expression("nonexistent_col", &d).is_none());
328        assert!(eval_expression("1 + 2", &d).is_none()); // no column referenced
329        assert!(eval_expression("a +", &d).is_none()); // parse error
330        assert!(eval_expression("a $ b", &d).is_none()); // bad char
331    }
332
333    #[test]
334    fn aggregates_broadcast_over_column() {
335        let mut d = DataFrame::new();
336        d.add_column(
337            "count".into(),
338            vec![Value::Float(1.0), Value::Float(3.0), Value::Float(4.0)],
339        );
340        // proportion = count / sum(count); sum = 8
341        assert_eq!(
342            f(&eval_expression("count / sum(count)", &d).unwrap()),
343            vec![0.125, 0.375, 0.5]
344        );
345        // normalized = count / max(count); max = 4
346        assert_eq!(
347            f(&eval_expression("count / max(count)", &d).unwrap()),
348            vec![0.25, 0.75, 1.0]
349        );
350        assert_eq!(
351            f(&eval_expression("mean(count)", &d).unwrap()),
352            vec![8.0 / 3.0; 3]
353        );
354    }
355
356    #[test]
357    fn division_by_zero_is_na() {
358        let mut d = DataFrame::new();
359        d.add_column("a".into(), vec![Value::Float(1.0)]);
360        d.add_column("z".into(), vec![Value::Float(0.0)]);
361        assert!(matches!(
362            eval_expression("a / z", &d).unwrap()[0],
363            Value::Na
364        ));
365    }
366}