Skip to main content

pdforg_sheets/
evaluator.rs

1//! Formula evaluator — evaluates Expr ASTs against a sheet context.
2
3use crate::parser::{Expr, BinOpKind, UnaryOpKind};
4use crate::functions::FUNCTIONS;
5use thiserror::Error;
6use pdf_core::{CellAddress, CellValue, CellError, Sheet, Workbook};
7use std::collections::HashSet;
8
9#[derive(Debug, Error)]
10pub enum EvalError {
11    #[error("Unknown function: {0}")]
12    UnknownFunc(String),
13    #[error("Wrong number of arguments for {0}: expected {1}, got {2}")]
14    ArgCount(String, String, usize),
15    #[error("Type error: {0}")]
16    TypeError(String),
17    #[error("Division by zero")]
18    DivZero,
19    #[error("Circular reference detected")]
20    CircularRef,
21    #[error("Invalid reference: {0}")]
22    InvalidRef(String),
23    #[error("Formula error: {0}")]
24    Formula(String),
25}
26
27/// Context for formula evaluation
28pub struct EvalContext<'a> {
29    pub workbook: &'a Workbook,
30    pub current_sheet: usize,
31    pub current_cell: CellAddress,
32    pub call_stack: HashSet<CellAddress>,  // for circular ref detection
33}
34
35impl<'a> EvalContext<'a> {
36    pub fn new(workbook: &'a Workbook, sheet: usize, cell: CellAddress) -> Self {
37        EvalContext {
38            workbook,
39            current_sheet: sheet,
40            current_cell: cell,
41            call_stack: HashSet::new(),
42        }
43    }
44
45    pub fn sheet(&self) -> Option<&Sheet> {
46        self.workbook.sheets.get(self.current_sheet)
47    }
48
49    pub fn get_cell_value(&self, addr: &CellAddress) -> CellValue {
50        let sheet_idx = if let Some(ref name) = addr.sheet {
51            self.workbook.sheets.iter().position(|s| &s.name == name)
52                .unwrap_or(self.current_sheet)
53        } else {
54            self.current_sheet
55        };
56
57        self.workbook.sheets.get(sheet_idx)
58            .and_then(|s| s.get_cell(addr.row, addr.col))
59            .map(|c| c.value.clone())
60            .unwrap_or(CellValue::Empty)
61    }
62
63    pub fn get_range_values(&self, start: &CellAddress, end: &CellAddress) -> Vec<Vec<CellValue>> {
64        let sheet_idx = self.current_sheet;
65        let sheet = match self.workbook.sheets.get(sheet_idx) {
66            Some(s) => s,
67            None => return vec![],
68        };
69
70        let r1 = start.row.min(end.row);
71        let r2 = start.row.max(end.row);
72        let c1 = start.col.min(end.col);
73        let c2 = start.col.max(end.col);
74
75        (r1..=r2).map(|row| {
76            (c1..=c2).map(|col| {
77                sheet.get_cell(row, col)
78                    .map(|c| c.value.clone())
79                    .unwrap_or(CellValue::Empty)
80            }).collect()
81        }).collect()
82    }
83}
84
85/// Evaluate an expression in a given context
86pub fn eval(expr: &Expr, ctx: &EvalContext) -> Result<CellValue, EvalError> {
87    match expr {
88        Expr::Number(n) => Ok(CellValue::Number(*n)),
89        Expr::Text(s) => Ok(CellValue::Text(s.clone())),
90        Expr::Bool(b) => Ok(CellValue::Bool(*b)),
91        Expr::Error(e) => Ok(CellValue::Error(parse_error(e))),
92
93        Expr::Percent(inner) => {
94            let v = eval(inner, ctx)?;
95            match v.as_number() {
96                Some(n) => Ok(CellValue::Number(n / 100.0)),
97                None => Err(EvalError::TypeError("Expected number for %".into())),
98            }
99        }
100
101        Expr::CellRef(addr) => Ok(ctx.get_cell_value(addr)),
102
103        Expr::RangeRef(start, end) => {
104            // When a range is used as a scalar, return the top-left cell
105            Ok(ctx.get_cell_value(start))
106        }
107
108        Expr::NamedRange(name) => {
109            // Look up named range
110            if let Some(range) = ctx.workbook.named_ranges.get(name.as_str()) {
111                Ok(ctx.get_cell_value(&range.start))
112            } else {
113                Ok(CellValue::Error(CellError::Name))
114            }
115        }
116
117        Expr::Array(rows) => {
118            // For now, return first element of array
119            if let Some(first_row) = rows.first() {
120                if let Some(first_cell) = first_row.first() {
121                    return eval(first_cell, ctx);
122                }
123            }
124            Ok(CellValue::Empty)
125        }
126
127        Expr::UnaryOp { op, expr } => {
128            let v = eval(expr, ctx)?;
129            match op {
130                UnaryOpKind::Neg => match v.as_number() {
131                    Some(n) => Ok(CellValue::Number(-n)),
132                    None => Err(EvalError::TypeError("Expected number for negation".into())),
133                },
134                UnaryOpKind::Plus => Ok(v),
135            }
136        }
137
138        Expr::BinOp { left, op, right } => eval_binop(left, op, right, ctx),
139
140        Expr::Call { name, args } => eval_call(name, args, ctx),
141    }
142}
143
144fn eval_binop(left: &Expr, op: &BinOpKind, right: &Expr, ctx: &EvalContext) -> Result<CellValue, EvalError> {
145    // Special case for range expansion in aggregate functions (handled in functions)
146    let lv = eval(left, ctx)?;
147    let rv = eval(right, ctx)?;
148
149    match op {
150        BinOpKind::Add => num_op(lv, rv, |a, b| a + b),
151        BinOpKind::Sub => num_op(lv, rv, |a, b| a - b),
152        BinOpKind::Mul => num_op(lv, rv, |a, b| a * b),
153        BinOpKind::Div => {
154            let b = rv.as_number().ok_or_else(|| EvalError::TypeError("Expected number".into()))?;
155            if b == 0.0 { return Ok(CellValue::Error(CellError::Div0)); }
156            num_op(lv, CellValue::Number(b), |a, _| a / b)
157        }
158        BinOpKind::Pow => num_op(lv, rv, |a, b| a.powf(b)),
159        BinOpKind::Concat => {
160            let s = format!("{}{}", lv.as_text(), rv.as_text());
161            Ok(CellValue::Text(s))
162        }
163        BinOpKind::Eq => Ok(CellValue::Bool(cell_eq(&lv, &rv))),
164        BinOpKind::Ne => Ok(CellValue::Bool(!cell_eq(&lv, &rv))),
165        BinOpKind::Lt => Ok(CellValue::Bool(cell_lt(&lv, &rv))),
166        BinOpKind::Le => Ok(CellValue::Bool(cell_lt(&lv, &rv) || cell_eq(&lv, &rv))),
167        BinOpKind::Gt => Ok(CellValue::Bool(!cell_lt(&lv, &rv) && !cell_eq(&lv, &rv))),
168        BinOpKind::Ge => Ok(CellValue::Bool(!cell_lt(&lv, &rv))),
169    }
170}
171
172fn num_op(a: CellValue, b: CellValue, f: impl Fn(f64, f64) -> f64) -> Result<CellValue, EvalError> {
173    let an = a.as_number().ok_or_else(|| EvalError::TypeError(format!("Expected number, got {:?}", a)))?;
174    let bn = b.as_number().ok_or_else(|| EvalError::TypeError(format!("Expected number, got {:?}", b)))?;
175    Ok(CellValue::Number(f(an, bn)))
176}
177
178fn cell_eq(a: &CellValue, b: &CellValue) -> bool {
179    match (a, b) {
180        (CellValue::Number(x), CellValue::Number(y)) => x == y,
181        (CellValue::Text(x), CellValue::Text(y)) => x.to_uppercase() == y.to_uppercase(),
182        (CellValue::Bool(x), CellValue::Bool(y)) => x == y,
183        (CellValue::Empty, CellValue::Empty) => true,
184        _ => false,
185    }
186}
187
188fn cell_lt(a: &CellValue, b: &CellValue) -> bool {
189    match (a, b) {
190        (CellValue::Number(x), CellValue::Number(y)) => x < y,
191        (CellValue::Text(x), CellValue::Text(y)) => x.to_uppercase() < y.to_uppercase(),
192        (CellValue::Bool(x), CellValue::Bool(y)) => !x && *y,
193        _ => false,
194    }
195}
196
197fn eval_call(name: &str, args: &[Expr], ctx: &EvalContext) -> Result<CellValue, EvalError> {
198    // Collect range values for functions that need them specially
199    let fn_args = collect_fn_args(name, args, ctx)?;
200
201    // Look up and call function
202    if let Some(func) = FUNCTIONS.get(name.to_uppercase().as_str()) {
203        func(&fn_args).map_err(|e| EvalError::Formula(e))
204    } else {
205        Err(EvalError::UnknownFunc(name.to_string()))
206    }
207}
208
209/// Collect function arguments, expanding ranges when needed
210fn collect_fn_args(name: &str, args: &[Expr], ctx: &EvalContext) -> Result<Vec<CellValue>, EvalError> {
211    // Functions that expand ranges into flat lists of values
212    let range_expanding = [
213        "SUM", "AVERAGE", "MIN", "MAX", "COUNT", "COUNTA", "COUNTBLANK",
214        "PRODUCT", "STDEV", "STDEVP", "VAR", "VARP", "MEDIAN", "MODE",
215    ];
216    let expands = range_expanding.contains(&name.to_uppercase().as_str());
217
218    let mut result = vec![];
219    for arg in args {
220        match arg {
221            Expr::RangeRef(start, end) if expands => {
222                let values = ctx.get_range_values(start, end);
223                for row in values {
224                    for v in row {
225                        result.push(v);
226                    }
227                }
228            }
229            _ => result.push(eval(arg, ctx)?),
230        }
231    }
232    Ok(result)
233}
234
235fn parse_error(s: &str) -> CellError {
236    match s {
237        "#DIV/0!" => CellError::Div0,
238        "#N/A" => CellError::NA,
239        "#NAME?" => CellError::Name,
240        "#NULL!" => CellError::Null,
241        "#NUM!" => CellError::Num,
242        "#REF!" => CellError::Ref,
243        "#VALUE!" => CellError::Value,
244        _ => CellError::Value,
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::parser::parse_formula;
252
253    fn eval_str(formula: &str, wb: &Workbook) -> CellValue {
254        let expr = parse_formula(formula).unwrap();
255        let ctx = EvalContext::new(wb, 0, CellAddress::new(0, 0));
256        eval(&expr, &ctx).unwrap_or(CellValue::Error(CellError::Value))
257    }
258
259    #[test]
260    fn test_eval_arithmetic() {
261        let wb = Workbook::default();
262        assert_eq!(eval_str("=2+3", &wb), CellValue::Number(5.0));
263        assert_eq!(eval_str("=10/4", &wb), CellValue::Number(2.5));
264        assert_eq!(eval_str("=2^10", &wb), CellValue::Number(1024.0));
265    }
266
267    #[test]
268    fn test_eval_sum() {
269        let mut wb = Workbook::default();
270        let sheet = wb.sheets.get_mut(0).unwrap();
271        use pdf_core::Cell;
272        sheet.set_cell(0, 0, Cell { value: CellValue::Number(1.0), ..Default::default() });
273        sheet.set_cell(1, 0, Cell { value: CellValue::Number(2.0), ..Default::default() });
274        sheet.set_cell(2, 0, Cell { value: CellValue::Number(3.0), ..Default::default() });
275        assert_eq!(eval_str("=SUM(A1:A3)", &wb), CellValue::Number(6.0));
276    }
277
278    #[test]
279    fn test_eval_if() {
280        let wb = Workbook::default();
281        assert_eq!(eval_str("=IF(1>0,\"yes\",\"no\")", &wb), CellValue::Text("yes".into()));
282        assert_eq!(eval_str("=IF(0>1,\"yes\",\"no\")", &wb), CellValue::Text("no".into()));
283    }
284
285    #[test]
286    fn test_eval_concatenate() {
287        let wb = Workbook::default();
288        assert_eq!(eval_str("=\"Hello\"&\" \"&\"World\"", &wb), CellValue::Text("Hello World".into()));
289    }
290}