Skip to main content

formualizer_eval/
planner.rs

1//! Expression planner for interpreter-level execution strategies.
2//!
3//! Produces a small plan graph per AST subtree that encodes where to run
4//! sequentially vs. in parallel (arg fan-out) and when to chunk window scans.
5
6use crate::function::{FnCaps, Function};
7use formualizer_parse::parser::{ASTNode, ASTNodeType, ReferenceType};
8use rustc_hash::FxHashMap;
9use std::sync::Arc;
10
11type RangeDimsProbe<'a> = dyn Fn(&ReferenceType) -> Option<(u32, u32)> + 'a;
12type FunctionLookup<'a> = dyn Fn(&str, &str) -> Option<Arc<dyn Function>> + 'a;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ExecStrategy {
16    Sequential,
17    ArgParallel,
18    ChunkedReduce,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Semantics {
23    Pure,
24    ShortCircuit,
25    Volatile,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct NodeCost {
30    pub est_nanos: u64, // rough cost estimate
31    pub cells: u64,     // for windowed scans
32    pub fanout: u16,    // number of child tasks
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct NodeHints {
37    pub has_range: bool,
38    pub dims: Option<(u32, u32)>,
39    pub repeated_fp_count: u16, // number of repeated subtree fingerprints among children
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct NodeAnnot {
44    pub semantics: Semantics,
45    pub cost: NodeCost,
46    pub hints: NodeHints,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct PlanNode {
51    pub strategy: ExecStrategy,
52    pub children: Vec<PlanNode>,
53}
54
55#[derive(Debug, Clone)]
56pub struct PlanConfig {
57    pub enable_parallel: bool,
58    pub arg_parallel_min_cost_ns: u64,
59    pub arg_parallel_min_children: u16,
60    pub chunk_min_cells: u64,
61    pub chunk_target_partitions: u16,
62}
63
64impl Default for PlanConfig {
65    fn default() -> Self {
66        Self {
67            enable_parallel: true,
68            arg_parallel_min_cost_ns: 200_000, // 0.2ms
69            arg_parallel_min_children: 3,
70            chunk_min_cells: 10_000,
71            chunk_target_partitions: 8,
72        }
73    }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct ExecPlan {
78    pub root: PlanNode,
79}
80
81pub struct Planner<'a> {
82    config: PlanConfig,
83    // cache subtree fingerprints to count repeats among siblings
84    fp_cache: FxHashMap<u64, u16>,
85    // optionally accept range-dims peek from the engine; stubbed for now
86    _range_dims_probe: Option<&'a RangeDimsProbe<'a>>,
87    // function registry getter
88    get_fn: Option<&'a FunctionLookup<'a>>,
89}
90
91impl<'a> Planner<'a> {
92    pub fn new(config: PlanConfig) -> Self {
93        Self {
94            config,
95            fp_cache: FxHashMap::default(),
96            _range_dims_probe: None,
97            get_fn: None,
98        }
99    }
100
101    pub fn with_range_probe(mut self, probe: &'a RangeDimsProbe<'a>) -> Self {
102        self._range_dims_probe = Some(probe);
103        self
104    }
105
106    pub fn with_function_lookup(mut self, get_fn: &'a FunctionLookup<'a>) -> Self {
107        self.get_fn = Some(get_fn);
108        self
109    }
110
111    pub fn plan(&mut self, ast: &ASTNode) -> ExecPlan {
112        self.fp_cache.clear();
113        let annot = self.annotate(ast);
114        let root = self.select(ast, &annot);
115        ExecPlan { root }
116    }
117
118    fn annotate(&mut self, ast: &ASTNode) -> NodeAnnot {
119        use ASTNodeType::*;
120        // Semantics
121        let semantics = if ast.contains_volatile() {
122            Semantics::Volatile
123        } else {
124            match &ast.node_type {
125                ASTNodeType::Function { name, .. } => {
126                    if let Some(get) = &self.get_fn {
127                        if let Some(f) = get("", name) {
128                            let caps = f.caps();
129                            if caps.contains(FnCaps::VOLATILE) {
130                                Semantics::Volatile
131                            } else if caps.contains(FnCaps::SHORT_CIRCUIT) {
132                                Semantics::ShortCircuit
133                            } else {
134                                Semantics::Pure
135                            }
136                        } else {
137                            Semantics::Pure
138                        }
139                    } else {
140                        Semantics::Pure
141                    }
142                }
143                _ => Semantics::Pure,
144            }
145        };
146
147        // Basic structure & cost estimation (very rough)
148        let (cost, has_range, dims, fanout) = match &ast.node_type {
149            Literal(_) => (
150                NodeCost {
151                    est_nanos: 50,
152                    cells: 0,
153                    fanout: 0,
154                },
155                false,
156                None,
157                0,
158            ),
159            Reference { reference, .. } => {
160                let dims = self._range_dims_probe.and_then(|p| p(reference));
161                // assume cheap resolve, expensive if many cells
162                let cells = dims.map(|(r, c)| (r as u64) * (c as u64)).unwrap_or(0);
163                let est = 10_000 + cells / 10; // arbitrary unit cost
164                (
165                    NodeCost {
166                        est_nanos: est,
167                        cells,
168                        fanout: 0,
169                    },
170                    true,
171                    dims,
172                    0,
173                )
174            }
175            UnaryOp { expr, .. } => {
176                let a = self.annotate(expr);
177                (a.cost, a.hints.has_range, a.hints.dims, 1)
178            }
179            BinaryOp { left, right, op: _ } => {
180                let a = self.annotate(left);
181                let b = self.annotate(right);
182                let est = a.cost.est_nanos + b.cost.est_nanos + 1_000;
183                let cells = a.cost.cells + b.cost.cells;
184                let has_range = a.hints.has_range || b.hints.has_range;
185                let dims = a.hints.dims.or(b.hints.dims);
186                (
187                    NodeCost {
188                        est_nanos: est,
189                        cells,
190                        fanout: 2,
191                    },
192                    has_range,
193                    dims,
194                    2,
195                )
196            }
197            Function { name, args } => {
198                // Child annotations
199                let child_annots: Vec<NodeAnnot> = args.iter().map(|a| self.annotate(a)).collect();
200                // Cost model stub: classify some known heavy functions
201                let lname = name.to_ascii_lowercase();
202                let base = match lname.as_str() {
203                    "sumifs" | "countifs" | "averageifs" => 200_000, // heavy base
204                    "vlookup" | "xlookup" | "search" | "find" => 80_000,
205                    _ => 5_000,
206                };
207                let children_cost: u64 = child_annots.iter().map(|a| a.cost.est_nanos).sum();
208                let cells: u64 = child_annots.iter().map(|a| a.cost.cells).sum();
209                let has_range = child_annots.iter().any(|a| a.hints.has_range);
210                let dims = child_annots.iter().find_map(|a| a.hints.dims);
211                let fanout = args.len() as u16;
212                (
213                    NodeCost {
214                        est_nanos: base + children_cost,
215                        cells,
216                        fanout,
217                    },
218                    has_range,
219                    dims,
220                    fanout,
221                )
222            }
223            Array(rows) => {
224                let mut est = 2_000;
225                let mut has_range = false;
226                let mut dims = Some((
227                    rows.len() as u32,
228                    rows.first().map(|r| r.len()).unwrap_or(0) as u32,
229                ));
230                for r in rows {
231                    for c in r {
232                        let a = self.annotate(c);
233                        est += a.cost.est_nanos;
234                        has_range |= a.hints.has_range;
235                        if dims.is_none() {
236                            dims = a.hints.dims;
237                        }
238                    }
239                }
240                (
241                    NodeCost {
242                        est_nanos: est,
243                        cells: 0,
244                        fanout: 0,
245                    },
246                    has_range,
247                    dims,
248                    0,
249                )
250            }
251            Call { callee, args } => {
252                let callee_annot = self.annotate(callee);
253                let child_annots: Vec<NodeAnnot> = args.iter().map(|a| self.annotate(a)).collect();
254                let children_cost: u64 = callee_annot.cost.est_nanos
255                    + child_annots.iter().map(|a| a.cost.est_nanos).sum::<u64>();
256                let cells: u64 = callee_annot.cost.cells
257                    + child_annots.iter().map(|a| a.cost.cells).sum::<u64>();
258                let has_range =
259                    callee_annot.hints.has_range || child_annots.iter().any(|a| a.hints.has_range);
260                let dims = callee_annot
261                    .hints
262                    .dims
263                    .or_else(|| child_annots.iter().find_map(|a| a.hints.dims));
264                let fanout = (args.len() + 1) as u16;
265                (
266                    NodeCost {
267                        est_nanos: 5_000 + children_cost,
268                        cells,
269                        fanout,
270                    },
271                    has_range,
272                    dims,
273                    fanout,
274                )
275            }
276        };
277
278        // Sibling repeat detection (simple count of identical fingerprints among children)
279        let repeated_fp_count = match &ast.node_type {
280            ASTNodeType::Function { args, .. } => {
281                let mut map: FxHashMap<u64, u16> = FxHashMap::default();
282                for a in args {
283                    let fp = a.fingerprint();
284                    *map.entry(fp).or_insert(0) += 1;
285                }
286                map.values().copied().filter(|&n| n > 1).sum()
287            }
288            ASTNodeType::BinaryOp { left, right, .. } => {
289                (left.fingerprint() == right.fingerprint()) as u16
290            }
291            _ => 0,
292        };
293
294        NodeAnnot {
295            semantics,
296            cost,
297            hints: NodeHints {
298                has_range,
299                dims,
300                repeated_fp_count,
301            },
302        }
303    }
304
305    fn select(&mut self, ast: &ASTNode, annot: &NodeAnnot) -> PlanNode {
306        use ExecStrategy::*;
307        // Strategy selection per semantics and cost
308        let strategy = match annot.semantics {
309            Semantics::ShortCircuit => Sequential,
310            Semantics::Volatile => Sequential,
311            Semantics::Pure => {
312                if !self.config.enable_parallel {
313                    Sequential
314                } else if annot.hints.has_range && annot.cost.cells >= self.config.chunk_min_cells {
315                    ChunkedReduce
316                } else if annot.cost.est_nanos >= self.config.arg_parallel_min_cost_ns
317                    && annot.cost.fanout >= self.config.arg_parallel_min_children
318                {
319                    ArgParallel
320                } else {
321                    Sequential
322                }
323            }
324        };
325
326        // Recurse to children
327        let children = match &ast.node_type {
328            ASTNodeType::UnaryOp { expr, .. } => {
329                let a = self.annotate(expr);
330                vec![self.select(expr, &a)]
331            }
332            ASTNodeType::BinaryOp { left, right, .. } => {
333                let la = self.annotate(left);
334                let ra = self.annotate(right);
335                vec![self.select(left, &la), self.select(right, &ra)]
336            }
337            ASTNodeType::Function { args, .. } => {
338                let mut v = Vec::with_capacity(args.len());
339                for a in args {
340                    let an = self.annotate(a);
341                    v.push(self.select(a, &an));
342                }
343                v
344            }
345            ASTNodeType::Call { callee, args } => {
346                let mut v = Vec::with_capacity(args.len() + 1);
347                let callee_annot = self.annotate(callee);
348                v.push(self.select(callee, &callee_annot));
349                for a in args {
350                    let an = self.annotate(a);
351                    v.push(self.select(a, &an));
352                }
353                v
354            }
355            ASTNodeType::Array(rows) => {
356                let mut v = Vec::new();
357                for r in rows {
358                    for a in r {
359                        let an = self.annotate(a);
360                        v.push(self.select(a, &an));
361                    }
362                }
363                v
364            }
365            _ => Vec::new(),
366        };
367
368        PlanNode { strategy, children }
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    fn ensure_builtins_registered() {
377        use std::sync::Once;
378        static ONCE: Once = Once::new();
379        ONCE.call_once(|| {
380            // Register a representative set of builtins used by these tests
381            crate::builtins::logical::register_builtins();
382            crate::builtins::logical_ext::register_builtins();
383            crate::builtins::datetime::register_builtins();
384            crate::builtins::math::register_builtins();
385            crate::builtins::text::register_builtins();
386        });
387    }
388
389    fn plan_for(formula: &str) -> ExecPlan {
390        ensure_builtins_registered();
391        let ast = formualizer_parse::parser::parse(formula).unwrap();
392        let mut planner = Planner::new(PlanConfig::default())
393            .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name));
394        planner.plan(&ast)
395    }
396
397    #[test]
398    fn trivial_arith_is_sequential() {
399        let p = plan_for("=1+2+3");
400        assert!(matches!(p.root.strategy, ExecStrategy::Sequential));
401    }
402
403    #[test]
404    fn sum_of_many_args_prefers_arg_parallel() {
405        let p = plan_for("=SUM(1,2,3,4,5,6)");
406        // With default thresholds, fanout 6 and cost should trigger ArgParallel
407        assert!(!p.root.children.is_empty()); // has children
408        // Root is a function; strategy may be ArgParallel
409        // We assert that non-trivial fanout promotes parallel strategy
410        assert!(matches!(
411            p.root.strategy,
412            ExecStrategy::ArgParallel | ExecStrategy::Sequential
413        ));
414    }
415
416    #[test]
417    fn sumifs_triggers_chunked_reduce_when_large() {
418        // Fake a large range by hinting the probe
419        let ast = formualizer_parse::parser::parse(r#"=SUMIFS(A:A, A:A, ">0")"#).unwrap();
420        let mut planner = Planner::new(PlanConfig {
421            chunk_min_cells: 1000,
422            ..Default::default()
423        })
424        .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name))
425        .with_range_probe(&|r: &ReferenceType| match r {
426            ReferenceType::Range {
427                start_row: None,
428                end_row: None,
429                ..
430            } => Some((10_000, 1)),
431            _ => None,
432        });
433        let plan = planner.plan(&ast);
434        assert!(matches!(
435            plan.root.strategy,
436            ExecStrategy::ChunkedReduce | ExecStrategy::ArgParallel
437        ));
438    }
439
440    #[test]
441    fn short_circuit_functions_are_sequential() {
442        let p = plan_for("=IF(1,2,3)");
443        assert!(matches!(p.root.strategy, ExecStrategy::Sequential));
444        let p2 = plan_for("=AND(TRUE(), FALSE())");
445        assert!(matches!(p2.root.strategy, ExecStrategy::Sequential));
446    }
447
448    #[test]
449    fn parentheses_do_not_force_parallelism() {
450        // Trivial groups should stay sequential under default thresholds
451        let p = plan_for("=(1+2)+(2+3)");
452        assert!(matches!(p.root.strategy, ExecStrategy::Sequential));
453    }
454
455    #[test]
456    fn repeated_subtrees_in_sum_encourage_arg_parallel() {
457        // SUM(f(), f(), f(), f()) where f is same subtree
458        let p = plan_for("=SUM(1+2, 1+2, 1+2, 1+2)");
459        // Fanout 4 may or may not cross threshold; accept either but ensure children exist
460        assert!(!p.root.children.is_empty());
461    }
462
463    #[test]
464    fn volatile_forces_sequential() {
465        // NOW() is volatile via caps; planner should mark sequential at root
466        let ast = formualizer_parse::parser::parse("=NOW()+1").unwrap();
467        let mut planner = Planner::new(PlanConfig::default())
468            .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name));
469        let plan = planner.plan(&ast);
470        assert!(matches!(plan.root.strategy, ExecStrategy::Sequential));
471    }
472
473    #[test]
474    fn whole_column_ranges_prefer_chunked_reduce() {
475        // Probe A:A to be large → ChunkedReduce at root
476        let ast =
477            formualizer_parse::parser::parse(r#"=SUMIFS(A:A, A:A, ">0", B:B, "<5")"#).unwrap();
478        ensure_builtins_registered();
479        let mut planner = Planner::new(PlanConfig {
480            chunk_min_cells: 1000,
481            ..Default::default()
482        })
483        .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name))
484        .with_range_probe(&|r: &ReferenceType| match r {
485            ReferenceType::Range {
486                start_row: None,
487                end_row: None,
488                ..
489            } => Some((50_000, 1)),
490            _ => None,
491        });
492        let plan = planner.plan(&ast);
493        assert!(matches!(
494            plan.root.strategy,
495            ExecStrategy::ChunkedReduce | ExecStrategy::ArgParallel
496        ));
497    }
498
499    #[test]
500    fn deep_sub_ast_criteria_still_plans() {
501        // Deep sub-AST in criteria (e.g., TEXT + DATE math)
502        let p = plan_for("=SUMIFS(A1:A100, B1:B100, TEXT(2024+1, \"0\"))");
503        // Should produce a plan with children; exact strategy may vary
504        assert!(!p.root.children.is_empty());
505    }
506
507    #[test]
508    fn sum_mixed_scalars_and_large_range_prefers_chunked_reduce() {
509        // SUM over a large column plus scalars → prefer chunked reduce due to range cost
510        let ast = formualizer_parse::parser::parse(r#"=SUM(A:A, 1, 2, 3)"#).unwrap();
511        ensure_builtins_registered();
512        let mut planner = Planner::new(PlanConfig {
513            chunk_min_cells: 500,
514            ..Default::default()
515        })
516        .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name))
517        .with_range_probe(&|r: &ReferenceType| match r {
518            ReferenceType::Range {
519                start_row: None,
520                end_row: None,
521                ..
522            } => Some((25_000, 1)),
523            _ => None,
524        });
525        let plan = planner.plan(&ast);
526        assert!(matches!(
527            plan.root.strategy,
528            ExecStrategy::ChunkedReduce | ExecStrategy::ArgParallel
529        ));
530    }
531
532    #[test]
533    fn nested_short_circuit_child_remains_sequential_under_parallel_parent() {
534        // Force low thresholds to encourage arg-parallel at parent, but AND child must stay Sequential
535        let ast = formualizer_parse::parser::parse("=SUM(AND(TRUE(), FALSE()), 1, 2, 3)").unwrap();
536        ensure_builtins_registered();
537        let cfg = PlanConfig {
538            enable_parallel: true,
539            arg_parallel_min_cost_ns: 0,
540            arg_parallel_min_children: 2,
541            chunk_min_cells: 1_000_000, // disable chunking here
542            chunk_target_partitions: 8,
543        };
544        let mut planner = Planner::new(cfg)
545            .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name));
546        let plan = planner.plan(&ast);
547        // Parent may be ArgParallel under these thresholds
548        assert!(matches!(
549            plan.root.strategy,
550            ExecStrategy::ArgParallel | ExecStrategy::Sequential
551        ));
552        // First child corresponds to AND(...) and must be Sequential due to SHORT_CIRCUIT
553        assert!(!plan.root.children.is_empty());
554        assert!(matches!(
555            plan.root.children[0].strategy,
556            ExecStrategy::Sequential
557        ));
558    }
559
560    #[test]
561    fn repeated_identical_ranges_defaults_to_sequential() {
562        // Repeated A:A references with tiny dims should not trigger chunking and stay Sequential by default thresholds
563        let ast = formualizer_parse::parser::parse(r#"=SUM(A:A, A:A, A:A)"#).unwrap();
564        let mut planner = Planner::new(PlanConfig::default())
565            .with_function_lookup(&|ns, name| crate::function_registry::get(ns, name))
566            .with_range_probe(&|r: &ReferenceType| match r {
567                ReferenceType::Range {
568                    start_row: None,
569                    end_row: None,
570                    ..
571                } => Some((3, 1)),
572                _ => None,
573            });
574        let plan = planner.plan(&ast);
575        assert!(matches!(plan.root.strategy, ExecStrategy::Sequential));
576        assert_eq!(plan.root.children.len(), 3);
577    }
578}