1use crate::function::{FnCaps, Function};
7use formualizer_parse::parser::{ASTNode, ASTNodeType, ReferenceType};
8use rustc_hash::FxHashMap;
9use std::sync::Arc;
10
11type RangeDimsProbe<'a> = dyn Fn(&ReferenceType) -> Option<(u32, u32)> + 'a;
12type FunctionLookup<'a> = dyn Fn(&str, &str) -> Option<Arc<dyn Function>> + 'a;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ExecStrategy {
16 Sequential,
17 ArgParallel,
18 ChunkedReduce,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Semantics {
23 Pure,
24 ShortCircuit,
25 Volatile,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct NodeCost {
30 pub est_nanos: u64, pub cells: u64, pub fanout: u16, }
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct NodeHints {
37 pub has_range: bool,
38 pub dims: Option<(u32, u32)>,
39 pub repeated_fp_count: u16, }
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct NodeAnnot {
44 pub semantics: Semantics,
45 pub cost: NodeCost,
46 pub hints: NodeHints,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct PlanNode {
51 pub strategy: ExecStrategy,
52 pub children: Vec<PlanNode>,
53}
54
55#[derive(Debug, Clone)]
56pub struct PlanConfig {
57 pub enable_parallel: bool,
58 pub arg_parallel_min_cost_ns: u64,
59 pub arg_parallel_min_children: u16,
60 pub chunk_min_cells: u64,
61 pub chunk_target_partitions: u16,
62}
63
64impl Default for PlanConfig {
65 fn default() -> Self {
66 Self {
67 enable_parallel: true,
68 arg_parallel_min_cost_ns: 200_000, arg_parallel_min_children: 3,
70 chunk_min_cells: 10_000,
71 chunk_target_partitions: 8,
72 }
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct ExecPlan {
78 pub root: PlanNode,
79}
80
81pub struct Planner<'a> {
82 config: PlanConfig,
83 fp_cache: FxHashMap<u64, u16>,
85 _range_dims_probe: Option<&'a RangeDimsProbe<'a>>,
87 get_fn: Option<&'a FunctionLookup<'a>>,
89}
90
91impl<'a> Planner<'a> {
92 pub fn new(config: PlanConfig) -> Self {
93 Self {
94 config,
95 fp_cache: FxHashMap::default(),
96 _range_dims_probe: None,
97 get_fn: None,
98 }
99 }
100
101 pub fn with_range_probe(mut self, probe: &'a RangeDimsProbe<'a>) -> Self {
102 self._range_dims_probe = Some(probe);
103 self
104 }
105
106 pub fn with_function_lookup(mut self, get_fn: &'a FunctionLookup<'a>) -> Self {
107 self.get_fn = Some(get_fn);
108 self
109 }
110
111 pub fn plan(&mut self, ast: &ASTNode) -> ExecPlan {
112 self.fp_cache.clear();
113 let annot = self.annotate(ast);
114 let root = self.select(ast, &annot);
115 ExecPlan { root }
116 }
117
118 fn annotate(&mut self, ast: &ASTNode) -> NodeAnnot {
119 use ASTNodeType::*;
120 let semantics = if ast.contains_volatile() {
122 Semantics::Volatile
123 } else {
124 match &ast.node_type {
125 ASTNodeType::Function { name, .. } => {
126 if let Some(get) = &self.get_fn {
127 if let Some(f) = get("", name) {
128 let caps = f.caps();
129 if caps.contains(FnCaps::VOLATILE) {
130 Semantics::Volatile
131 } else if caps.contains(FnCaps::SHORT_CIRCUIT) {
132 Semantics::ShortCircuit
133 } else {
134 Semantics::Pure
135 }
136 } else {
137 Semantics::Pure
138 }
139 } else {
140 Semantics::Pure
141 }
142 }
143 _ => Semantics::Pure,
144 }
145 };
146
147 let (cost, has_range, dims, fanout) = match &ast.node_type {
149 Literal(_) => (
150 NodeCost {
151 est_nanos: 50,
152 cells: 0,
153 fanout: 0,
154 },
155 false,
156 None,
157 0,
158 ),
159 Reference { reference, .. } => {
160 let dims = self._range_dims_probe.and_then(|p| p(reference));
161 let cells = dims.map(|(r, c)| (r as u64) * (c as u64)).unwrap_or(0);
163 let est = 10_000 + cells / 10; (
165 NodeCost {
166 est_nanos: est,
167 cells,
168 fanout: 0,
169 },
170 true,
171 dims,
172 0,
173 )
174 }
175 UnaryOp { expr, .. } => {
176 let a = self.annotate(expr);
177 (a.cost, a.hints.has_range, a.hints.dims, 1)
178 }
179 BinaryOp { left, right, op: _ } => {
180 let a = self.annotate(left);
181 let b = self.annotate(right);
182 let est = a.cost.est_nanos + b.cost.est_nanos + 1_000;
183 let cells = a.cost.cells + b.cost.cells;
184 let has_range = a.hints.has_range || b.hints.has_range;
185 let dims = a.hints.dims.or(b.hints.dims);
186 (
187 NodeCost {
188 est_nanos: est,
189 cells,
190 fanout: 2,
191 },
192 has_range,
193 dims,
194 2,
195 )
196 }
197 Function { name, args } => {
198 let child_annots: Vec<NodeAnnot> = args.iter().map(|a| self.annotate(a)).collect();
200 let lname = name.to_ascii_lowercase();
202 let base = match lname.as_str() {
203 "sumifs" | "countifs" | "averageifs" => 200_000, "vlookup" | "xlookup" | "search" | "find" => 80_000,
205 _ => 5_000,
206 };
207 let children_cost: u64 = child_annots.iter().map(|a| a.cost.est_nanos).sum();
208 let cells: u64 = child_annots.iter().map(|a| a.cost.cells).sum();
209 let has_range = child_annots.iter().any(|a| a.hints.has_range);
210 let dims = child_annots.iter().find_map(|a| a.hints.dims);
211 let fanout = args.len() as u16;
212 (
213 NodeCost {
214 est_nanos: base + children_cost,
215 cells,
216 fanout,
217 },
218 has_range,
219 dims,
220 fanout,
221 )
222 }
223 Array(rows) => {
224 let mut est = 2_000;
225 let mut has_range = false;
226 let mut dims = Some((
227 rows.len() as u32,
228 rows.first().map(|r| r.len()).unwrap_or(0) as u32,
229 ));
230 for r in rows {
231 for c in r {
232 let a = self.annotate(c);
233 est += a.cost.est_nanos;
234 has_range |= a.hints.has_range;
235 if dims.is_none() {
236 dims = a.hints.dims;
237 }
238 }
239 }
240 (
241 NodeCost {
242 est_nanos: est,
243 cells: 0,
244 fanout: 0,
245 },
246 has_range,
247 dims,
248 0,
249 )
250 }
251 Call { callee, args } => {
252 let callee_annot = self.annotate(callee);
253 let child_annots: Vec<NodeAnnot> = args.iter().map(|a| self.annotate(a)).collect();
254 let children_cost: u64 = callee_annot.cost.est_nanos
255 + child_annots.iter().map(|a| a.cost.est_nanos).sum::<u64>();
256 let cells: u64 = callee_annot.cost.cells
257 + child_annots.iter().map(|a| a.cost.cells).sum::<u64>();
258 let has_range =
259 callee_annot.hints.has_range || child_annots.iter().any(|a| a.hints.has_range);
260 let dims = callee_annot
261 .hints
262 .dims
263 .or_else(|| child_annots.iter().find_map(|a| a.hints.dims));
264 let fanout = (args.len() + 1) as u16;
265 (
266 NodeCost {
267 est_nanos: 5_000 + children_cost,
268 cells,
269 fanout,
270 },
271 has_range,
272 dims,
273 fanout,
274 )
275 }
276 };
277
278 let repeated_fp_count = match &ast.node_type {
280 ASTNodeType::Function { args, .. } => {
281 let mut map: FxHashMap<u64, u16> = FxHashMap::default();
282 for a in args {
283 let fp = a.fingerprint();
284 *map.entry(fp).or_insert(0) += 1;
285 }
286 map.values().copied().filter(|&n| n > 1).sum()
287 }
288 ASTNodeType::BinaryOp { left, right, .. } => {
289 (left.fingerprint() == right.fingerprint()) as u16
290 }
291 _ => 0,
292 };
293
294 NodeAnnot {
295 semantics,
296 cost,
297 hints: NodeHints {
298 has_range,
299 dims,
300 repeated_fp_count,
301 },
302 }
303 }
304
305 fn select(&mut self, ast: &ASTNode, annot: &NodeAnnot) -> PlanNode {
306 use ExecStrategy::*;
307 let strategy = match annot.semantics {
309 Semantics::ShortCircuit => Sequential,
310 Semantics::Volatile => Sequential,
311 Semantics::Pure => {
312 if !self.config.enable_parallel {
313 Sequential
314 } else if annot.hints.has_range && annot.cost.cells >= self.config.chunk_min_cells {
315 ChunkedReduce
316 } else if annot.cost.est_nanos >= self.config.arg_parallel_min_cost_ns
317 && annot.cost.fanout >= self.config.arg_parallel_min_children
318 {
319 ArgParallel
320 } else {
321 Sequential
322 }
323 }
324 };
325
326 let children = match &ast.node_type {
328 ASTNodeType::UnaryOp { expr, .. } => {
329 let a = self.annotate(expr);
330 vec![self.select(expr, &a)]
331 }
332 ASTNodeType::BinaryOp { left, right, .. } => {
333 let la = self.annotate(left);
334 let ra = self.annotate(right);
335 vec![self.select(left, &la), self.select(right, &ra)]
336 }
337 ASTNodeType::Function { args, .. } => {
338 let mut v = Vec::with_capacity(args.len());
339 for a in args {
340 let an = self.annotate(a);
341 v.push(self.select(a, &an));
342 }
343 v
344 }
345 ASTNodeType::Call { callee, args } => {
346 let mut v = Vec::with_capacity(args.len() + 1);
347 let callee_annot = self.annotate(callee);
348 v.push(self.select(callee, &callee_annot));
349 for a in args {
350 let an = self.annotate(a);
351 v.push(self.select(a, &an));
352 }
353 v
354 }
355 ASTNodeType::Array(rows) => {
356 let mut v = Vec::new();
357 for r in rows {
358 for a in r {
359 let an = self.annotate(a);
360 v.push(self.select(a, &an));
361 }
362 }
363 v
364 }
365 _ => Vec::new(),
366 };
367
368 PlanNode { strategy, children }
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 fn ensure_builtins_registered() {
377 use std::sync::Once;
378 static ONCE: Once = Once::new();
379 ONCE.call_once(|| {
380 crate::builtins::logical::register_builtins();
382 crate::builtins::logical_ext::register_builtins();
383 crate::builtins::datetime::register_builtins();
384 crate::builtins::math::register_builtins();
385 crate::builtins::text::register_builtins();
386 });
387 }
388
389 fn plan_for(formula: &str) -> ExecPlan {
390 ensure_builtins_registered();
391 let ast = formualizer_parse::parser::parse(formula).unwrap();
392 let mut planner = Planner::new(PlanConfig::default())
393 .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name));
394 planner.plan(&ast)
395 }
396
397 #[test]
398 fn trivial_arith_is_sequential() {
399 let p = plan_for("=1+2+3");
400 assert!(matches!(p.root.strategy, ExecStrategy::Sequential));
401 }
402
403 #[test]
404 fn sum_of_many_args_prefers_arg_parallel() {
405 let p = plan_for("=SUM(1,2,3,4,5,6)");
406 assert!(!p.root.children.is_empty()); assert!(matches!(
411 p.root.strategy,
412 ExecStrategy::ArgParallel | ExecStrategy::Sequential
413 ));
414 }
415
416 #[test]
417 fn sumifs_triggers_chunked_reduce_when_large() {
418 let ast = formualizer_parse::parser::parse(r#"=SUMIFS(A:A, A:A, ">0")"#).unwrap();
420 let mut planner = Planner::new(PlanConfig {
421 chunk_min_cells: 1000,
422 ..Default::default()
423 })
424 .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name))
425 .with_range_probe(&|r: &ReferenceType| match r {
426 ReferenceType::Range {
427 start_row: None,
428 end_row: None,
429 ..
430 } => Some((10_000, 1)),
431 _ => None,
432 });
433 let plan = planner.plan(&ast);
434 assert!(matches!(
435 plan.root.strategy,
436 ExecStrategy::ChunkedReduce | ExecStrategy::ArgParallel
437 ));
438 }
439
440 #[test]
441 fn short_circuit_functions_are_sequential() {
442 let p = plan_for("=IF(1,2,3)");
443 assert!(matches!(p.root.strategy, ExecStrategy::Sequential));
444 let p2 = plan_for("=AND(TRUE(), FALSE())");
445 assert!(matches!(p2.root.strategy, ExecStrategy::Sequential));
446 }
447
448 #[test]
449 fn parentheses_do_not_force_parallelism() {
450 let p = plan_for("=(1+2)+(2+3)");
452 assert!(matches!(p.root.strategy, ExecStrategy::Sequential));
453 }
454
455 #[test]
456 fn repeated_subtrees_in_sum_encourage_arg_parallel() {
457 let p = plan_for("=SUM(1+2, 1+2, 1+2, 1+2)");
459 assert!(!p.root.children.is_empty());
461 }
462
463 #[test]
464 fn volatile_forces_sequential() {
465 let ast = formualizer_parse::parser::parse("=NOW()+1").unwrap();
467 let mut planner = Planner::new(PlanConfig::default())
468 .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name));
469 let plan = planner.plan(&ast);
470 assert!(matches!(plan.root.strategy, ExecStrategy::Sequential));
471 }
472
473 #[test]
474 fn whole_column_ranges_prefer_chunked_reduce() {
475 let ast =
477 formualizer_parse::parser::parse(r#"=SUMIFS(A:A, A:A, ">0", B:B, "<5")"#).unwrap();
478 ensure_builtins_registered();
479 let mut planner = Planner::new(PlanConfig {
480 chunk_min_cells: 1000,
481 ..Default::default()
482 })
483 .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name))
484 .with_range_probe(&|r: &ReferenceType| match r {
485 ReferenceType::Range {
486 start_row: None,
487 end_row: None,
488 ..
489 } => Some((50_000, 1)),
490 _ => None,
491 });
492 let plan = planner.plan(&ast);
493 assert!(matches!(
494 plan.root.strategy,
495 ExecStrategy::ChunkedReduce | ExecStrategy::ArgParallel
496 ));
497 }
498
499 #[test]
500 fn deep_sub_ast_criteria_still_plans() {
501 let p = plan_for("=SUMIFS(A1:A100, B1:B100, TEXT(2024+1, \"0\"))");
503 assert!(!p.root.children.is_empty());
505 }
506
507 #[test]
508 fn sum_mixed_scalars_and_large_range_prefers_chunked_reduce() {
509 let ast = formualizer_parse::parser::parse(r#"=SUM(A:A, 1, 2, 3)"#).unwrap();
511 ensure_builtins_registered();
512 let mut planner = Planner::new(PlanConfig {
513 chunk_min_cells: 500,
514 ..Default::default()
515 })
516 .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name))
517 .with_range_probe(&|r: &ReferenceType| match r {
518 ReferenceType::Range {
519 start_row: None,
520 end_row: None,
521 ..
522 } => Some((25_000, 1)),
523 _ => None,
524 });
525 let plan = planner.plan(&ast);
526 assert!(matches!(
527 plan.root.strategy,
528 ExecStrategy::ChunkedReduce | ExecStrategy::ArgParallel
529 ));
530 }
531
532 #[test]
533 fn nested_short_circuit_child_remains_sequential_under_parallel_parent() {
534 let ast = formualizer_parse::parser::parse("=SUM(AND(TRUE(), FALSE()), 1, 2, 3)").unwrap();
536 ensure_builtins_registered();
537 let cfg = PlanConfig {
538 enable_parallel: true,
539 arg_parallel_min_cost_ns: 0,
540 arg_parallel_min_children: 2,
541 chunk_min_cells: 1_000_000, chunk_target_partitions: 8,
543 };
544 let mut planner = Planner::new(cfg)
545 .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name));
546 let plan = planner.plan(&ast);
547 assert!(matches!(
549 plan.root.strategy,
550 ExecStrategy::ArgParallel | ExecStrategy::Sequential
551 ));
552 assert!(!plan.root.children.is_empty());
554 assert!(matches!(
555 plan.root.children[0].strategy,
556 ExecStrategy::Sequential
557 ));
558 }
559
560 #[test]
561 fn repeated_identical_ranges_defaults_to_sequential() {
562 let ast = formualizer_parse::parser::parse(r#"=SUM(A:A, A:A, A:A)"#).unwrap();
564 let mut planner = Planner::new(PlanConfig::default())
565 .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name))
566 .with_range_probe(&|r: &ReferenceType| match r {
567 ReferenceType::Range {
568 start_row: None,
569 end_row: None,
570 ..
571 } => Some((3, 1)),
572 _ => None,
573 });
574 let plan = planner.plan(&ast);
575 assert!(matches!(plan.root.strategy, ExecStrategy::Sequential));
576 assert_eq!(plan.root.children.len(), 3);
577 }
578}