1use crate::formula::ast::*;
2use crate::formula::functions;
3use crate::formula::parser;
4use crate::model::{CellError, CellPos, CellValue, Sheet};
5
6fn expand_range(start: &CellRef, end: &CellRef) -> Vec<CellPos> {
7 let mut positions = Vec::new();
8 let r1 = start.row.min(end.row);
9 let r2 = start.row.max(end.row);
10 let c1 = start.col.min(end.col);
11 let c2 = start.col.max(end.col);
12 for r in r1..=r2 {
13 for c in c1..=c2 {
14 positions.push((r, c));
15 }
16 }
17 positions
18}
19
20fn resolve_cell(sheet: &Sheet, pos: CellPos) -> CellValue {
21 match sheet.get_cell(pos) {
22 Some(cell) => cell.value.clone(),
23 None => CellValue::Empty,
24 }
25}
26
27fn cell_value_to_number(v: &CellValue) -> Result<f64, CellError> {
28 match v {
29 CellValue::Number(n) => Ok(*n),
30 CellValue::Empty => Ok(0.0),
31 CellValue::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
32 CellValue::Error(e) => Err(e.clone()),
33 CellValue::Text(_) => Err(CellError::Value),
34 }
35}
36
37fn eval_expr(expr: &Expr, sheet: &Sheet) -> CellValue {
38 match expr {
39 Expr::Number(n) => CellValue::Number(*n),
40 Expr::Text(s) => CellValue::Text(s.clone()),
41 Expr::Bool(b) => CellValue::Bool(*b),
42 Expr::CellRef(cell_ref) => {
43 let val = resolve_cell(sheet, (cell_ref.row, cell_ref.col));
44 if val == CellValue::Empty {
45 CellValue::Number(0.0)
46 } else {
47 val
48 }
49 }
50 Expr::Range { .. } => CellValue::Error(CellError::Value),
51 Expr::UnaryNeg(inner) => {
52 let val = eval_expr(inner, sheet);
53 match cell_value_to_number(&val) {
54 Ok(n) => CellValue::Number(-n),
55 Err(e) => CellValue::Error(e),
56 }
57 }
58 Expr::BinaryOp { op, left, right } => {
59 let lval = eval_expr(left, sheet);
60 let rval = eval_expr(right, sheet);
61
62 if let CellValue::Error(e) = &lval {
63 return CellValue::Error(e.clone());
64 }
65 if let CellValue::Error(e) = &rval {
66 return CellValue::Error(e.clone());
67 }
68
69 match op {
70 Op::Add | Op::Sub | Op::Mul | Op::Div => {
71 let ln = match cell_value_to_number(&lval) {
72 Ok(n) => n,
73 Err(e) => return CellValue::Error(e),
74 };
75 let rn = match cell_value_to_number(&rval) {
76 Ok(n) => n,
77 Err(e) => return CellValue::Error(e),
78 };
79 match op {
80 Op::Add => CellValue::Number(ln + rn),
81 Op::Sub => CellValue::Number(ln - rn),
82 Op::Mul => CellValue::Number(ln * rn),
83 Op::Div => {
84 if rn == 0.0 {
85 CellValue::Error(CellError::DivZero)
86 } else {
87 CellValue::Number(ln / rn)
88 }
89 }
90 _ => unreachable!(),
91 }
92 }
93 Op::Gt | Op::Gte | Op::Lt | Op::Lte | Op::Eq | Op::Neq => {
94 let ln = match cell_value_to_number(&lval) {
95 Ok(n) => n,
96 Err(e) => return CellValue::Error(e),
97 };
98 let rn = match cell_value_to_number(&rval) {
99 Ok(n) => n,
100 Err(e) => return CellValue::Error(e),
101 };
102 let result = match op {
103 Op::Gt => ln > rn,
104 Op::Gte => ln >= rn,
105 Op::Lt => ln < rn,
106 Op::Lte => ln <= rn,
107 Op::Eq => (ln - rn).abs() < f64::EPSILON,
108 Op::Neq => (ln - rn).abs() >= f64::EPSILON,
109 _ => unreachable!(),
110 };
111 CellValue::Bool(result)
112 }
113 }
114 }
115 Expr::FnCall { name, args } => {
116 let upper = name.to_uppercase();
117
118 if upper == "IF" {
119 let evaled: Vec<CellValue> = args.iter().map(|a| eval_expr(a, sheet)).collect();
120 return functions::fn_if(&evaled);
121 }
122
123 let mut values = Vec::new();
124 for arg in args {
125 match arg {
126 Expr::Range { start, end } => {
127 for pos in expand_range(start, end) {
128 values.push(resolve_cell(sheet, pos));
129 }
130 }
131 other => {
132 values.push(eval_expr(other, sheet));
133 }
134 }
135 }
136
137 match upper.as_str() {
138 "SUM" => functions::fn_sum(&values),
139 "AVERAGE" => functions::fn_average(&values),
140 "COUNT" => functions::fn_count(&values),
141 "MIN" => functions::fn_min(&values),
142 "MAX" => functions::fn_max(&values),
143 _ => CellValue::Error(CellError::Name),
144 }
145 }
146 }
147}
148
149pub fn evaluate(formula: &str, sheet: &Sheet) -> CellValue {
150 match parser::parse(formula) {
151 Ok(expr) => eval_expr(&expr, sheet),
152 Err(e) => CellValue::Error(e),
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::model::Sheet;
160
161 fn eval_with_sheet(formula: &str, sheet: &Sheet) -> CellValue {
162 evaluate(formula, sheet)
163 }
164
165 fn eval(formula: &str) -> CellValue {
166 let sheet = Sheet::new();
167 eval_with_sheet(formula, &sheet)
168 }
169
170 #[test]
171 fn eval_number() {
172 assert_eq!(eval("42"), CellValue::Number(42.0));
173 }
174
175 #[test]
176 fn eval_addition() {
177 assert_eq!(eval("1+2"), CellValue::Number(3.0));
178 }
179
180 #[test]
181 fn eval_subtraction() {
182 assert_eq!(eval("5-3"), CellValue::Number(2.0));
183 }
184
185 #[test]
186 fn eval_multiplication() {
187 assert_eq!(eval("3*4"), CellValue::Number(12.0));
188 }
189
190 #[test]
191 fn eval_division() {
192 assert_eq!(eval("10/4"), CellValue::Number(2.5));
193 }
194
195 #[test]
196 fn eval_division_by_zero() {
197 assert_eq!(eval("1/0"), CellValue::Error(CellError::DivZero));
198 }
199
200 #[test]
201 fn eval_precedence() {
202 assert_eq!(eval("1+2*3"), CellValue::Number(7.0));
203 }
204
205 #[test]
206 fn eval_parentheses() {
207 assert_eq!(eval("(1+2)*3"), CellValue::Number(9.0));
208 }
209
210 #[test]
211 fn eval_negation() {
212 assert_eq!(eval("-5"), CellValue::Number(-5.0));
213 }
214
215 #[test]
216 fn eval_cell_ref() {
217 let mut sheet = Sheet::new();
218 sheet.set_cell((0, 0), "10");
219 assert_eq!(eval_with_sheet("A1", &sheet), CellValue::Number(10.0));
220 }
221
222 #[test]
223 fn eval_cell_ref_empty() {
224 let sheet = Sheet::new();
225 assert_eq!(eval_with_sheet("A1", &sheet), CellValue::Number(0.0));
226 }
227
228 #[test]
229 fn eval_comparison_gt() {
230 assert_eq!(eval("3>2"), CellValue::Bool(true));
231 assert_eq!(eval("2>3"), CellValue::Bool(false));
232 }
233
234 #[test]
235 fn eval_comparison_eq() {
236 assert_eq!(eval("3=3"), CellValue::Bool(true));
237 assert_eq!(eval("3=4"), CellValue::Bool(false));
238 }
239
240 #[test]
241 fn eval_string() {
242 assert_eq!(eval("\"hello\""), CellValue::Text("hello".into()));
243 }
244
245 #[test]
246 fn eval_string_add_error() {
247 assert_eq!(eval("\"hello\"+1"), CellValue::Error(CellError::Value));
248 }
249
250 #[test]
251 fn eval_bool() {
252 assert_eq!(eval("TRUE"), CellValue::Bool(true));
253 }
254
255 #[test]
256 fn eval_sum() {
257 let mut sheet = Sheet::new();
258 sheet.set_cell((0, 0), "1");
259 sheet.set_cell((1, 0), "2");
260 sheet.set_cell((2, 0), "3");
261 assert_eq!(
262 eval_with_sheet("SUM(A1:A3)", &sheet),
263 CellValue::Number(6.0)
264 );
265 }
266
267 #[test]
268 fn eval_average() {
269 let mut sheet = Sheet::new();
270 sheet.set_cell((0, 0), "2");
271 sheet.set_cell((1, 0), "4");
272 assert_eq!(
273 eval_with_sheet("AVERAGE(A1:A2)", &sheet),
274 CellValue::Number(3.0)
275 );
276 }
277
278 #[test]
279 fn eval_count() {
280 let mut sheet = Sheet::new();
281 sheet.set_cell((0, 0), "1");
282 sheet.set_cell((1, 0), "hello");
283 sheet.set_cell((2, 0), "3");
284 assert_eq!(
285 eval_with_sheet("COUNT(A1:A3)", &sheet),
286 CellValue::Number(2.0)
287 );
288 }
289
290 #[test]
291 fn eval_min() {
292 let mut sheet = Sheet::new();
293 sheet.set_cell((0, 0), "5");
294 sheet.set_cell((1, 0), "2");
295 sheet.set_cell((2, 0), "8");
296 assert_eq!(
297 eval_with_sheet("MIN(A1:A3)", &sheet),
298 CellValue::Number(2.0)
299 );
300 }
301
302 #[test]
303 fn eval_max() {
304 let mut sheet = Sheet::new();
305 sheet.set_cell((0, 0), "5");
306 sheet.set_cell((1, 0), "2");
307 sheet.set_cell((2, 0), "8");
308 assert_eq!(
309 eval_with_sheet("MAX(A1:A3)", &sheet),
310 CellValue::Number(8.0)
311 );
312 }
313
314 #[test]
315 fn eval_if_true() {
316 assert_eq!(eval("IF(TRUE,1,2)"), CellValue::Number(1.0));
317 }
318
319 #[test]
320 fn eval_if_false() {
321 assert_eq!(eval("IF(FALSE,1,2)"), CellValue::Number(2.0));
322 }
323
324 #[test]
325 fn eval_unknown_function() {
326 assert_eq!(eval("FOO(1)"), CellValue::Error(CellError::Name));
327 }
328
329 #[test]
330 fn eval_error_propagation() {
331 let mut sheet = Sheet::new();
332 sheet.set_cell((0, 0), "=1/0");
333 sheet.cells.get_mut(&(0, 0)).unwrap().value = CellValue::Error(CellError::DivZero);
334 assert_eq!(
335 eval_with_sheet("A1+1", &sheet),
336 CellValue::Error(CellError::DivZero)
337 );
338 }
339}