1use std::collections::{HashMap, HashSet, VecDeque};
2
3use crate::formula::ast::*;
4use crate::formula::eval;
5use crate::formula::parser;
6use crate::model::{CellError, CellPos, CellValue, Sheet};
7
8#[derive(Debug, Clone)]
9pub struct DepGraph {
10 pub dependents: HashMap<CellPos, HashSet<CellPos>>,
11 pub dependencies: HashMap<CellPos, HashSet<CellPos>>,
12}
13
14impl Default for DepGraph {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl DepGraph {
21 pub fn new() -> Self {
22 DepGraph {
23 dependents: HashMap::new(),
24 dependencies: HashMap::new(),
25 }
26 }
27
28 pub fn set_dependencies(&mut self, cell: CellPos, deps: Vec<CellPos>) {
29 if let Some(old_deps) = self.dependencies.remove(&cell) {
30 for dep in &old_deps {
31 if let Some(set) = self.dependents.get_mut(dep) {
32 set.remove(&cell);
33 }
34 }
35 }
36 let dep_set: HashSet<CellPos> = deps.into_iter().collect();
37 for dep in &dep_set {
38 self.dependents.entry(*dep).or_default().insert(cell);
39 }
40 self.dependencies.insert(cell, dep_set);
41 }
42
43 pub fn remove(&mut self, cell: CellPos) {
47 if let Some(old_deps) = self.dependencies.remove(&cell) {
48 for dep in &old_deps {
49 if let Some(set) = self.dependents.get_mut(dep) {
50 set.remove(&cell);
51 }
52 }
53 }
54 }
55}
56
57pub fn extract_deps(formula: &str) -> Vec<CellPos> {
58 let expr = match parser::parse(formula) {
59 Ok(e) => e,
60 Err(_) => return vec![],
61 };
62 let mut deps = Vec::new();
63 collect_refs(&expr, &mut deps);
64 deps
65}
66
67fn collect_refs(expr: &Expr, out: &mut Vec<CellPos>) {
68 match expr {
69 Expr::CellRef(r) => {
70 out.push((r.row, r.col));
71 }
72 Expr::Range { start, end } => {
73 let r1 = start.row.min(end.row);
74 let r2 = start.row.max(end.row);
75 let c1 = start.col.min(end.col);
76 let c2 = start.col.max(end.col);
77 for r in r1..=r2 {
78 for c in c1..=c2 {
79 out.push((r, c));
80 }
81 }
82 }
83 Expr::BinaryOp { left, right, .. } => {
84 collect_refs(left, out);
85 collect_refs(right, out);
86 }
87 Expr::UnaryNeg(inner) => collect_refs(inner, out),
88 Expr::FnCall { args, .. } => {
89 for arg in args {
90 collect_refs(arg, out);
91 }
92 }
93 _ => {}
94 }
95}
96
97pub fn set_formula(sheet: &mut Sheet, deps: &mut DepGraph, pos: CellPos, raw: &str) {
98 sheet.set_cell(pos, raw);
99 let formula = &raw[1..]; let dep_list = extract_deps(formula);
101 deps.set_dependencies(pos, dep_list);
102}
103
104pub fn mark_dirty(sheet: &mut Sheet, deps: &DepGraph, pos: CellPos) {
105 let mut queue = VecDeque::new();
106 queue.push_back(pos);
107 while let Some(cell) = queue.pop_front() {
108 if let Some(dependents) = deps.dependents.get(&cell) {
109 for &dep in dependents {
110 if let Some(c) = sheet.cells.get_mut(&dep) {
111 if !c.dirty {
112 c.dirty = true;
113 queue.push_back(dep);
114 }
115 }
116 }
117 }
118 }
119}
120
121pub fn recalculate(sheet: &mut Sheet, deps: &DepGraph) {
122 let formula_cells: Vec<CellPos> = sheet
123 .cells
124 .iter()
125 .filter(|(_, cell)| cell.raw.starts_with('='))
126 .map(|(pos, _)| *pos)
127 .collect();
128
129 let formula_set: HashSet<CellPos> = formula_cells.iter().cloned().collect();
130
131 let mut in_degree: HashMap<CellPos, usize> = HashMap::new();
132 for &cell in &formula_cells {
133 let count = deps
134 .dependencies
135 .get(&cell)
136 .map(|d| d.iter().filter(|p| formula_set.contains(p)).count())
137 .unwrap_or(0);
138 in_degree.insert(cell, count);
139 }
140
141 let mut queue: VecDeque<CellPos> = in_degree
142 .iter()
143 .filter(|(_, °)| deg == 0)
144 .map(|(&pos, _)| pos)
145 .collect();
146
147 let mut order = Vec::new();
148
149 while let Some(cell) = queue.pop_front() {
150 order.push(cell);
151 if let Some(dependents) = deps.dependents.get(&cell) {
152 for &dep in dependents {
153 if let Some(deg) = in_degree.get_mut(&dep) {
154 *deg -= 1;
155 if *deg == 0 {
156 queue.push_back(dep);
157 }
158 }
159 }
160 }
161 }
162
163 let ordered_set: HashSet<CellPos> = order.iter().cloned().collect();
164 for &cell in &formula_cells {
165 if !ordered_set.contains(&cell) {
166 if let Some(c) = sheet.cells.get_mut(&cell) {
167 c.value = CellValue::Error(CellError::Circ);
168 c.dirty = false;
169 }
170 }
171 }
172
173 for pos in order {
174 let raw = match sheet.get_cell(pos) {
175 Some(cell) if cell.raw.starts_with('=') => cell.raw.clone(),
176 _ => continue,
177 };
178 let formula = &raw[1..];
179 let value = eval::evaluate(formula, sheet);
180 if let Some(cell) = sheet.cells.get_mut(&pos) {
181 cell.value = value;
182 cell.dirty = false;
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::model::{CellValue, Sheet};
191
192 #[test]
193 fn extract_deps_cell_ref() {
194 let deps = extract_deps("A1+B1");
195 assert_eq!(deps, vec![(0, 0), (0, 1)]);
196 }
197
198 #[test]
199 fn extract_deps_range() {
200 let deps = extract_deps("SUM(A1:A3)");
201 assert_eq!(deps, vec![(0, 0), (1, 0), (2, 0)]);
202 }
203
204 #[test]
205 fn extract_deps_no_refs() {
206 let deps = extract_deps("1+2");
207 assert!(deps.is_empty());
208 }
209
210 #[test]
211 fn recalc_simple() {
212 let mut sheet = Sheet::new();
213 let mut deps = DepGraph::new();
214 sheet.set_cell((0, 0), "10");
215 set_formula(&mut sheet, &mut deps, (0, 1), "=A1+5");
216 recalculate(&mut sheet, &deps);
217 assert_eq!(
218 sheet.get_cell((0, 1)).unwrap().value,
219 CellValue::Number(15.0)
220 );
221 }
222
223 #[test]
224 fn recalc_chain() {
225 let mut sheet = Sheet::new();
226 let mut deps = DepGraph::new();
227 sheet.set_cell((0, 0), "10");
228 set_formula(&mut sheet, &mut deps, (0, 1), "=A1*2");
229 set_formula(&mut sheet, &mut deps, (0, 2), "=B1+1");
230 recalculate(&mut sheet, &deps);
231 assert_eq!(
232 sheet.get_cell((0, 1)).unwrap().value,
233 CellValue::Number(20.0)
234 );
235 assert_eq!(
236 sheet.get_cell((0, 2)).unwrap().value,
237 CellValue::Number(21.0)
238 );
239 }
240
241 #[test]
242 fn recalc_circular_reference() {
243 let mut sheet = Sheet::new();
244 let mut deps = DepGraph::new();
245 set_formula(&mut sheet, &mut deps, (0, 0), "=B1");
246 set_formula(&mut sheet, &mut deps, (0, 1), "=A1");
247 recalculate(&mut sheet, &deps);
248 assert_eq!(
249 sheet.get_cell((0, 0)).unwrap().value,
250 CellValue::Error(CellError::Circ)
251 );
252 assert_eq!(
253 sheet.get_cell((0, 1)).unwrap().value,
254 CellValue::Error(CellError::Circ)
255 );
256 }
257
258 #[test]
259 fn recalc_after_value_change() {
260 let mut sheet = Sheet::new();
261 let mut deps = DepGraph::new();
262 sheet.set_cell((0, 0), "10");
263 set_formula(&mut sheet, &mut deps, (0, 1), "=A1+5");
264 recalculate(&mut sheet, &deps);
265 assert_eq!(
266 sheet.get_cell((0, 1)).unwrap().value,
267 CellValue::Number(15.0)
268 );
269
270 sheet.set_cell((0, 0), "20");
271 mark_dirty(&mut sheet, &deps, (0, 0));
272 recalculate(&mut sheet, &deps);
273 assert_eq!(
274 sheet.get_cell((0, 1)).unwrap().value,
275 CellValue::Number(25.0)
276 );
277 }
278
279 #[test]
280 fn remove_drops_edges() {
281 let mut sheet = Sheet::new();
282 let mut deps = DepGraph::new();
283 sheet.set_cell((0, 0), "10");
284 set_formula(&mut sheet, &mut deps, (0, 1), "=A1+5");
285 assert!(deps.dependencies.contains_key(&(0, 1)));
286 assert!(deps.dependents.get(&(0, 0)).unwrap().contains(&(0, 1)));
287
288 deps.remove((0, 1));
289
290 assert!(!deps.dependencies.contains_key(&(0, 1)));
291 assert!(!deps
292 .dependents
293 .get(&(0, 0))
294 .map(|s| s.contains(&(0, 1)))
295 .unwrap_or(false));
296 }
297
298 #[test]
299 fn self_reference_is_circular() {
300 let mut sheet = Sheet::new();
301 let mut deps = DepGraph::new();
302 set_formula(&mut sheet, &mut deps, (0, 0), "=A1+1");
303 recalculate(&mut sheet, &deps);
304 assert_eq!(
305 sheet.get_cell((0, 0)).unwrap().value,
306 CellValue::Error(CellError::Circ)
307 );
308 }
309}