Skip to main content

cell_sheet_core/formula/
deps.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use crate::formula::ast::*;
4use crate::formula::eval;
5use crate::formula::parser;
6use crate::model::{CellError, CellPos, CellValue, Sheet};
7
8pub struct DepGraph {
9    pub dependents: HashMap<CellPos, HashSet<CellPos>>,
10    pub dependencies: HashMap<CellPos, HashSet<CellPos>>,
11}
12
13impl Default for DepGraph {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl DepGraph {
20    pub fn new() -> Self {
21        DepGraph {
22            dependents: HashMap::new(),
23            dependencies: HashMap::new(),
24        }
25    }
26
27    pub fn set_dependencies(&mut self, cell: CellPos, deps: Vec<CellPos>) {
28        if let Some(old_deps) = self.dependencies.remove(&cell) {
29            for dep in &old_deps {
30                if let Some(set) = self.dependents.get_mut(dep) {
31                    set.remove(&cell);
32                }
33            }
34        }
35        let dep_set: HashSet<CellPos> = deps.into_iter().collect();
36        for dep in &dep_set {
37            self.dependents.entry(*dep).or_default().insert(cell);
38        }
39        self.dependencies.insert(cell, dep_set);
40    }
41}
42
43pub fn extract_deps(formula: &str) -> Vec<CellPos> {
44    let expr = match parser::parse(formula) {
45        Ok(e) => e,
46        Err(_) => return vec![],
47    };
48    let mut deps = Vec::new();
49    collect_refs(&expr, &mut deps);
50    deps
51}
52
53fn collect_refs(expr: &Expr, out: &mut Vec<CellPos>) {
54    match expr {
55        Expr::CellRef(r) => {
56            out.push((r.row, r.col));
57        }
58        Expr::Range { start, end } => {
59            let r1 = start.row.min(end.row);
60            let r2 = start.row.max(end.row);
61            let c1 = start.col.min(end.col);
62            let c2 = start.col.max(end.col);
63            for r in r1..=r2 {
64                for c in c1..=c2 {
65                    out.push((r, c));
66                }
67            }
68        }
69        Expr::BinaryOp { left, right, .. } => {
70            collect_refs(left, out);
71            collect_refs(right, out);
72        }
73        Expr::UnaryNeg(inner) => collect_refs(inner, out),
74        Expr::FnCall { args, .. } => {
75            for arg in args {
76                collect_refs(arg, out);
77            }
78        }
79        _ => {}
80    }
81}
82
83pub fn set_formula(sheet: &mut Sheet, deps: &mut DepGraph, pos: CellPos, raw: &str) {
84    sheet.set_cell(pos, raw);
85    let formula = &raw[1..]; // strip '='
86    let dep_list = extract_deps(formula);
87    deps.set_dependencies(pos, dep_list);
88}
89
90pub fn mark_dirty(sheet: &mut Sheet, deps: &DepGraph, pos: CellPos) {
91    let mut queue = VecDeque::new();
92    queue.push_back(pos);
93    while let Some(cell) = queue.pop_front() {
94        if let Some(dependents) = deps.dependents.get(&cell) {
95            for &dep in dependents {
96                if let Some(c) = sheet.cells.get_mut(&dep) {
97                    if !c.dirty {
98                        c.dirty = true;
99                        queue.push_back(dep);
100                    }
101                }
102            }
103        }
104    }
105}
106
107pub fn recalculate(sheet: &mut Sheet, deps: &DepGraph) {
108    let formula_cells: Vec<CellPos> = sheet
109        .cells
110        .iter()
111        .filter(|(_, cell)| cell.raw.starts_with('='))
112        .map(|(pos, _)| *pos)
113        .collect();
114
115    let formula_set: HashSet<CellPos> = formula_cells.iter().cloned().collect();
116
117    let mut in_degree: HashMap<CellPos, usize> = HashMap::new();
118    for &cell in &formula_cells {
119        let count = deps
120            .dependencies
121            .get(&cell)
122            .map(|d| d.iter().filter(|p| formula_set.contains(p)).count())
123            .unwrap_or(0);
124        in_degree.insert(cell, count);
125    }
126
127    let mut queue: VecDeque<CellPos> = in_degree
128        .iter()
129        .filter(|(_, &deg)| deg == 0)
130        .map(|(&pos, _)| pos)
131        .collect();
132
133    let mut order = Vec::new();
134
135    while let Some(cell) = queue.pop_front() {
136        order.push(cell);
137        if let Some(dependents) = deps.dependents.get(&cell) {
138            for &dep in dependents {
139                if let Some(deg) = in_degree.get_mut(&dep) {
140                    *deg -= 1;
141                    if *deg == 0 {
142                        queue.push_back(dep);
143                    }
144                }
145            }
146        }
147    }
148
149    let ordered_set: HashSet<CellPos> = order.iter().cloned().collect();
150    for &cell in &formula_cells {
151        if !ordered_set.contains(&cell) {
152            if let Some(c) = sheet.cells.get_mut(&cell) {
153                c.value = CellValue::Error(CellError::Circ);
154                c.dirty = false;
155            }
156        }
157    }
158
159    for pos in order {
160        let raw = match sheet.get_cell(pos) {
161            Some(cell) if cell.raw.starts_with('=') => cell.raw.clone(),
162            _ => continue,
163        };
164        let formula = &raw[1..];
165        let value = eval::evaluate(formula, sheet);
166        if let Some(cell) = sheet.cells.get_mut(&pos) {
167            cell.value = value;
168            cell.dirty = false;
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::model::{CellValue, Sheet};
177
178    #[test]
179    fn extract_deps_cell_ref() {
180        let deps = extract_deps("A1+B1");
181        assert_eq!(deps, vec![(0, 0), (0, 1)]);
182    }
183
184    #[test]
185    fn extract_deps_range() {
186        let deps = extract_deps("SUM(A1:A3)");
187        assert_eq!(deps, vec![(0, 0), (1, 0), (2, 0)]);
188    }
189
190    #[test]
191    fn extract_deps_no_refs() {
192        let deps = extract_deps("1+2");
193        assert!(deps.is_empty());
194    }
195
196    #[test]
197    fn recalc_simple() {
198        let mut sheet = Sheet::new();
199        let mut deps = DepGraph::new();
200        sheet.set_cell((0, 0), "10");
201        set_formula(&mut sheet, &mut deps, (0, 1), "=A1+5");
202        recalculate(&mut sheet, &deps);
203        assert_eq!(
204            sheet.get_cell((0, 1)).unwrap().value,
205            CellValue::Number(15.0)
206        );
207    }
208
209    #[test]
210    fn recalc_chain() {
211        let mut sheet = Sheet::new();
212        let mut deps = DepGraph::new();
213        sheet.set_cell((0, 0), "10");
214        set_formula(&mut sheet, &mut deps, (0, 1), "=A1*2");
215        set_formula(&mut sheet, &mut deps, (0, 2), "=B1+1");
216        recalculate(&mut sheet, &deps);
217        assert_eq!(
218            sheet.get_cell((0, 1)).unwrap().value,
219            CellValue::Number(20.0)
220        );
221        assert_eq!(
222            sheet.get_cell((0, 2)).unwrap().value,
223            CellValue::Number(21.0)
224        );
225    }
226
227    #[test]
228    fn recalc_circular_reference() {
229        let mut sheet = Sheet::new();
230        let mut deps = DepGraph::new();
231        set_formula(&mut sheet, &mut deps, (0, 0), "=B1");
232        set_formula(&mut sheet, &mut deps, (0, 1), "=A1");
233        recalculate(&mut sheet, &deps);
234        assert_eq!(
235            sheet.get_cell((0, 0)).unwrap().value,
236            CellValue::Error(CellError::Circ)
237        );
238        assert_eq!(
239            sheet.get_cell((0, 1)).unwrap().value,
240            CellValue::Error(CellError::Circ)
241        );
242    }
243
244    #[test]
245    fn recalc_after_value_change() {
246        let mut sheet = Sheet::new();
247        let mut deps = DepGraph::new();
248        sheet.set_cell((0, 0), "10");
249        set_formula(&mut sheet, &mut deps, (0, 1), "=A1+5");
250        recalculate(&mut sheet, &deps);
251        assert_eq!(
252            sheet.get_cell((0, 1)).unwrap().value,
253            CellValue::Number(15.0)
254        );
255
256        sheet.set_cell((0, 0), "20");
257        mark_dirty(&mut sheet, &deps, (0, 0));
258        recalculate(&mut sheet, &deps);
259        assert_eq!(
260            sheet.get_cell((0, 1)).unwrap().value,
261            CellValue::Number(25.0)
262        );
263    }
264
265    #[test]
266    fn self_reference_is_circular() {
267        let mut sheet = Sheet::new();
268        let mut deps = DepGraph::new();
269        set_formula(&mut sheet, &mut deps, (0, 0), "=A1+1");
270        recalculate(&mut sheet, &deps);
271        assert_eq!(
272            sheet.get_cell((0, 0)).unwrap().value,
273            CellValue::Error(CellError::Circ)
274        );
275    }
276}