Skip to main content

cell_sheet_core/formula/
eval.rs

1use crate::formula::ast::*;
2use crate::formula::functions;
3use crate::formula::parser;
4use crate::model::{CellError, CellPos, CellValue, Sheet};
5
6fn expand_range(start: &CellRef, end: &CellRef) -> Vec<CellPos> {
7    let mut positions = Vec::new();
8    let r1 = start.row.min(end.row);
9    let r2 = start.row.max(end.row);
10    let c1 = start.col.min(end.col);
11    let c2 = start.col.max(end.col);
12    for r in r1..=r2 {
13        for c in c1..=c2 {
14            positions.push((r, c));
15        }
16    }
17    positions
18}
19
20fn resolve_cell(sheet: &Sheet, pos: CellPos) -> CellValue {
21    match sheet.get_cell(pos) {
22        Some(cell) => cell.value.clone(),
23        None => CellValue::Empty,
24    }
25}
26
27fn cell_value_to_number(v: &CellValue) -> Result<f64, CellError> {
28    match v {
29        CellValue::Number(n) => Ok(*n),
30        CellValue::Empty => Ok(0.0),
31        CellValue::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
32        CellValue::Error(e) => Err(e.clone()),
33        CellValue::Text(_) => Err(CellError::Value),
34    }
35}
36
37fn eval_expr(expr: &Expr, sheet: &Sheet) -> CellValue {
38    match expr {
39        Expr::Number(n) => CellValue::Number(*n),
40        Expr::Text(s) => CellValue::Text(s.clone()),
41        Expr::Bool(b) => CellValue::Bool(*b),
42        Expr::CellRef(cell_ref) => {
43            let val = resolve_cell(sheet, (cell_ref.row, cell_ref.col));
44            if val == CellValue::Empty {
45                CellValue::Number(0.0)
46            } else {
47                val
48            }
49        }
50        Expr::Range { .. } => CellValue::Error(CellError::Value),
51        Expr::UnaryNeg(inner) => {
52            let val = eval_expr(inner, sheet);
53            match cell_value_to_number(&val) {
54                Ok(n) => CellValue::Number(-n),
55                Err(e) => CellValue::Error(e),
56            }
57        }
58        Expr::BinaryOp { op, left, right } => {
59            let lval = eval_expr(left, sheet);
60            let rval = eval_expr(right, sheet);
61
62            if let CellValue::Error(e) = &lval {
63                return CellValue::Error(e.clone());
64            }
65            if let CellValue::Error(e) = &rval {
66                return CellValue::Error(e.clone());
67            }
68
69            match op {
70                Op::Add | Op::Sub | Op::Mul | Op::Div => {
71                    let ln = match cell_value_to_number(&lval) {
72                        Ok(n) => n,
73                        Err(e) => return CellValue::Error(e),
74                    };
75                    let rn = match cell_value_to_number(&rval) {
76                        Ok(n) => n,
77                        Err(e) => return CellValue::Error(e),
78                    };
79                    match op {
80                        Op::Add => CellValue::Number(ln + rn),
81                        Op::Sub => CellValue::Number(ln - rn),
82                        Op::Mul => CellValue::Number(ln * rn),
83                        Op::Div => {
84                            if rn == 0.0 {
85                                CellValue::Error(CellError::DivZero)
86                            } else {
87                                CellValue::Number(ln / rn)
88                            }
89                        }
90                        _ => unreachable!(),
91                    }
92                }
93                Op::Gt | Op::Gte | Op::Lt | Op::Lte | Op::Eq | Op::Neq => {
94                    let ln = match cell_value_to_number(&lval) {
95                        Ok(n) => n,
96                        Err(e) => return CellValue::Error(e),
97                    };
98                    let rn = match cell_value_to_number(&rval) {
99                        Ok(n) => n,
100                        Err(e) => return CellValue::Error(e),
101                    };
102                    let result = match op {
103                        Op::Gt => ln > rn,
104                        Op::Gte => ln >= rn,
105                        Op::Lt => ln < rn,
106                        Op::Lte => ln <= rn,
107                        Op::Eq => (ln - rn).abs() < f64::EPSILON,
108                        Op::Neq => (ln - rn).abs() >= f64::EPSILON,
109                        _ => unreachable!(),
110                    };
111                    CellValue::Bool(result)
112                }
113            }
114        }
115        Expr::FnCall { name, args } => {
116            let upper = name.to_uppercase();
117
118            if upper == "IF" {
119                let evaled: Vec<CellValue> = args.iter().map(|a| eval_expr(a, sheet)).collect();
120                return functions::fn_if(&evaled);
121            }
122
123            let mut values = Vec::new();
124            for arg in args {
125                match arg {
126                    Expr::Range { start, end } => {
127                        for pos in expand_range(start, end) {
128                            values.push(resolve_cell(sheet, pos));
129                        }
130                    }
131                    other => {
132                        values.push(eval_expr(other, sheet));
133                    }
134                }
135            }
136
137            match upper.as_str() {
138                "SUM" => functions::fn_sum(&values),
139                "AVERAGE" => functions::fn_average(&values),
140                "COUNT" => functions::fn_count(&values),
141                "MIN" => functions::fn_min(&values),
142                "MAX" => functions::fn_max(&values),
143                _ => CellValue::Error(CellError::Name),
144            }
145        }
146    }
147}
148
149pub fn evaluate(formula: &str, sheet: &Sheet) -> CellValue {
150    match parser::parse(formula) {
151        Ok(expr) => eval_expr(&expr, sheet),
152        Err(e) => CellValue::Error(e),
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::model::Sheet;
160
161    fn eval_with_sheet(formula: &str, sheet: &Sheet) -> CellValue {
162        evaluate(formula, sheet)
163    }
164
165    fn eval(formula: &str) -> CellValue {
166        let sheet = Sheet::new();
167        eval_with_sheet(formula, &sheet)
168    }
169
170    #[test]
171    fn eval_number() {
172        assert_eq!(eval("42"), CellValue::Number(42.0));
173    }
174
175    #[test]
176    fn eval_addition() {
177        assert_eq!(eval("1+2"), CellValue::Number(3.0));
178    }
179
180    #[test]
181    fn eval_subtraction() {
182        assert_eq!(eval("5-3"), CellValue::Number(2.0));
183    }
184
185    #[test]
186    fn eval_multiplication() {
187        assert_eq!(eval("3*4"), CellValue::Number(12.0));
188    }
189
190    #[test]
191    fn eval_division() {
192        assert_eq!(eval("10/4"), CellValue::Number(2.5));
193    }
194
195    #[test]
196    fn eval_division_by_zero() {
197        assert_eq!(eval("1/0"), CellValue::Error(CellError::DivZero));
198    }
199
200    #[test]
201    fn eval_precedence() {
202        assert_eq!(eval("1+2*3"), CellValue::Number(7.0));
203    }
204
205    #[test]
206    fn eval_parentheses() {
207        assert_eq!(eval("(1+2)*3"), CellValue::Number(9.0));
208    }
209
210    #[test]
211    fn eval_negation() {
212        assert_eq!(eval("-5"), CellValue::Number(-5.0));
213    }
214
215    #[test]
216    fn eval_cell_ref() {
217        let mut sheet = Sheet::new();
218        sheet.set_cell((0, 0), "10");
219        assert_eq!(eval_with_sheet("A1", &sheet), CellValue::Number(10.0));
220    }
221
222    #[test]
223    fn eval_cell_ref_empty() {
224        let sheet = Sheet::new();
225        assert_eq!(eval_with_sheet("A1", &sheet), CellValue::Number(0.0));
226    }
227
228    #[test]
229    fn eval_comparison_gt() {
230        assert_eq!(eval("3>2"), CellValue::Bool(true));
231        assert_eq!(eval("2>3"), CellValue::Bool(false));
232    }
233
234    #[test]
235    fn eval_comparison_eq() {
236        assert_eq!(eval("3=3"), CellValue::Bool(true));
237        assert_eq!(eval("3=4"), CellValue::Bool(false));
238    }
239
240    #[test]
241    fn eval_string() {
242        assert_eq!(eval("\"hello\""), CellValue::Text("hello".into()));
243    }
244
245    #[test]
246    fn eval_string_add_error() {
247        assert_eq!(eval("\"hello\"+1"), CellValue::Error(CellError::Value));
248    }
249
250    #[test]
251    fn eval_bool() {
252        assert_eq!(eval("TRUE"), CellValue::Bool(true));
253    }
254
255    #[test]
256    fn eval_sum() {
257        let mut sheet = Sheet::new();
258        sheet.set_cell((0, 0), "1");
259        sheet.set_cell((1, 0), "2");
260        sheet.set_cell((2, 0), "3");
261        assert_eq!(
262            eval_with_sheet("SUM(A1:A3)", &sheet),
263            CellValue::Number(6.0)
264        );
265    }
266
267    #[test]
268    fn eval_average() {
269        let mut sheet = Sheet::new();
270        sheet.set_cell((0, 0), "2");
271        sheet.set_cell((1, 0), "4");
272        assert_eq!(
273            eval_with_sheet("AVERAGE(A1:A2)", &sheet),
274            CellValue::Number(3.0)
275        );
276    }
277
278    #[test]
279    fn eval_count() {
280        let mut sheet = Sheet::new();
281        sheet.set_cell((0, 0), "1");
282        sheet.set_cell((1, 0), "hello");
283        sheet.set_cell((2, 0), "3");
284        assert_eq!(
285            eval_with_sheet("COUNT(A1:A3)", &sheet),
286            CellValue::Number(2.0)
287        );
288    }
289
290    #[test]
291    fn eval_min() {
292        let mut sheet = Sheet::new();
293        sheet.set_cell((0, 0), "5");
294        sheet.set_cell((1, 0), "2");
295        sheet.set_cell((2, 0), "8");
296        assert_eq!(
297            eval_with_sheet("MIN(A1:A3)", &sheet),
298            CellValue::Number(2.0)
299        );
300    }
301
302    #[test]
303    fn eval_max() {
304        let mut sheet = Sheet::new();
305        sheet.set_cell((0, 0), "5");
306        sheet.set_cell((1, 0), "2");
307        sheet.set_cell((2, 0), "8");
308        assert_eq!(
309            eval_with_sheet("MAX(A1:A3)", &sheet),
310            CellValue::Number(8.0)
311        );
312    }
313
314    #[test]
315    fn eval_if_true() {
316        assert_eq!(eval("IF(TRUE,1,2)"), CellValue::Number(1.0));
317    }
318
319    #[test]
320    fn eval_if_false() {
321        assert_eq!(eval("IF(FALSE,1,2)"), CellValue::Number(2.0));
322    }
323
324    #[test]
325    fn eval_unknown_function() {
326        assert_eq!(eval("FOO(1)"), CellValue::Error(CellError::Name));
327    }
328
329    #[test]
330    fn eval_error_propagation() {
331        let mut sheet = Sheet::new();
332        sheet.set_cell((0, 0), "=1/0");
333        sheet.cells.get_mut(&(0, 0)).unwrap().value = CellValue::Error(CellError::DivZero);
334        assert_eq!(
335            eval_with_sheet("A1+1", &sheet),
336            CellValue::Error(CellError::DivZero)
337        );
338    }
339}