Skip to main content

ggplot_rs/aes/
expr.rs

1//! A tiny arithmetic expression evaluator for **computed aesthetics** — so
2//! `aes` can map e.g. `"pop / 1e6"`, `"log(gdp)"`, or `"a * b + 1"` instead of a
3//! bare column name. Anything that isn't an existing column is parsed and
4//! evaluated per row against the data's numeric columns.
5//!
6//! Grammar (standard precedence, `^` right-associative):
7//! `expr := term (('+'|'-') term)*`, `term := factor (('*'|'/'|'%') factor)*`,
8//! `factor := unary ('^' factor)?`, `unary := ('-'|'+') unary | primary`,
9//! `primary := number | ident | ident '(' expr ')' | '(' expr ')'`.
10//! Functions: `ln`/`log`, `log10`, `log2`, `sqrt`, `exp`, `abs`, `sin`, `cos`,
11//! `tan`, `floor`, `ceil`, `round`, `sign`.
12
13use crate::data::{DataFrame, Value};
14
15#[derive(Debug, Clone)]
16enum Expr {
17    Num(f64),
18    Col(String),
19    Neg(Box<Expr>),
20    Bin(char, Box<Expr>, Box<Expr>),
21    Func(String, Box<Expr>),
22}
23
24#[derive(Debug, Clone, PartialEq)]
25enum Tok {
26    Num(f64),
27    Ident(String),
28    Op(char),
29}
30
31fn tokenize(s: &str) -> Option<Vec<Tok>> {
32    let chars: Vec<char> = s.chars().collect();
33    let mut toks = Vec::new();
34    let mut i = 0;
35    while i < chars.len() {
36        let c = chars[i];
37        if c.is_whitespace() {
38            i += 1;
39        } else if c.is_ascii_digit()
40            || (c == '.' && i + 1 < chars.len() && chars[i + 1].is_ascii_digit())
41        {
42            let start = i;
43            while i < chars.len() && (chars[i].is_ascii_digit() || chars[i] == '.') {
44                i += 1;
45            }
46            if i < chars.len() && (chars[i] == 'e' || chars[i] == 'E') {
47                i += 1;
48                if i < chars.len() && (chars[i] == '+' || chars[i] == '-') {
49                    i += 1;
50                }
51                while i < chars.len() && chars[i].is_ascii_digit() {
52                    i += 1;
53                }
54            }
55            let num: String = chars[start..i].iter().collect();
56            toks.push(Tok::Num(num.parse().ok()?));
57        } else if c.is_alphabetic() || c == '_' {
58            let start = i;
59            while i < chars.len()
60                && (chars[i].is_alphanumeric() || chars[i] == '_' || chars[i] == '.')
61            {
62                i += 1;
63            }
64            toks.push(Tok::Ident(chars[start..i].iter().collect()));
65        } else if "+-*/%^()".contains(c) {
66            toks.push(Tok::Op(c));
67            i += 1;
68        } else {
69            return None; // unknown character → not an expression
70        }
71    }
72    Some(toks)
73}
74
75struct Parser {
76    toks: Vec<Tok>,
77    pos: usize,
78}
79
80impl Parser {
81    fn peek(&self) -> Option<&Tok> {
82        self.toks.get(self.pos)
83    }
84    fn eat_op(&mut self, c: char) -> bool {
85        if matches!(self.peek(), Some(Tok::Op(o)) if *o == c) {
86            self.pos += 1;
87            true
88        } else {
89            false
90        }
91    }
92    fn expr(&mut self) -> Option<Expr> {
93        let mut left = self.term()?;
94        while let Some(Tok::Op(c @ ('+' | '-'))) = self.peek().cloned() {
95            self.pos += 1;
96            let right = self.term()?;
97            left = Expr::Bin(c, Box::new(left), Box::new(right));
98        }
99        Some(left)
100    }
101    fn term(&mut self) -> Option<Expr> {
102        let mut left = self.factor()?;
103        while let Some(Tok::Op(c @ ('*' | '/' | '%'))) = self.peek().cloned() {
104            self.pos += 1;
105            let right = self.factor()?;
106            left = Expr::Bin(c, Box::new(left), Box::new(right));
107        }
108        Some(left)
109    }
110    fn factor(&mut self) -> Option<Expr> {
111        let base = self.unary()?;
112        if self.eat_op('^') {
113            let exp = self.factor()?; // right-associative
114            return Some(Expr::Bin('^', Box::new(base), Box::new(exp)));
115        }
116        Some(base)
117    }
118    fn unary(&mut self) -> Option<Expr> {
119        if self.eat_op('-') {
120            return Some(Expr::Neg(Box::new(self.unary()?)));
121        }
122        if self.eat_op('+') {
123            return self.unary();
124        }
125        self.primary()
126    }
127    fn primary(&mut self) -> Option<Expr> {
128        let tok = self.toks.get(self.pos).cloned()?;
129        self.pos += 1;
130        match tok {
131            Tok::Num(n) => Some(Expr::Num(n)),
132            Tok::Op('(') => {
133                let e = self.expr()?;
134                self.eat_op(')').then_some(e)
135            }
136            Tok::Ident(name) => {
137                if self.eat_op('(') {
138                    let arg = self.expr()?;
139                    if !self.eat_op(')') {
140                        return None;
141                    }
142                    Some(Expr::Func(name.to_lowercase(), Box::new(arg)))
143                } else {
144                    Some(Expr::Col(name))
145                }
146            }
147            _ => None,
148        }
149    }
150}
151
152fn parse(s: &str) -> Option<Expr> {
153    let toks = tokenize(s)?;
154    if toks.is_empty() {
155        return None;
156    }
157    let mut p = Parser { toks, pos: 0 };
158    let e = p.expr()?;
159    (p.pos == p.toks.len()).then_some(e)
160}
161
162fn eval(e: &Expr, data: &DataFrame, row: usize) -> Option<f64> {
163    match e {
164        Expr::Num(n) => Some(*n),
165        Expr::Col(name) => data
166            .column(name)
167            .and_then(|c| c.get(row))
168            .and_then(|v| v.as_f64()),
169        Expr::Neg(a) => Some(-eval(a, data, row)?),
170        Expr::Bin(op, a, b) => {
171            let (x, y) = (eval(a, data, row)?, eval(b, data, row)?);
172            Some(match op {
173                '+' => x + y,
174                '-' => x - y,
175                '*' => x * y,
176                '/' => x / y,
177                '%' => x % y,
178                '^' => x.powf(y),
179                _ => return None,
180            })
181        }
182        Expr::Func(name, a) => {
183            let x = eval(a, data, row)?;
184            Some(match name.as_str() {
185                "ln" | "log" => x.ln(),
186                "log10" => x.log10(),
187                "log2" => x.log2(),
188                "sqrt" => x.sqrt(),
189                "exp" => x.exp(),
190                "abs" => x.abs(),
191                "sin" => x.sin(),
192                "cos" => x.cos(),
193                "tan" => x.tan(),
194                "floor" => x.floor(),
195                "ceil" => x.ceil(),
196                "round" => x.round(),
197                "sign" => x.signum(),
198                _ => return None,
199            })
200        }
201    }
202}
203
204fn references_known_column(e: &Expr, data: &DataFrame) -> bool {
205    match e {
206        Expr::Col(name) => data.has_column(name),
207        Expr::Num(_) => false,
208        Expr::Neg(a) | Expr::Func(_, a) => references_known_column(a, data),
209        Expr::Bin(_, a, b) => references_known_column(a, data) || references_known_column(b, data),
210    }
211}
212
213/// Evaluate `expr` over every row of `data`, producing one `Value` per row
214/// (non-finite results become `Value::Na`). Returns `None` if the string is not
215/// a valid expression or references no existing column (so a plain unknown
216/// column name / typo is left for the caller to handle, not silently computed).
217pub fn eval_expression(expr: &str, data: &DataFrame) -> Option<Vec<Value>> {
218    let parsed = parse(expr)?;
219    if !references_known_column(&parsed, data) {
220        return None;
221    }
222    let n = data.nrows();
223    let mut out = Vec::with_capacity(n);
224    for row in 0..n {
225        out.push(match eval(&parsed, data, row) {
226            Some(v) if v.is_finite() => Value::Float(v),
227            _ => Value::Na,
228        });
229    }
230    Some(out)
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    fn df() -> DataFrame {
238        let mut d = DataFrame::new();
239        d.add_column("a".into(), vec![Value::Float(2.0), Value::Float(4.0)]);
240        d.add_column("b".into(), vec![Value::Float(8.0), Value::Float(2.0)]);
241        d
242    }
243
244    fn f(vals: &[Value]) -> Vec<f64> {
245        vals.iter().filter_map(|v| v.as_f64()).collect()
246    }
247
248    #[test]
249    fn arithmetic_and_precedence() {
250        let d = df();
251        assert_eq!(f(&eval_expression("a / b", &d).unwrap()), vec![0.25, 2.0]);
252        assert_eq!(
253            f(&eval_expression("a + b * 2", &d).unwrap()),
254            vec![18.0, 8.0]
255        );
256        assert_eq!(
257            f(&eval_expression("(a + b) * 2", &d).unwrap()),
258            vec![20.0, 12.0]
259        );
260        assert_eq!(f(&eval_expression("2 ^ a", &d).unwrap()), vec![4.0, 16.0]);
261        assert_eq!(f(&eval_expression("-a", &d).unwrap()), vec![-2.0, -4.0]);
262    }
263
264    #[test]
265    fn functions() {
266        let d = df();
267        assert_eq!(
268            f(&eval_expression("sqrt(b)", &d).unwrap()),
269            vec![8f64.sqrt(), 2f64.sqrt()]
270        );
271        assert_eq!(f(&eval_expression("log2(b)", &d).unwrap()), vec![3.0, 1.0]);
272        assert_eq!(
273            f(&eval_expression("abs(a - b)", &d).unwrap()),
274            vec![6.0, 2.0]
275        );
276    }
277
278    #[test]
279    fn non_expression_or_unknown_returns_none() {
280        let d = df();
281        assert!(eval_expression("nonexistent_col", &d).is_none());
282        assert!(eval_expression("1 + 2", &d).is_none()); // no column referenced
283        assert!(eval_expression("a +", &d).is_none()); // parse error
284        assert!(eval_expression("a $ b", &d).is_none()); // bad char
285    }
286
287    #[test]
288    fn division_by_zero_is_na() {
289        let mut d = DataFrame::new();
290        d.add_column("a".into(), vec![Value::Float(1.0)]);
291        d.add_column("z".into(), vec![Value::Float(0.0)]);
292        assert!(matches!(
293            eval_expression("a / z", &d).unwrap()[0],
294            Value::Na
295        ));
296    }
297}