Skip to main content

cjc_data/
lazy.rs

1//! Lazy evaluation IR for TidyView.
2//!
3//! A `LazyView` captures a chain of operations as a tree of `ViewNode`s.
4//! On `.collect()`, the tree is optimized by rule-based passes, then executed.
5//!
6//! # Determinism
7//!
8//! - All plan transformations are pure, deterministic functions.
9//! - No HashMap/HashSet usage -- only Vec and BTreeSet for column tracking.
10//! - Execution delegates to TidyView/TidyFrame methods which already guarantee
11//!   deterministic output (Kahan summation, BTreeMap groups, stable sorts).
12
13use crate::{ArrangeKey, Column, DExpr, DBinOp, DataFrame, TidyAgg, TidyError, TidyFrame};
14use std::collections::BTreeSet;
15use std::rc::Rc;
16
17// ── ViewNode IR ──────────────────────────────────────────────────────────────
18
19/// A node in the lazy evaluation tree.
20#[derive(Debug, Clone)]
21pub enum ViewNode {
22    /// Leaf: scan a base DataFrame.
23    Scan { df: Rc<DataFrame> },
24    /// Filter rows by predicate.
25    Filter {
26        input: Box<ViewNode>,
27        predicate: DExpr,
28    },
29    /// Project to subset of columns.
30    Select {
31        input: Box<ViewNode>,
32        columns: Vec<String>,
33    },
34    /// Add/replace columns via expressions.
35    Mutate {
36        input: Box<ViewNode>,
37        assignments: Vec<(String, DExpr)>,
38    },
39    /// Sort by keys.
40    Arrange {
41        input: Box<ViewNode>,
42        keys: Vec<ArrangeKey>,
43    },
44    /// Group + summarise (pipeline breaker).
45    GroupSummarise {
46        input: Box<ViewNode>,
47        group_keys: Vec<String>,
48        aggregations: Vec<(String, TidyAgg)>,
49    },
50    /// v3 Phase 6: streaming-friendly group + summarise. Produced by
51    /// the lazy optimizer (`annotate_streamable_summarise`) when every
52    /// aggregation is one of {Count, Sum, Mean, Min, Max, Var, Sd}. At
53    /// execution this dispatches to `TidyView::summarise_streaming`,
54    /// avoiding the per-group `Vec<usize>` materialisation.
55    StreamingGroupSummarise {
56        input: Box<ViewNode>,
57        group_keys: Vec<String>,
58        aggregations: Vec<(String, crate::StreamingAgg)>,
59    },
60    /// Distinct on columns.
61    Distinct {
62        input: Box<ViewNode>,
63        columns: Vec<String>,
64    },
65    /// Join two inputs.
66    Join {
67        left: Box<ViewNode>,
68        right: Box<ViewNode>,
69        on: Vec<(String, String)>,
70        kind: JoinType,
71    },
72}
73
74/// The kind of join.
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum JoinType {
77    Inner,
78    Left,
79    Semi,
80    Anti,
81}
82
83// ── LazyView builder ─────────────────────────────────────────────────────────
84
85/// A lazy view that captures operations without executing them.
86pub struct LazyView {
87    plan: ViewNode,
88}
89
90impl LazyView {
91    /// Create from a DataFrame (takes ownership, wraps in Rc).
92    pub fn from_df(df: DataFrame) -> Self {
93        LazyView {
94            plan: ViewNode::Scan { df: Rc::new(df) },
95        }
96    }
97
98    /// Create from an Rc<DataFrame>.
99    pub fn from_rc(df: Rc<DataFrame>) -> Self {
100        LazyView {
101            plan: ViewNode::Scan { df },
102        }
103    }
104
105    /// Filter rows by a DExpr predicate.
106    pub fn filter(self, predicate: DExpr) -> Self {
107        LazyView {
108            plan: ViewNode::Filter {
109                input: Box::new(self.plan),
110                predicate,
111            },
112        }
113    }
114
115    /// Project to a subset of columns.
116    pub fn select(self, columns: Vec<String>) -> Self {
117        LazyView {
118            plan: ViewNode::Select {
119                input: Box::new(self.plan),
120                columns,
121            },
122        }
123    }
124
125    /// Add or replace columns via expressions.
126    pub fn mutate(self, assignments: Vec<(String, DExpr)>) -> Self {
127        LazyView {
128            plan: ViewNode::Mutate {
129                input: Box::new(self.plan),
130                assignments,
131            },
132        }
133    }
134
135    /// Sort rows by keys.
136    pub fn arrange(self, keys: Vec<ArrangeKey>) -> Self {
137        LazyView {
138            plan: ViewNode::Arrange {
139                input: Box::new(self.plan),
140                keys,
141            },
142        }
143    }
144
145    /// Group by keys and aggregate.
146    pub fn group_summarise(
147        self,
148        group_keys: Vec<String>,
149        aggregations: Vec<(String, TidyAgg)>,
150    ) -> Self {
151        LazyView {
152            plan: ViewNode::GroupSummarise {
153                input: Box::new(self.plan),
154                group_keys,
155                aggregations,
156            },
157        }
158    }
159
160    /// Keep distinct rows by columns.
161    pub fn distinct(self, columns: Vec<String>) -> Self {
162        LazyView {
163            plan: ViewNode::Distinct {
164                input: Box::new(self.plan),
165                columns,
166            },
167        }
168    }
169
170    /// Join with another LazyView.
171    pub fn join(self, right: LazyView, on: Vec<(String, String)>, kind: JoinType) -> Self {
172        LazyView {
173            plan: ViewNode::Join {
174                left: Box::new(self.plan),
175                right: Box::new(right.plan),
176                on,
177                kind,
178            },
179        }
180    }
181
182    /// Optimize and execute the plan, returning a TidyFrame.
183    pub fn collect(self) -> Result<TidyFrame, TidyError> {
184        let optimized = optimize(self.plan);
185        execute(optimized)
186    }
187
188    /// Inspect the plan tree (for testing/debugging).
189    pub fn plan(&self) -> &ViewNode {
190        &self.plan
191    }
192
193    /// Consume and return the optimized plan without executing (for testing).
194    pub fn optimized_plan(self) -> ViewNode {
195        optimize(self.plan)
196    }
197}
198
199// ── Optimizer ────────────────────────────────────────────────────────────────
200
201/// Apply all optimization passes to a ViewNode tree.
202///
203/// Pass order matters: merge filters first so pushdown sees fewer nodes,
204/// then push predicates toward scans, then prune redundant selects.
205pub fn optimize(plan: ViewNode) -> ViewNode {
206    let plan = merge_filters(plan);
207    let plan = push_predicates_down(plan);
208    let plan = eliminate_redundant_selects(plan);
209    // v3 Phase 6: rewrite GroupSummarise → StreamingGroupSummarise when
210    // every aggregation is streaming-friendly. This avoids materialising
211    // the per-group `Vec<usize>` row index buffers (~8 bytes × N rows).
212    let plan = annotate_streamable_summarise(plan);
213    plan
214}
215
216// ── v3 Phase 6: streamable-summarise annotation ─────────────────────────────
217//
218// `GroupSummarise` becomes `StreamingGroupSummarise` when every aggregation
219// is one of {Count, Sum, Mean, Min, Max, Var, Sd}. Median / Quantile /
220// NDistinct / IQR / First / Last require the full row-index list and stay
221// on the legacy path.
222//
223// Output shape: byte-equal to legacy path on the streamable subset.
224// Determinism: BTreeMap iteration order over key tuples (Vec<u32> codes
225// for cat-aware path, Vec<String> displays otherwise).
226
227fn try_streaming_agg(agg: &TidyAgg) -> Option<crate::StreamingAgg> {
228    use crate::StreamingAgg;
229    match agg {
230        TidyAgg::Count => Some(StreamingAgg::Count),
231        TidyAgg::Sum(c) => Some(StreamingAgg::Sum(c.clone())),
232        TidyAgg::Mean(c) => Some(StreamingAgg::Mean(c.clone())),
233        TidyAgg::Min(c) => Some(StreamingAgg::Min(c.clone())),
234        TidyAgg::Max(c) => Some(StreamingAgg::Max(c.clone())),
235        TidyAgg::Var(c) => Some(StreamingAgg::Var(c.clone())),
236        TidyAgg::Sd(c) => Some(StreamingAgg::Sd(c.clone())),
237        // Median / Quantile / NDistinct / Iqr / First / Last need the
238        // materialised row index list — not streaming-friendly.
239        _ => None,
240    }
241}
242
243fn annotate_streamable_summarise(plan: ViewNode) -> ViewNode {
244    match plan {
245        ViewNode::GroupSummarise {
246            input,
247            group_keys,
248            aggregations,
249        } => {
250            let input = Box::new(annotate_streamable_summarise(*input));
251            // All-or-nothing: if any aggregation is non-streaming, keep
252            // the whole node on the legacy path. Mixed dispatch would
253            // require a second walk.
254            let all_streaming: Option<Vec<(String, crate::StreamingAgg)>> = aggregations
255                .iter()
256                .map(|(name, agg)| try_streaming_agg(agg).map(|sa| (name.clone(), sa)))
257                .collect();
258            match all_streaming {
259                Some(streaming_aggs) => ViewNode::StreamingGroupSummarise {
260                    input,
261                    group_keys,
262                    aggregations: streaming_aggs,
263                },
264                None => ViewNode::GroupSummarise {
265                    input,
266                    group_keys,
267                    aggregations,
268                },
269            }
270        }
271        ViewNode::Filter { input, predicate } => ViewNode::Filter {
272            input: Box::new(annotate_streamable_summarise(*input)),
273            predicate,
274        },
275        ViewNode::Select { input, columns } => ViewNode::Select {
276            input: Box::new(annotate_streamable_summarise(*input)),
277            columns,
278        },
279        ViewNode::Mutate { input, assignments } => ViewNode::Mutate {
280            input: Box::new(annotate_streamable_summarise(*input)),
281            assignments,
282        },
283        ViewNode::Arrange { input, keys } => ViewNode::Arrange {
284            input: Box::new(annotate_streamable_summarise(*input)),
285            keys,
286        },
287        ViewNode::Distinct { input, columns } => ViewNode::Distinct {
288            input: Box::new(annotate_streamable_summarise(*input)),
289            columns,
290        },
291        ViewNode::Join {
292            left,
293            right,
294            on,
295            kind,
296        } => ViewNode::Join {
297            left: Box::new(annotate_streamable_summarise(*left)),
298            right: Box::new(annotate_streamable_summarise(*right)),
299            on,
300            kind,
301        },
302        ViewNode::StreamingGroupSummarise { .. } => plan,
303        ViewNode::Scan { .. } => plan,
304    }
305}
306
307// ── Pass 1: Filter Merging ───────────────────────────────────────────────────
308
309/// Merge consecutive Filter nodes into a single Filter with AND predicate.
310///
311/// `Filter(Filter(input, p1), p2)` becomes `Filter(input, p1 AND p2)`.
312fn merge_filters(plan: ViewNode) -> ViewNode {
313    match plan {
314        ViewNode::Filter { input, predicate } => {
315            let merged_input = merge_filters(*input);
316            match merged_input {
317                ViewNode::Filter {
318                    input: inner,
319                    predicate: inner_pred,
320                } => {
321                    // Combine: inner_pred AND predicate
322                    let combined = DExpr::BinOp {
323                        op: DBinOp::And,
324                        left: Box::new(inner_pred),
325                        right: Box::new(predicate),
326                    };
327                    ViewNode::Filter {
328                        input: inner,
329                        predicate: combined,
330                    }
331                }
332                other => ViewNode::Filter {
333                    input: Box::new(other),
334                    predicate,
335                },
336            }
337        }
338        // Recurse into all other node types
339        ViewNode::Select { input, columns } => ViewNode::Select {
340            input: Box::new(merge_filters(*input)),
341            columns,
342        },
343        ViewNode::Mutate {
344            input,
345            assignments,
346        } => ViewNode::Mutate {
347            input: Box::new(merge_filters(*input)),
348            assignments,
349        },
350        ViewNode::Arrange { input, keys } => ViewNode::Arrange {
351            input: Box::new(merge_filters(*input)),
352            keys,
353        },
354        ViewNode::GroupSummarise {
355            input,
356            group_keys,
357            aggregations,
358        } => ViewNode::GroupSummarise {
359            input: Box::new(merge_filters(*input)),
360            group_keys,
361            aggregations,
362        },
363        ViewNode::Distinct { input, columns } => ViewNode::Distinct {
364            input: Box::new(merge_filters(*input)),
365            columns,
366        },
367        ViewNode::Join {
368            left,
369            right,
370            on,
371            kind,
372        } => ViewNode::Join {
373            left: Box::new(merge_filters(*left)),
374            right: Box::new(merge_filters(*right)),
375            on,
376            kind,
377        },
378        other => other, // Scan
379    }
380}
381
382// ── Pass 2: Predicate Pushdown ───────────────────────────────────────────────
383
384/// Push Filter nodes closer to Scan nodes.
385///
386/// Rules:
387/// - Filter past Select: always safe (filter refs columns that must exist).
388/// - Filter past Mutate: only if predicate does NOT reference any mutated column.
389/// - Filter into Join: push to the side that owns ALL referenced columns.
390/// - Filter past Arrange: always safe (sort order preserved after filter).
391/// - Do NOT push past GroupSummarise (aggregation changes row identity).
392/// - Do NOT push past Distinct (distinct changes row identity).
393fn push_predicates_down(plan: ViewNode) -> ViewNode {
394    match plan {
395        ViewNode::Filter { input, predicate } => {
396            let optimized_input = push_predicates_down(*input);
397            push_filter_into(optimized_input, predicate)
398        }
399        // Recurse into all other nodes
400        ViewNode::Select { input, columns } => ViewNode::Select {
401            input: Box::new(push_predicates_down(*input)),
402            columns,
403        },
404        ViewNode::Mutate {
405            input,
406            assignments,
407        } => ViewNode::Mutate {
408            input: Box::new(push_predicates_down(*input)),
409            assignments,
410        },
411        ViewNode::Arrange { input, keys } => ViewNode::Arrange {
412            input: Box::new(push_predicates_down(*input)),
413            keys,
414        },
415        ViewNode::GroupSummarise {
416            input,
417            group_keys,
418            aggregations,
419        } => ViewNode::GroupSummarise {
420            input: Box::new(push_predicates_down(*input)),
421            group_keys,
422            aggregations,
423        },
424        ViewNode::Distinct { input, columns } => ViewNode::Distinct {
425            input: Box::new(push_predicates_down(*input)),
426            columns,
427        },
428        ViewNode::Join {
429            left,
430            right,
431            on,
432            kind,
433        } => ViewNode::Join {
434            left: Box::new(push_predicates_down(*left)),
435            right: Box::new(push_predicates_down(*right)),
436            on,
437            kind,
438        },
439        other => other,
440    }
441}
442
443/// Try to push a filter predicate below the given node.
444fn push_filter_into(node: ViewNode, predicate: DExpr) -> ViewNode {
445    match node {
446        // Push filter past Select (always safe -- filter references columns
447        // that are in the select list or the query is malformed anyway).
448        ViewNode::Select { input, columns } => ViewNode::Select {
449            input: Box::new(push_filter_into(*input, predicate)),
450            columns,
451        },
452
453        // Push filter past Arrange (filter doesn't affect sort order).
454        ViewNode::Arrange { input, keys } => ViewNode::Arrange {
455            input: Box::new(push_filter_into(*input, predicate)),
456            keys,
457        },
458
459        // Push filter past Mutate only if the predicate does NOT reference
460        // any column that Mutate introduces/replaces.
461        ViewNode::Mutate {
462            input,
463            assignments,
464        } => {
465            let pred_cols = expr_columns(&predicate);
466            let mutated_cols: BTreeSet<String> =
467                assignments.iter().map(|(name, _)| name.clone()).collect();
468            let references_mutated = pred_cols.iter().any(|c| mutated_cols.contains(c));
469
470            if references_mutated {
471                // Cannot push -- predicate depends on mutated columns.
472                ViewNode::Filter {
473                    input: Box::new(ViewNode::Mutate {
474                        input,
475                        assignments,
476                    }),
477                    predicate,
478                }
479            } else {
480                // Safe to push below.
481                ViewNode::Mutate {
482                    input: Box::new(push_filter_into(*input, predicate)),
483                    assignments,
484                }
485            }
486        }
487
488        // Push filter into Join: if predicate references only left-side columns,
489        // push into left; if only right-side, push into right; otherwise keep above.
490        ViewNode::Join {
491            left,
492            right,
493            on,
494            kind,
495        } => {
496            let pred_cols = expr_columns(&predicate);
497            let left_cols = node_output_columns(&left);
498            let right_cols = node_output_columns(&right);
499
500            let all_in_left = pred_cols.iter().all(|c| left_cols.contains(c));
501            let all_in_right = pred_cols.iter().all(|c| right_cols.contains(c));
502
503            if all_in_left {
504                ViewNode::Join {
505                    left: Box::new(push_filter_into(*left, predicate)),
506                    right,
507                    on,
508                    kind,
509                }
510            } else if all_in_right {
511                ViewNode::Join {
512                    left,
513                    right: Box::new(push_filter_into(*right, predicate)),
514                    on,
515                    kind,
516                }
517            } else {
518                // Predicate spans both sides -- keep above.
519                ViewNode::Filter {
520                    input: Box::new(ViewNode::Join {
521                        left,
522                        right,
523                        on,
524                        kind,
525                    }),
526                    predicate,
527                }
528            }
529        }
530
531        // Do NOT push past GroupSummarise or Distinct -- they change row identity.
532        other => ViewNode::Filter {
533            input: Box::new(other),
534            predicate,
535        },
536    }
537}
538
539// ── Pass 3: Redundant Select Elimination ─────────────────────────────────────
540
541/// Remove Select nodes that select all columns from their input
542/// (i.e., the select list matches the input's output columns exactly).
543fn eliminate_redundant_selects(plan: ViewNode) -> ViewNode {
544    match plan {
545        ViewNode::Select { input, columns } => {
546            let optimized_input = eliminate_redundant_selects(*input);
547            let input_cols = node_output_columns(&optimized_input);
548
549            // If the select list matches all input columns (same set), remove it.
550            let select_set: BTreeSet<&str> = columns.iter().map(|s| s.as_str()).collect();
551            let input_set: BTreeSet<&str> = input_cols.iter().map(|s| s.as_str()).collect();
552
553            if select_set == input_set {
554                optimized_input
555            } else {
556                ViewNode::Select {
557                    input: Box::new(optimized_input),
558                    columns,
559                }
560            }
561        }
562        ViewNode::Filter { input, predicate } => ViewNode::Filter {
563            input: Box::new(eliminate_redundant_selects(*input)),
564            predicate,
565        },
566        ViewNode::Mutate {
567            input,
568            assignments,
569        } => ViewNode::Mutate {
570            input: Box::new(eliminate_redundant_selects(*input)),
571            assignments,
572        },
573        ViewNode::Arrange { input, keys } => ViewNode::Arrange {
574            input: Box::new(eliminate_redundant_selects(*input)),
575            keys,
576        },
577        ViewNode::GroupSummarise {
578            input,
579            group_keys,
580            aggregations,
581        } => ViewNode::GroupSummarise {
582            input: Box::new(eliminate_redundant_selects(*input)),
583            group_keys,
584            aggregations,
585        },
586        ViewNode::Distinct { input, columns } => ViewNode::Distinct {
587            input: Box::new(eliminate_redundant_selects(*input)),
588            columns,
589        },
590        ViewNode::Join {
591            left,
592            right,
593            on,
594            kind,
595        } => ViewNode::Join {
596            left: Box::new(eliminate_redundant_selects(*left)),
597            right: Box::new(eliminate_redundant_selects(*right)),
598            on,
599            kind,
600        },
601        other => other,
602    }
603}
604
605// ── Executor ─────────────────────────────────────────────────────────────────
606
607/// Execute an optimized ViewNode tree, producing a TidyFrame.
608fn execute(node: ViewNode) -> Result<TidyFrame, TidyError> {
609    match node {
610        ViewNode::Scan { df } => Ok(TidyFrame::from_df((*df).clone())),
611
612        ViewNode::Filter { input, predicate } => {
613            let frame = execute(*input)?;
614            let view = frame.view();
615            let filtered = view.filter(&predicate)?;
616            let df = filtered.materialize()?;
617            Ok(TidyFrame::from_df(df))
618        }
619
620        ViewNode::Select { input, columns } => {
621            let frame = execute(*input)?;
622            let view = frame.view();
623            let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
624            let selected = view.select(&col_refs)?;
625            let df = selected.materialize()?;
626            Ok(TidyFrame::from_df(df))
627        }
628
629        ViewNode::Mutate {
630            input,
631            assignments,
632        } => {
633            let frame = execute(*input)?;
634            let view = frame.view();
635            let assign_refs: Vec<(&str, DExpr)> = assignments
636                .into_iter()
637                .map(|(name, expr)| (leaked_str(&name), expr))
638                .collect();
639            // Use TidyView::mutate which materializes + applies assignments.
640            let result = view.mutate(&assign_refs.iter().map(|(n, e)| (*n, e.clone())).collect::<Vec<_>>())?;
641            Ok(result)
642        }
643
644        ViewNode::Arrange { input, keys } => {
645            let frame = execute(*input)?;
646            let view = frame.view();
647            let arranged = view.arrange(&keys)?;
648            let df = arranged.materialize()?;
649            Ok(TidyFrame::from_df(df))
650        }
651
652        ViewNode::GroupSummarise {
653            input,
654            group_keys,
655            aggregations,
656        } => {
657            let frame = execute(*input)?;
658            let view = frame.view();
659            let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
660            let grouped = view.group_by(&key_refs)?;
661            let agg_refs: Vec<(&str, TidyAgg)> = aggregations
662                .into_iter()
663                .map(|(name, agg)| (leaked_str(&name), agg))
664                .collect();
665            let result = grouped.summarise(
666                &agg_refs.iter().map(|(n, a)| (*n, a.clone())).collect::<Vec<_>>(),
667            )?;
668            Ok(result)
669        }
670
671        ViewNode::StreamingGroupSummarise {
672            input,
673            group_keys,
674            aggregations,
675        } => {
676            let frame = execute(*input)?;
677            let view = frame.view();
678            let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
679            let agg_owned: Vec<(String, crate::StreamingAgg)> = aggregations;
680            let agg_refs: Vec<(&str, crate::StreamingAgg)> = agg_owned
681                .iter()
682                .map(|(name, sa)| (leaked_str(name), sa.clone()))
683                .collect();
684            view.summarise_streaming(&key_refs, &agg_refs)
685        }
686
687        ViewNode::Distinct { input, columns } => {
688            let frame = execute(*input)?;
689            let view = frame.view();
690            let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
691            let distinct = view.distinct(&col_refs)?;
692            let df = distinct.materialize()?;
693            Ok(TidyFrame::from_df(df))
694        }
695
696        ViewNode::Join {
697            left,
698            right,
699            on,
700            kind,
701        } => {
702            let left_frame = execute(*left)?;
703            let right_frame = execute(*right)?;
704            let left_view = left_frame.view();
705            let right_view = right_frame.view();
706            let on_refs: Vec<(&str, &str)> = on
707                .iter()
708                .map(|(l, r)| (l.as_str(), r.as_str()))
709                .collect();
710
711            match kind {
712                JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
713                JoinType::Left => left_view.left_join(&right_view, &on_refs),
714                JoinType::Semi => {
715                    let result = left_view.semi_join(&right_view, &on_refs)?;
716                    let df = result.materialize()?;
717                    Ok(TidyFrame::from_df(df))
718                }
719                JoinType::Anti => {
720                    let result = left_view.anti_join(&right_view, &on_refs)?;
721                    let df = result.materialize()?;
722                    Ok(TidyFrame::from_df(df))
723                }
724            }
725        }
726    }
727}
728
729// ── Helper: column reference extraction from DExpr ───────────────────────────
730
731/// Collect all column names referenced by an expression.
732fn expr_columns(expr: &DExpr) -> BTreeSet<String> {
733    let mut cols = BTreeSet::new();
734    collect_expr_cols(expr, &mut cols);
735    cols
736}
737
738fn collect_expr_cols(expr: &DExpr, cols: &mut BTreeSet<String>) {
739    match expr {
740        DExpr::Col(name) => {
741            cols.insert(name.clone());
742        }
743        DExpr::BinOp { left, right, .. } => {
744            collect_expr_cols(left, cols);
745            collect_expr_cols(right, cols);
746        }
747        DExpr::Agg(_, inner) => collect_expr_cols(inner, cols),
748        DExpr::FnCall(_, args) => {
749            for arg in args {
750                collect_expr_cols(arg, cols);
751            }
752        }
753        DExpr::CumSum(e)
754        | DExpr::CumProd(e)
755        | DExpr::CumMax(e)
756        | DExpr::CumMin(e)
757        | DExpr::Lag(e, _)
758        | DExpr::Lead(e, _)
759        | DExpr::Rank(e)
760        | DExpr::DenseRank(e) => {
761            collect_expr_cols(e, cols);
762        }
763        // Rolling window functions reference a column by name (String field).
764        DExpr::RollingSum(col, _)
765        | DExpr::RollingMean(col, _)
766        | DExpr::RollingMin(col, _)
767        | DExpr::RollingMax(col, _)
768        | DExpr::RollingVar(col, _)
769        | DExpr::RollingSd(col, _) => {
770            cols.insert(col.clone());
771        }
772        DExpr::LitInt(_)
773        | DExpr::LitFloat(_)
774        | DExpr::LitBool(_)
775        | DExpr::LitStr(_)
776        | DExpr::Count
777        | DExpr::RowNumber => {}
778    }
779}
780
781// ── Helper: infer output columns of a ViewNode ───────────────────────────────
782
783/// Return the set of column names that a node produces.
784///
785/// For Scan nodes, reads from the DataFrame directly.
786/// For others, infers from the node type.
787fn node_output_columns(node: &ViewNode) -> BTreeSet<String> {
788    match node {
789        ViewNode::Scan { df } => df.column_names().into_iter().map(|s| s.to_string()).collect(),
790        ViewNode::Filter { input, .. } => node_output_columns(input),
791        ViewNode::Select { columns, .. } => columns.iter().cloned().collect(),
792        ViewNode::Mutate {
793            input,
794            assignments,
795        } => {
796            let mut cols = node_output_columns(input);
797            for (name, _) in assignments {
798                cols.insert(name.clone());
799            }
800            cols
801        }
802        ViewNode::Arrange { input, .. } => node_output_columns(input),
803        ViewNode::GroupSummarise {
804            group_keys,
805            aggregations,
806            ..
807        } => {
808            let mut cols: BTreeSet<String> = group_keys.iter().cloned().collect();
809            for (name, _) in aggregations {
810                cols.insert(name.clone());
811            }
812            cols
813        }
814        ViewNode::StreamingGroupSummarise {
815            group_keys,
816            aggregations,
817            ..
818        } => {
819            let mut cols: BTreeSet<String> = group_keys.iter().cloned().collect();
820            for (name, _) in aggregations {
821                cols.insert(name.clone());
822            }
823            cols
824        }
825        ViewNode::Distinct { input, .. } => node_output_columns(input),
826        ViewNode::Join {
827            left, right, on, ..
828        } => {
829            let mut cols = node_output_columns(left);
830            let right_cols = node_output_columns(right);
831            // Right join keys that duplicate left keys are excluded in output
832            let left_keys: BTreeSet<&str> = on.iter().map(|(l, _)| l.as_str()).collect();
833            let right_keys: BTreeSet<&str> = on.iter().map(|(_, r)| r.as_str()).collect();
834            for c in &right_cols {
835                if !right_keys.contains(c.as_str()) || !left_keys.contains(c.as_str()) {
836                    cols.insert(c.clone());
837                }
838            }
839            cols
840        }
841    }
842}
843
844// ── Helper: leak a String into a &'static str for API compatibility ──────────
845
846/// Convert a String to &'static str by leaking.
847///
848/// This is used only during plan execution (bounded number of calls per plan).
849/// The leaked memory is small (column name strings) and proportional to
850/// the plan size, not the data size.
851fn leaked_str(s: &str) -> &'static str {
852    Box::leak(s.to_string().into_boxed_str())
853}
854
855// ── Plan inspection helpers (for testing) ────────────────────────────────────
856
857impl ViewNode {
858    /// Count the number of Filter nodes in the tree.
859    pub fn count_filters(&self) -> usize {
860        match self {
861            ViewNode::Filter { input, .. } => 1 + input.count_filters(),
862            ViewNode::Select { input, .. } => input.count_filters(),
863            ViewNode::Mutate { input, .. } => input.count_filters(),
864            ViewNode::Arrange { input, .. } => input.count_filters(),
865            ViewNode::GroupSummarise { input, .. } => input.count_filters(),
866            ViewNode::StreamingGroupSummarise { input, .. } => input.count_filters(),
867            ViewNode::Distinct { input, .. } => input.count_filters(),
868            ViewNode::Join { left, right, .. } => {
869                left.count_filters() + right.count_filters()
870            }
871            ViewNode::Scan { .. } => 0,
872        }
873    }
874
875    /// Check if the immediate child (input) of the outermost node is a Scan.
876    /// Useful for verifying predicate pushdown moved a filter near the scan.
877    pub fn is_filter_on_scan(&self) -> bool {
878        match self {
879            ViewNode::Filter { input, .. } => matches!(input.as_ref(), ViewNode::Scan { .. }),
880            _ => false,
881        }
882    }
883
884    /// Return the innermost node (the leaf Scan) by walking `input` chains.
885    pub fn innermost(&self) -> &ViewNode {
886        match self {
887            ViewNode::Filter { input, .. }
888            | ViewNode::Select { input, .. }
889            | ViewNode::Mutate { input, .. }
890            | ViewNode::Arrange { input, .. }
891            | ViewNode::GroupSummarise { input, .. }
892            | ViewNode::StreamingGroupSummarise { input, .. }
893            | ViewNode::Distinct { input, .. } => input.innermost(),
894            ViewNode::Join { left, .. } => left.innermost(),
895            ViewNode::Scan { .. } => self,
896        }
897    }
898
899    /// Return the node kind name (for test assertions).
900    pub fn kind(&self) -> &'static str {
901        match self {
902            ViewNode::Scan { .. } => "Scan",
903            ViewNode::Filter { .. } => "Filter",
904            ViewNode::Select { .. } => "Select",
905            ViewNode::Mutate { .. } => "Mutate",
906            ViewNode::Arrange { .. } => "Arrange",
907            ViewNode::GroupSummarise { .. } => "GroupSummarise",
908            ViewNode::StreamingGroupSummarise { .. } => "StreamingGroupSummarise",
909            ViewNode::Distinct { .. } => "Distinct",
910            ViewNode::Join { .. } => "Join",
911        }
912    }
913
914    /// Walk the plan tree depth-first and collect node kinds top-down.
915    pub fn node_kinds(&self) -> Vec<&'static str> {
916        let mut out = vec![self.kind()];
917        match self {
918            ViewNode::Filter { input, .. }
919            | ViewNode::Select { input, .. }
920            | ViewNode::Mutate { input, .. }
921            | ViewNode::Arrange { input, .. }
922            | ViewNode::GroupSummarise { input, .. }
923            | ViewNode::StreamingGroupSummarise { input, .. }
924            | ViewNode::Distinct { input, .. } => {
925                out.extend(input.node_kinds());
926            }
927            ViewNode::Join { left, right, .. } => {
928                out.extend(left.node_kinds());
929                out.extend(right.node_kinds());
930            }
931            ViewNode::Scan { .. } => {}
932        }
933        out
934    }
935}
936
937// ── Batch Executor ────────────────────────────────────────────────────────────
938
939/// Maximum rows per batch for vectorized processing.
940const BATCH_SIZE: usize = 2048;
941
942/// A chunk of up to `BATCH_SIZE` rows for vectorized processing.
943///
944/// Batches are processed sequentially in order (batch 0, batch 1, ...)
945/// to preserve deterministic row ordering.
946#[derive(Debug, Clone)]
947pub struct Batch {
948    pub columns: Vec<(String, Column)>,
949    pub nrows: usize,
950}
951
952impl Batch {
953    /// Convert this batch into a DataFrame.
954    fn into_dataframe(self) -> DataFrame {
955        DataFrame {
956            columns: self.columns,
957        }
958    }
959
960    /// Get a column by name.
961    fn get_column(&self, name: &str) -> Option<&Column> {
962        self.columns.iter().find(|(n, _)| n == name).map(|(_, c)| c)
963    }
964
965    /// Column names in order.
966    fn column_names(&self) -> Vec<&str> {
967        self.columns.iter().map(|(n, _)| n.as_str()).collect()
968    }
969}
970
971/// Slice a column from `start..end`.
972fn slice_column(col: &Column, start: usize, end: usize) -> Column {
973    if matches!(col, Column::CategoricalAdaptive(_)) {
974        return slice_column(&col.to_legacy_categorical(), start, end);
975    }
976    match col {
977        Column::Float(v) => Column::Float(v[start..end].to_vec()),
978        Column::Int(v) => Column::Int(v[start..end].to_vec()),
979        Column::Str(v) => Column::Str(v[start..end].to_vec()),
980        Column::Bool(v) => Column::Bool(v[start..end].to_vec()),
981        Column::Categorical { levels, codes } => Column::Categorical {
982            levels: levels.clone(),
983            codes: codes[start..end].to_vec(),
984        },
985        Column::DateTime(v) => Column::DateTime(v[start..end].to_vec()),
986        Column::CategoricalAdaptive(_) => unreachable!("handled by early return"),
987    }
988}
989
990/// Split a DataFrame into batches of up to `BATCH_SIZE` rows.
991fn split_batches(df: &DataFrame) -> Vec<Batch> {
992    let nrows = df.nrows();
993    if nrows == 0 {
994        return vec![Batch {
995            columns: df.columns.iter().map(|(n, c)| {
996                (n.clone(), slice_column(c, 0, 0))
997            }).collect(),
998            nrows: 0,
999        }];
1000    }
1001    let mut batches = Vec::new();
1002    let mut offset = 0;
1003    while offset < nrows {
1004        let end = (offset + BATCH_SIZE).min(nrows);
1005        let batch_cols = df
1006            .columns
1007            .iter()
1008            .map(|(name, col)| (name.clone(), slice_column(col, offset, end)))
1009            .collect();
1010        batches.push(Batch {
1011            columns: batch_cols,
1012            nrows: end - offset,
1013        });
1014        offset = end;
1015    }
1016    batches
1017}
1018
1019/// Merge a vector of batches back into a single DataFrame.
1020///
1021/// Batches must have identical column schemas. Empty batches are skipped.
1022fn merge_batches(batches: Vec<Batch>) -> Result<DataFrame, TidyError> {
1023    if batches.is_empty() {
1024        return Ok(DataFrame::new());
1025    }
1026
1027    // Determine schema from first non-empty batch (or just first batch).
1028    let schema: Vec<String> = batches[0].column_names().iter().map(|s| s.to_string()).collect();
1029    if schema.is_empty() {
1030        return Ok(DataFrame::new());
1031    }
1032
1033    // Pre-allocate merged columns.
1034    let total_rows: usize = batches.iter().map(|b| b.nrows).sum();
1035    let mut merged_cols: Vec<(String, Column)> = schema
1036        .iter()
1037        .map(|name| {
1038            // Determine type from first batch's column.
1039            let first_col = batches[0].get_column(name).unwrap();
1040            let empty = match first_col {
1041                Column::Float(_) => Column::Float(Vec::with_capacity(total_rows)),
1042                Column::Int(_) => Column::Int(Vec::with_capacity(total_rows)),
1043                Column::Str(_) => Column::Str(Vec::with_capacity(total_rows)),
1044                Column::Bool(_) => Column::Bool(Vec::with_capacity(total_rows)),
1045                Column::Categorical { levels, .. } => Column::Categorical {
1046                    levels: levels.clone(),
1047                    codes: Vec::with_capacity(total_rows),
1048                },
1049                Column::CategoricalAdaptive(_) => {
1050                    // Empty buffer matching legacy categorical shape.
1051                    let legacy = first_col.to_legacy_categorical();
1052                    if let Column::Categorical { levels, .. } = legacy {
1053                        Column::Categorical {
1054                            levels,
1055                            codes: Vec::with_capacity(total_rows),
1056                        }
1057                    } else {
1058                        // Non-UTF-8 / null fallback: Str buffer.
1059                        Column::Str(Vec::with_capacity(total_rows))
1060                    }
1061                }
1062                Column::DateTime(_) => Column::DateTime(Vec::with_capacity(total_rows)),
1063            };
1064            (name.clone(), empty)
1065        })
1066        .collect();
1067
1068    // Append each batch's data.
1069    for batch in &batches {
1070        if batch.nrows == 0 {
1071            continue;
1072        }
1073        for (i, (name, merged_col)) in merged_cols.iter_mut().enumerate() {
1074            let batch_col = batch.get_column(name).ok_or_else(|| {
1075                TidyError::ColumnNotFound(format!(
1076                    "batch merge: column '{}' missing in batch (index {})",
1077                    name, i
1078                ))
1079            })?;
1080            append_column(merged_col, batch_col);
1081        }
1082    }
1083
1084    Ok(DataFrame { columns: merged_cols })
1085}
1086
1087/// Append all rows from `src` into `dst` (same type assumed).
1088fn append_column(dst: &mut Column, src: &Column) {
1089    match (dst, src) {
1090        (Column::Float(d), Column::Float(s)) => d.extend_from_slice(s),
1091        (Column::Int(d), Column::Int(s)) => d.extend_from_slice(s),
1092        (Column::Str(d), Column::Str(s)) => d.extend(s.iter().cloned()),
1093        (Column::Bool(d), Column::Bool(s)) => d.extend_from_slice(s),
1094        (Column::Categorical { codes: d, .. }, Column::Categorical { codes: s, .. }) => {
1095            d.extend_from_slice(s);
1096        }
1097        (Column::DateTime(d), Column::DateTime(s)) => d.extend_from_slice(s),
1098        _ => {} // Type mismatch: should not happen if schema is consistent.
1099    }
1100}
1101
1102// ── Streamable operation representation ──────────────────────────────────────
1103
1104/// A streamable (non-breaking) operation that can be applied per-batch.
1105#[derive(Debug, Clone)]
1106enum StreamableOp {
1107    Filter { predicate: DExpr },
1108    Select { columns: Vec<String> },
1109    Mutate { assignments: Vec<(String, DExpr)> },
1110}
1111
1112/// Returns true if the node is a pipeline breaker (requires full materialization).
1113fn is_pipeline_breaker(node: &ViewNode) -> bool {
1114    matches!(
1115        node,
1116        ViewNode::Arrange { .. }
1117            | ViewNode::GroupSummarise { .. }
1118            | ViewNode::StreamingGroupSummarise { .. }
1119            | ViewNode::Distinct { .. }
1120            | ViewNode::Join { .. }
1121    )
1122}
1123
1124/// Walk the plan tree and collect a chain of streamable operations from the top.
1125///
1126/// Returns `(streamable_ops_in_execution_order, base_node)`.
1127/// The chain is collected top-down (outermost first), then reversed so the
1128/// innermost (closest to scan) operation is applied first.
1129fn collect_streamable_chain(node: ViewNode) -> (Vec<StreamableOp>, Box<ViewNode>) {
1130    let mut ops = Vec::new();
1131    let mut current = node;
1132
1133    loop {
1134        match current {
1135            ViewNode::Filter { input, predicate } => {
1136                ops.push(StreamableOp::Filter { predicate });
1137                current = *input;
1138            }
1139            ViewNode::Select { input, columns } => {
1140                ops.push(StreamableOp::Select { columns });
1141                current = *input;
1142            }
1143            ViewNode::Mutate { input, assignments } => {
1144                ops.push(StreamableOp::Mutate { assignments });
1145                current = *input;
1146            }
1147            // Any other node is the base (Scan or a pipeline breaker).
1148            other => {
1149                // Reverse: we collected outermost-first, but need to apply innermost-first.
1150                ops.reverse();
1151                return (ops, Box::new(other));
1152            }
1153        }
1154    }
1155}
1156
1157/// Apply a single streamable operation to a batch.
1158fn apply_op_to_batch(batch: Batch, op: &StreamableOp) -> Result<Batch, TidyError> {
1159    match op {
1160        StreamableOp::Filter { predicate } => {
1161            // Materialize batch into a temporary DataFrame, apply filter via TidyView.
1162            let df = batch.into_dataframe();
1163            if df.nrows() == 0 {
1164                return Ok(Batch {
1165                    nrows: 0,
1166                    columns: df.columns,
1167                });
1168            }
1169            let frame = TidyFrame::from_df(df);
1170            let view = frame.view();
1171            let filtered = view.filter(predicate)?;
1172            let result_df = filtered.materialize()?;
1173            let nrows = result_df.nrows();
1174            Ok(Batch {
1175                columns: result_df.columns,
1176                nrows,
1177            })
1178        }
1179        StreamableOp::Select { columns } => {
1180            // Keep only the named columns, in the requested order.
1181            let selected: Vec<(String, Column)> = columns
1182                .iter()
1183                .filter_map(|name| {
1184                    batch
1185                        .columns
1186                        .iter()
1187                        .find(|(n, _)| n == name)
1188                        .cloned()
1189                })
1190                .collect();
1191            Ok(Batch {
1192                nrows: batch.nrows,
1193                columns: selected,
1194            })
1195        }
1196        StreamableOp::Mutate { assignments } => {
1197            // Materialize batch into a temporary DataFrame, apply mutate via TidyView.
1198            let df = batch.into_dataframe();
1199            let frame = TidyFrame::from_df(df);
1200            let view = frame.view();
1201            let assign_refs: Vec<(&str, DExpr)> = assignments
1202                .iter()
1203                .map(|(name, expr)| (leaked_str(name), expr.clone()))
1204                .collect();
1205            let result = view.mutate(
1206                &assign_refs
1207                    .iter()
1208                    .map(|(n, e)| (*n, e.clone()))
1209                    .collect::<Vec<_>>(),
1210            )?;
1211            let result_df = result.borrow().clone();
1212            let nrows = result_df.nrows();
1213            Ok(Batch {
1214                columns: result_df.columns,
1215                nrows,
1216            })
1217        }
1218    }
1219}
1220
1221/// Apply a chain of streamable operations to a DataFrame in batches.
1222fn apply_chain_batched(
1223    frame: &TidyFrame,
1224    chain: &[StreamableOp],
1225) -> Result<TidyFrame, TidyError> {
1226    let df = frame.borrow().clone();
1227    let batches = split_batches(&df);
1228
1229    let mut result_batches: Vec<Batch> = Vec::new();
1230    for batch in batches {
1231        let mut current = batch;
1232        for op in chain {
1233            current = apply_op_to_batch(current, op)?;
1234        }
1235        if current.nrows > 0 {
1236            result_batches.push(current);
1237        }
1238    }
1239
1240    if result_batches.is_empty() {
1241        // Preserve schema from original DataFrame but with zero rows.
1242        let empty_df = DataFrame {
1243            columns: df
1244                .columns
1245                .iter()
1246                .map(|(name, col)| {
1247                    (name.clone(), slice_column(col, 0, 0))
1248                })
1249                .collect(),
1250        };
1251        // If chain includes a Select, apply column pruning to the empty frame.
1252        let mut result_cols: Option<Vec<String>> = None;
1253        for op in chain {
1254            if let StreamableOp::Select { columns } = op {
1255                result_cols = Some(columns.clone());
1256            }
1257        }
1258        if let Some(cols) = result_cols {
1259            let pruned: Vec<(String, Column)> = cols
1260                .iter()
1261                .filter_map(|name| {
1262                    empty_df
1263                        .columns
1264                        .iter()
1265                        .find(|(n, _)| n == name)
1266                        .cloned()
1267                })
1268                .collect();
1269            return Ok(TidyFrame::from_df(DataFrame { columns: pruned }));
1270        }
1271        return Ok(TidyFrame::from_df(empty_df));
1272    }
1273
1274    let merged = merge_batches(result_batches)?;
1275    Ok(TidyFrame::from_df(merged))
1276}
1277
1278/// Execute an optimized ViewNode tree using batch processing where possible.
1279///
1280/// Streamable operations (Filter, Select, Mutate) are fused into a single
1281/// batch pass over 2048-row chunks. Pipeline breakers (Arrange, GroupSummarise,
1282/// Distinct, Join) force full materialization before proceeding.
1283///
1284/// # Determinism
1285///
1286/// Batches are processed sequentially in row order. The merged output has
1287/// rows in the same order as non-batched execution. Float reductions use
1288/// Kahan summation (delegated to TidyView internals).
1289pub fn execute_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
1290    match &node {
1291        // Leaf: just materialize.
1292        ViewNode::Scan { .. } => execute(node),
1293
1294        // Streamable operations: collect chain and batch-execute.
1295        _ if !is_pipeline_breaker(&node) => {
1296            let (chain, base) = collect_streamable_chain(node);
1297            if chain.is_empty() {
1298                // Shouldn't happen, but handle gracefully.
1299                return execute_batched(*base);
1300            }
1301            let base_frame = execute_batched(*base)?;
1302            apply_chain_batched(&base_frame, &chain)
1303        }
1304
1305        // Pipeline breakers: execute children batched, then apply breaker eagerly.
1306        _ => execute_breaker_batched(node),
1307    }
1308}
1309
1310/// Execute a pipeline-breaking node by first executing its children via
1311/// batch processing, then applying the breaker operation eagerly.
1312fn execute_breaker_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
1313    match node {
1314        ViewNode::Arrange { input, keys } => {
1315            let frame = execute_batched(*input)?;
1316            let view = frame.view();
1317            let arranged = view.arrange(&keys)?;
1318            let df = arranged.materialize()?;
1319            Ok(TidyFrame::from_df(df))
1320        }
1321
1322        ViewNode::GroupSummarise {
1323            input,
1324            group_keys,
1325            aggregations,
1326        } => {
1327            let frame = execute_batched(*input)?;
1328            let view = frame.view();
1329            let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
1330            let grouped = view.group_by(&key_refs)?;
1331            let agg_refs: Vec<(&str, TidyAgg)> = aggregations
1332                .into_iter()
1333                .map(|(name, agg)| (leaked_str(&name), agg))
1334                .collect();
1335            let result = grouped.summarise(
1336                &agg_refs
1337                    .iter()
1338                    .map(|(n, a)| (*n, a.clone()))
1339                    .collect::<Vec<_>>(),
1340            )?;
1341            Ok(result)
1342        }
1343
1344        ViewNode::StreamingGroupSummarise {
1345            input,
1346            group_keys,
1347            aggregations,
1348        } => {
1349            let frame = execute_batched(*input)?;
1350            let view = frame.view();
1351            let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
1352            let agg_owned: Vec<(String, crate::StreamingAgg)> = aggregations;
1353            let agg_refs: Vec<(&str, crate::StreamingAgg)> = agg_owned
1354                .iter()
1355                .map(|(name, sa)| (leaked_str(name), sa.clone()))
1356                .collect();
1357            view.summarise_streaming(&key_refs, &agg_refs)
1358        }
1359
1360        ViewNode::Distinct { input, columns } => {
1361            let frame = execute_batched(*input)?;
1362            let view = frame.view();
1363            let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
1364            let distinct = view.distinct(&col_refs)?;
1365            let df = distinct.materialize()?;
1366            Ok(TidyFrame::from_df(df))
1367        }
1368
1369        ViewNode::Join {
1370            left,
1371            right,
1372            on,
1373            kind,
1374        } => {
1375            let left_frame = execute_batched(*left)?;
1376            let right_frame = execute_batched(*right)?;
1377            let left_view = left_frame.view();
1378            let right_view = right_frame.view();
1379            let on_refs: Vec<(&str, &str)> =
1380                on.iter().map(|(l, r)| (l.as_str(), r.as_str())).collect();
1381
1382            match kind {
1383                JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
1384                JoinType::Left => left_view.left_join(&right_view, &on_refs),
1385                JoinType::Semi => {
1386                    let result = left_view.semi_join(&right_view, &on_refs)?;
1387                    let df = result.materialize()?;
1388                    Ok(TidyFrame::from_df(df))
1389                }
1390                JoinType::Anti => {
1391                    let result = left_view.anti_join(&right_view, &on_refs)?;
1392                    let df = result.materialize()?;
1393                    Ok(TidyFrame::from_df(df))
1394                }
1395            }
1396        }
1397
1398        // Non-breakers should not reach here, but handle gracefully.
1399        other => execute(other),
1400    }
1401}
1402
1403impl LazyView {
1404    /// Optimize and execute the plan using batch processing.
1405    ///
1406    /// This is an alternative to `collect()` that processes data in
1407    /// 2048-row batches, fusing chains of streamable operations
1408    /// (Filter, Select, Mutate) into a single pass per batch.
1409    ///
1410    /// Pipeline breakers (Arrange, GroupSummarise, Distinct, Join) cause
1411    /// full materialization at that point.
1412    ///
1413    /// The output is identical to `collect()` -- this is purely an
1414    /// execution strategy optimization.
1415    pub fn collect_batched(self) -> Result<TidyFrame, TidyError> {
1416        let optimized = optimize(self.plan);
1417        execute_batched(optimized)
1418    }
1419}
1420
1421// ══════════════════════════════════════════════════════════════════════════════
1422// Tests
1423// ══════════════════════════════════════════════════════════════════════════════
1424
1425#[cfg(test)]
1426mod tests {
1427    use super::*;
1428    use crate::{Column, DExpr, DBinOp, DataFrame, TidyAgg, ArrangeKey, TidyView};
1429
1430    /// Build a small test DataFrame: name(Str), age(Int), score(Float).
1431    fn test_df() -> DataFrame {
1432        DataFrame {
1433            columns: vec![
1434                (
1435                    "name".to_string(),
1436                    Column::Str(vec![
1437                        "Alice".into(),
1438                        "Bob".into(),
1439                        "Carol".into(),
1440                        "Dave".into(),
1441                    ]),
1442                ),
1443                ("age".to_string(), Column::Int(vec![30, 25, 35, 25])),
1444                (
1445                    "score".to_string(),
1446                    Column::Float(vec![90.0, 85.0, 95.0, 80.0]),
1447                ),
1448            ],
1449        }
1450    }
1451
1452    /// Build a second DataFrame for join tests: name(Str), dept(Str).
1453    fn dept_df() -> DataFrame {
1454        DataFrame {
1455            columns: vec![
1456                (
1457                    "name".to_string(),
1458                    Column::Str(vec!["Alice".into(), "Bob".into(), "Eve".into()]),
1459                ),
1460                (
1461                    "dept".to_string(),
1462                    Column::Str(vec!["Eng".into(), "Sales".into(), "Eng".into()]),
1463                ),
1464            ],
1465        }
1466    }
1467
1468    // ── Basic lazy chain produces same result as eager ────────────────────
1469
1470    #[test]
1471    fn lazy_filter_matches_eager() {
1472        let df = test_df();
1473        let predicate = DExpr::BinOp {
1474            op: DBinOp::Gt,
1475            left: Box::new(DExpr::Col("age".into())),
1476            right: Box::new(DExpr::LitInt(25)),
1477        };
1478
1479        // Eager
1480        let eager_view = TidyView::from_df(df.clone());
1481        let eager_filtered = eager_view.filter(&predicate).unwrap();
1482        let eager_df = eager_filtered.materialize().unwrap();
1483
1484        // Lazy
1485        let lazy_frame = LazyView::from_df(df)
1486            .filter(predicate)
1487            .collect()
1488            .unwrap();
1489        let lazy_df = lazy_frame.borrow();
1490
1491        assert_eq!(eager_df.nrows(), lazy_df.nrows());
1492        assert_eq!(eager_df.nrows(), 2); // Alice(30) and Carol(35)
1493
1494        // Verify same names
1495        let eager_names: Vec<String> = match eager_df.get_column("name").unwrap() {
1496            Column::Str(v) => v.clone(),
1497            _ => panic!("expected Str"),
1498        };
1499        let lazy_names: Vec<String> = match lazy_df.get_column("name").unwrap() {
1500            Column::Str(v) => v.clone(),
1501            _ => panic!("expected Str"),
1502        };
1503        assert_eq!(eager_names, lazy_names);
1504    }
1505
1506    #[test]
1507    fn lazy_select_matches_eager() {
1508        let df = test_df();
1509
1510        // Eager
1511        let eager_view = TidyView::from_df(df.clone());
1512        let eager_selected = eager_view.select(&["name", "age"]).unwrap();
1513        let eager_df = eager_selected.materialize().unwrap();
1514
1515        // Lazy
1516        let lazy_frame = LazyView::from_df(df)
1517            .select(vec!["name".into(), "age".into()])
1518            .collect()
1519            .unwrap();
1520        let lazy_df = lazy_frame.borrow();
1521
1522        assert_eq!(eager_df.ncols(), 2);
1523        assert_eq!(lazy_df.ncols(), 2);
1524        assert_eq!(eager_df.column_names(), lazy_df.column_names());
1525    }
1526
1527    #[test]
1528    fn lazy_arrange_matches_eager() {
1529        let df = test_df();
1530        let keys = vec![ArrangeKey::asc("age")];
1531
1532        // Eager
1533        let eager_view = TidyView::from_df(df.clone());
1534        let eager_arranged = eager_view.arrange(&keys).unwrap();
1535        let eager_df = eager_arranged.materialize().unwrap();
1536
1537        // Lazy
1538        let lazy_frame = LazyView::from_df(df)
1539            .arrange(keys)
1540            .collect()
1541            .unwrap();
1542        let lazy_df = lazy_frame.borrow();
1543
1544        let eager_ages = match eager_df.get_column("age").unwrap() {
1545            Column::Int(v) => v.clone(),
1546            _ => panic!("expected Int"),
1547        };
1548        let lazy_ages = match lazy_df.get_column("age").unwrap() {
1549            Column::Int(v) => v.clone(),
1550            _ => panic!("expected Int"),
1551        };
1552        assert_eq!(eager_ages, lazy_ages);
1553        // Should be sorted ascending
1554        assert_eq!(eager_ages, vec![25, 25, 30, 35]);
1555    }
1556
1557    #[test]
1558    fn lazy_group_summarise_matches_eager() {
1559        let df = test_df();
1560
1561        // Eager
1562        let eager_view = TidyView::from_df(df.clone());
1563        let grouped = eager_view.group_by(&["age"]).unwrap();
1564        let eager_frame = grouped
1565            .summarise(&[("count", TidyAgg::Count)])
1566            .unwrap();
1567        let eager_df = eager_frame.borrow();
1568
1569        // Lazy
1570        let lazy_frame = LazyView::from_df(df)
1571            .group_summarise(
1572                vec!["age".into()],
1573                vec![("count".into(), TidyAgg::Count)],
1574            )
1575            .collect()
1576            .unwrap();
1577        let lazy_df = lazy_frame.borrow();
1578
1579        assert_eq!(eager_df.nrows(), lazy_df.nrows());
1580        assert_eq!(eager_df.column_names(), lazy_df.column_names());
1581    }
1582
1583    // ── Predicate pushdown ───────────────────────────────────────────────
1584
1585    #[test]
1586    fn predicate_pushdown_past_select() {
1587        let df = test_df();
1588        let predicate = DExpr::BinOp {
1589            op: DBinOp::Gt,
1590            left: Box::new(DExpr::Col("age".into())),
1591            right: Box::new(DExpr::LitInt(25)),
1592        };
1593
1594        // Build: Scan -> Select -> Filter
1595        let lazy = LazyView::from_df(df)
1596            .select(vec!["name".into(), "age".into()])
1597            .filter(predicate);
1598
1599        let optimized = lazy.optimized_plan();
1600
1601        // After pushdown, the filter should be below the select.
1602        // Plan should be: Select -> Filter -> Scan
1603        let kinds = optimized.node_kinds();
1604        assert_eq!(kinds, vec!["Select", "Filter", "Scan"]);
1605    }
1606
1607    #[test]
1608    fn predicate_pushdown_past_arrange() {
1609        let df = test_df();
1610        let predicate = DExpr::BinOp {
1611            op: DBinOp::Gt,
1612            left: Box::new(DExpr::Col("age".into())),
1613            right: Box::new(DExpr::LitInt(25)),
1614        };
1615
1616        // Build: Scan -> Arrange -> Filter
1617        let lazy = LazyView::from_df(df)
1618            .arrange(vec![ArrangeKey::asc("age")])
1619            .filter(predicate);
1620
1621        let optimized = lazy.optimized_plan();
1622
1623        // Filter should be pushed below Arrange.
1624        let kinds = optimized.node_kinds();
1625        assert_eq!(kinds, vec!["Arrange", "Filter", "Scan"]);
1626    }
1627
1628    #[test]
1629    fn predicate_not_pushed_past_mutate_when_dependent() {
1630        let df = test_df();
1631        // Mutate adds "doubled_age" = age * 2
1632        // Filter on "doubled_age" > 50 -- references mutated column, cannot push.
1633        let predicate = DExpr::BinOp {
1634            op: DBinOp::Gt,
1635            left: Box::new(DExpr::Col("doubled_age".into())),
1636            right: Box::new(DExpr::LitInt(50)),
1637        };
1638
1639        let lazy = LazyView::from_df(df)
1640            .mutate(vec![(
1641                "doubled_age".into(),
1642                DExpr::BinOp {
1643                    op: DBinOp::Mul,
1644                    left: Box::new(DExpr::Col("age".into())),
1645                    right: Box::new(DExpr::LitInt(2)),
1646                },
1647            )])
1648            .filter(predicate);
1649
1650        let optimized = lazy.optimized_plan();
1651
1652        // Filter should stay ABOVE Mutate (cannot push).
1653        let kinds = optimized.node_kinds();
1654        assert_eq!(kinds, vec!["Filter", "Mutate", "Scan"]);
1655    }
1656
1657    #[test]
1658    fn predicate_pushed_past_mutate_when_independent() {
1659        let df = test_df();
1660        // Mutate adds "doubled_age" = age * 2
1661        // Filter on "score" > 85 -- does NOT reference mutated column, can push.
1662        let predicate = DExpr::BinOp {
1663            op: DBinOp::Gt,
1664            left: Box::new(DExpr::Col("score".into())),
1665            right: Box::new(DExpr::LitFloat(85.0)),
1666        };
1667
1668        let lazy = LazyView::from_df(df)
1669            .mutate(vec![(
1670                "doubled_age".into(),
1671                DExpr::BinOp {
1672                    op: DBinOp::Mul,
1673                    left: Box::new(DExpr::Col("age".into())),
1674                    right: Box::new(DExpr::LitInt(2)),
1675                },
1676            )])
1677            .filter(predicate);
1678
1679        let optimized = lazy.optimized_plan();
1680
1681        // Filter should be pushed below Mutate.
1682        let kinds = optimized.node_kinds();
1683        assert_eq!(kinds, vec!["Mutate", "Filter", "Scan"]);
1684    }
1685
1686    #[test]
1687    fn predicate_not_pushed_past_group_summarise() {
1688        let df = test_df();
1689        let predicate = DExpr::BinOp {
1690            op: DBinOp::Gt,
1691            left: Box::new(DExpr::Col("count".into())),
1692            right: Box::new(DExpr::LitInt(1)),
1693        };
1694
1695        let lazy = LazyView::from_df(df)
1696            .group_summarise(
1697                vec!["age".into()],
1698                vec![("count".into(), TidyAgg::Count)],
1699            )
1700            .filter(predicate);
1701
1702        let optimized = lazy.optimized_plan();
1703
1704        // Filter must stay above the group node. v3 Phase 6: the lazy
1705        // optimizer rewrites Count-only GroupSummarise to
1706        // StreamingGroupSummarise; both are pipeline breakers, so the
1707        // filter staying above either form is the invariant being
1708        // pinned here.
1709        let kinds = optimized.node_kinds();
1710        assert!(
1711            kinds == vec!["Filter", "GroupSummarise", "Scan"]
1712                || kinds == vec!["Filter", "StreamingGroupSummarise", "Scan"],
1713            "filter must stay above the group node, got {:?}",
1714            kinds
1715        );
1716    }
1717
1718    // ── Filter merging ───────────────────────────────────────────────────
1719
1720    #[test]
1721    fn consecutive_filters_merged() {
1722        let df = test_df();
1723        let pred1 = DExpr::BinOp {
1724            op: DBinOp::Gt,
1725            left: Box::new(DExpr::Col("age".into())),
1726            right: Box::new(DExpr::LitInt(20)),
1727        };
1728        let pred2 = DExpr::BinOp {
1729            op: DBinOp::Lt,
1730            left: Box::new(DExpr::Col("score".into())),
1731            right: Box::new(DExpr::LitFloat(95.0)),
1732        };
1733
1734        let lazy = LazyView::from_df(df).filter(pred1).filter(pred2);
1735
1736        let optimized = lazy.optimized_plan();
1737
1738        // Should have only 1 filter node (merged), not 2.
1739        assert_eq!(optimized.count_filters(), 1);
1740
1741        // The merged filter should produce correct results.
1742        // age > 20 AND score < 95 => Alice(30,90), Bob(25,85), Dave(25,80)
1743        let df2 = test_df();
1744        let result = LazyView::from_df(df2)
1745            .filter(DExpr::BinOp {
1746                op: DBinOp::Gt,
1747                left: Box::new(DExpr::Col("age".into())),
1748                right: Box::new(DExpr::LitInt(20)),
1749            })
1750            .filter(DExpr::BinOp {
1751                op: DBinOp::Lt,
1752                left: Box::new(DExpr::Col("score".into())),
1753                right: Box::new(DExpr::LitFloat(95.0)),
1754            })
1755            .collect()
1756            .unwrap();
1757
1758        let result_df = result.borrow();
1759        assert_eq!(result_df.nrows(), 3);
1760    }
1761
1762    // ── Redundant select elimination ─────────────────────────────────────
1763
1764    #[test]
1765    fn redundant_select_eliminated() {
1766        let df = test_df();
1767
1768        // Select all 3 columns (same as input) -- should be eliminated.
1769        let lazy = LazyView::from_df(df)
1770            .select(vec!["name".into(), "age".into(), "score".into()]);
1771
1772        let optimized = lazy.optimized_plan();
1773
1774        // Should be just a Scan (redundant Select removed).
1775        assert_eq!(optimized.kind(), "Scan");
1776    }
1777
1778    #[test]
1779    fn non_redundant_select_kept() {
1780        let df = test_df();
1781
1782        // Select only 2 of 3 columns -- should NOT be eliminated.
1783        let lazy = LazyView::from_df(df).select(vec!["name".into(), "age".into()]);
1784
1785        let optimized = lazy.optimized_plan();
1786
1787        assert_eq!(optimized.kind(), "Select");
1788    }
1789
1790    // ── Determinism ──────────────────────────────────────────────────────
1791
1792    #[test]
1793    fn determinism_3_runs_identical() {
1794        for _ in 0..3 {
1795            let df = test_df();
1796            let result = LazyView::from_df(df)
1797                .filter(DExpr::BinOp {
1798                    op: DBinOp::Gt,
1799                    left: Box::new(DExpr::Col("age".into())),
1800                    right: Box::new(DExpr::LitInt(20)),
1801                })
1802                .select(vec!["name".into(), "age".into()])
1803                .arrange(vec![ArrangeKey::desc("age")])
1804                .collect()
1805                .unwrap();
1806
1807            let result_df = result.borrow();
1808            assert_eq!(result_df.nrows(), 4);
1809
1810            let ages = match result_df.get_column("age").unwrap() {
1811                Column::Int(v) => v.clone(),
1812                _ => panic!("expected Int"),
1813            };
1814            // Descending: 35, 30, 25, 25
1815            assert_eq!(ages, vec![35, 30, 25, 25]);
1816
1817            let names = match result_df.get_column("name").unwrap() {
1818                Column::Str(v) => v.clone(),
1819                _ => panic!("expected Str"),
1820            };
1821            assert_eq!(names, vec!["Carol", "Alice", "Bob", "Dave"]);
1822        }
1823    }
1824
1825    // ── Join execution ───────────────────────────────────────────────────
1826
1827    #[test]
1828    fn lazy_inner_join() {
1829        let left = test_df();
1830        let right = dept_df();
1831
1832        let result = LazyView::from_df(left)
1833            .join(
1834                LazyView::from_df(right),
1835                vec![("name".into(), "name".into())],
1836                JoinType::Inner,
1837            )
1838            .collect()
1839            .unwrap();
1840
1841        let result_df = result.borrow();
1842        // Only Alice and Bob match
1843        assert_eq!(result_df.nrows(), 2);
1844        assert!(result_df.get_column("dept").is_some());
1845    }
1846
1847    #[test]
1848    fn lazy_semi_join() {
1849        let left = test_df();
1850        let right = dept_df();
1851
1852        let result = LazyView::from_df(left)
1853            .join(
1854                LazyView::from_df(right),
1855                vec![("name".into(), "name".into())],
1856                JoinType::Semi,
1857            )
1858            .collect()
1859            .unwrap();
1860
1861        let result_df = result.borrow();
1862        // Semi join: Alice and Bob from left
1863        assert_eq!(result_df.nrows(), 2);
1864        // Semi join should NOT include right columns
1865        assert!(result_df.get_column("dept").is_none());
1866    }
1867
1868    #[test]
1869    fn lazy_anti_join() {
1870        let left = test_df();
1871        let right = dept_df();
1872
1873        let result = LazyView::from_df(left)
1874            .join(
1875                LazyView::from_df(right),
1876                vec![("name".into(), "name".into())],
1877                JoinType::Anti,
1878            )
1879            .collect()
1880            .unwrap();
1881
1882        let result_df = result.borrow();
1883        // Anti join: Carol and Dave (not in right)
1884        assert_eq!(result_df.nrows(), 2);
1885    }
1886
1887    // ── Distinct ─────────────────────────────────────────────────────────
1888
1889    #[test]
1890    fn lazy_distinct() {
1891        let df = test_df();
1892
1893        let result = LazyView::from_df(df)
1894            .distinct(vec!["age".into()])
1895            .collect()
1896            .unwrap();
1897
1898        let result_df = result.borrow();
1899        // 3 distinct ages: 25, 30, 35
1900        assert_eq!(result_df.nrows(), 3);
1901    }
1902
1903    // ── Complex chain ────────────────────────────────────────────────────
1904
1905    #[test]
1906    fn complex_lazy_chain() {
1907        let df = test_df();
1908
1909        // filter(age > 20) -> mutate(bonus = score * 1.1) -> select(name, bonus) -> arrange(bonus desc)
1910        let result = LazyView::from_df(df)
1911            .filter(DExpr::BinOp {
1912                op: DBinOp::Gt,
1913                left: Box::new(DExpr::Col("age".into())),
1914                right: Box::new(DExpr::LitInt(20)),
1915            })
1916            .mutate(vec![(
1917                "bonus".into(),
1918                DExpr::BinOp {
1919                    op: DBinOp::Mul,
1920                    left: Box::new(DExpr::Col("score".into())),
1921                    right: Box::new(DExpr::LitFloat(1.1)),
1922                },
1923            )])
1924            .select(vec!["name".into(), "bonus".into()])
1925            .arrange(vec![ArrangeKey::desc("bonus")])
1926            .collect()
1927            .unwrap();
1928
1929        let result_df = result.borrow();
1930        assert_eq!(result_df.nrows(), 4);
1931        assert_eq!(result_df.ncols(), 2);
1932        assert_eq!(result_df.column_names(), vec!["name", "bonus"]);
1933    }
1934
1935    // ── Predicate pushdown into join ─────────────────────────────────────
1936
1937    #[test]
1938    fn predicate_pushdown_into_join_left_side() {
1939        let left = test_df();
1940        let right = dept_df();
1941
1942        // Join then filter on "age" > 25 -- "age" only exists in left.
1943        let lazy = LazyView::from_df(left)
1944            .join(
1945                LazyView::from_df(right),
1946                vec![("name".into(), "name".into())],
1947                JoinType::Inner,
1948            )
1949            .filter(DExpr::BinOp {
1950                op: DBinOp::Gt,
1951                left: Box::new(DExpr::Col("age".into())),
1952                right: Box::new(DExpr::LitInt(25)),
1953            });
1954
1955        let optimized = lazy.optimized_plan();
1956
1957        // The filter should be pushed into the left side of the join.
1958        let kinds = optimized.node_kinds();
1959        // Expect: Join -> [Filter -> Scan (left), Scan (right)]
1960        assert_eq!(kinds[0], "Join");
1961        // The left subtree should contain a Filter.
1962        if let ViewNode::Join { left, right, .. } = &optimized {
1963            assert_eq!(left.kind(), "Filter");
1964            assert_eq!(right.kind(), "Scan");
1965        } else {
1966            panic!("expected Join at top");
1967        }
1968    }
1969
1970    // ══════════════════════════════════════════════════════════════════════
1971    // Batch executor tests
1972    // ══════════════════════════════════════════════════════════════════════
1973
1974    /// Helper: compare two DataFrames column-by-column for equality.
1975    fn assert_df_eq(a: &DataFrame, b: &DataFrame, context: &str) {
1976        assert_eq!(
1977            a.nrows(),
1978            b.nrows(),
1979            "{}: nrows differ ({} vs {})",
1980            context,
1981            a.nrows(),
1982            b.nrows()
1983        );
1984        assert_eq!(
1985            a.column_names(),
1986            b.column_names(),
1987            "{}: column names differ",
1988            context
1989        );
1990        for (name_a, col_a) in &a.columns {
1991            let col_b = b.get_column(name_a).unwrap_or_else(|| {
1992                panic!("{}: column '{}' missing in b", context, name_a)
1993            });
1994            assert_col_eq(col_a, col_b, &format!("{} col '{}'", context, name_a));
1995        }
1996    }
1997
1998    fn assert_col_eq(a: &Column, b: &Column, context: &str) {
1999        match (a, b) {
2000            (Column::Int(va), Column::Int(vb)) => assert_eq!(va, vb, "{}", context),
2001            (Column::Float(va), Column::Float(vb)) => {
2002                assert_eq!(va.len(), vb.len(), "{}: float len", context);
2003                for (i, (x, y)) in va.iter().zip(vb.iter()).enumerate() {
2004                    assert!(
2005                        (x - y).abs() < 1e-12,
2006                        "{}: float[{}] {} != {}",
2007                        context,
2008                        i,
2009                        x,
2010                        y
2011                    );
2012                }
2013            }
2014            (Column::Str(va), Column::Str(vb)) => assert_eq!(va, vb, "{}", context),
2015            (Column::Bool(va), Column::Bool(vb)) => assert_eq!(va, vb, "{}", context),
2016            _ => panic!("{}: column type mismatch", context),
2017        }
2018    }
2019
2020    // ── Parity: collect_batched == collect ───────────────────────────────
2021
2022    #[test]
2023    fn batched_filter_parity() {
2024        let predicate = DExpr::BinOp {
2025            op: DBinOp::Gt,
2026            left: Box::new(DExpr::Col("age".into())),
2027            right: Box::new(DExpr::LitInt(25)),
2028        };
2029
2030        let eager = LazyView::from_df(test_df())
2031            .filter(predicate.clone())
2032            .collect()
2033            .unwrap();
2034        let batched = LazyView::from_df(test_df())
2035            .filter(predicate)
2036            .collect_batched()
2037            .unwrap();
2038
2039        assert_df_eq(&eager.borrow(), &batched.borrow(), "filter parity");
2040    }
2041
2042    #[test]
2043    fn batched_select_parity() {
2044        let cols = vec!["name".into(), "score".into()];
2045
2046        let eager = LazyView::from_df(test_df())
2047            .select(cols.clone())
2048            .collect()
2049            .unwrap();
2050        let batched = LazyView::from_df(test_df())
2051            .select(cols)
2052            .collect_batched()
2053            .unwrap();
2054
2055        assert_df_eq(&eager.borrow(), &batched.borrow(), "select parity");
2056    }
2057
2058    #[test]
2059    fn batched_mutate_parity() {
2060        let assignments = vec![(
2061            "doubled".into(),
2062            DExpr::BinOp {
2063                op: DBinOp::Mul,
2064                left: Box::new(DExpr::Col("age".into())),
2065                right: Box::new(DExpr::LitInt(2)),
2066            },
2067        )];
2068
2069        let eager = LazyView::from_df(test_df())
2070            .mutate(assignments.clone())
2071            .collect()
2072            .unwrap();
2073        let batched = LazyView::from_df(test_df())
2074            .mutate(assignments)
2075            .collect_batched()
2076            .unwrap();
2077
2078        assert_df_eq(&eager.borrow(), &batched.borrow(), "mutate parity");
2079    }
2080
2081    #[test]
2082    fn batched_filter_select_mutate_chain_parity() {
2083        let predicate = DExpr::BinOp {
2084            op: DBinOp::Gt,
2085            left: Box::new(DExpr::Col("age".into())),
2086            right: Box::new(DExpr::LitInt(20)),
2087        };
2088        let assignments = vec![(
2089            "bonus".into(),
2090            DExpr::BinOp {
2091                op: DBinOp::Mul,
2092                left: Box::new(DExpr::Col("score".into())),
2093                right: Box::new(DExpr::LitFloat(1.1)),
2094            },
2095        )];
2096
2097        let eager = LazyView::from_df(test_df())
2098            .filter(predicate.clone())
2099            .mutate(assignments.clone())
2100            .select(vec!["name".into(), "bonus".into()])
2101            .collect()
2102            .unwrap();
2103        let batched = LazyView::from_df(test_df())
2104            .filter(predicate)
2105            .mutate(assignments)
2106            .select(vec!["name".into(), "bonus".into()])
2107            .collect_batched()
2108            .unwrap();
2109
2110        assert_df_eq(
2111            &eager.borrow(),
2112            &batched.borrow(),
2113            "filter+mutate+select chain parity",
2114        );
2115    }
2116
2117    #[test]
2118    fn batched_group_summarise_parity() {
2119        let eager = LazyView::from_df(test_df())
2120            .group_summarise(
2121                vec!["age".into()],
2122                vec![("count".into(), TidyAgg::Count)],
2123            )
2124            .collect()
2125            .unwrap();
2126        let batched = LazyView::from_df(test_df())
2127            .group_summarise(
2128                vec!["age".into()],
2129                vec![("count".into(), TidyAgg::Count)],
2130            )
2131            .collect_batched()
2132            .unwrap();
2133
2134        assert_df_eq(
2135            &eager.borrow(),
2136            &batched.borrow(),
2137            "group_summarise parity",
2138        );
2139    }
2140
2141    #[test]
2142    fn batched_arrange_parity() {
2143        let keys = vec![ArrangeKey::asc("age")];
2144
2145        let eager = LazyView::from_df(test_df())
2146            .arrange(keys.clone())
2147            .collect()
2148            .unwrap();
2149        let batched = LazyView::from_df(test_df())
2150            .arrange(keys)
2151            .collect_batched()
2152            .unwrap();
2153
2154        assert_df_eq(&eager.borrow(), &batched.borrow(), "arrange parity");
2155    }
2156
2157    #[test]
2158    fn batched_distinct_parity() {
2159        let eager = LazyView::from_df(test_df())
2160            .distinct(vec!["age".into()])
2161            .collect()
2162            .unwrap();
2163        let batched = LazyView::from_df(test_df())
2164            .distinct(vec!["age".into()])
2165            .collect_batched()
2166            .unwrap();
2167
2168        assert_df_eq(&eager.borrow(), &batched.borrow(), "distinct parity");
2169    }
2170
2171    #[test]
2172    fn batched_join_parity() {
2173        let eager = LazyView::from_df(test_df())
2174            .join(
2175                LazyView::from_df(dept_df()),
2176                vec![("name".into(), "name".into())],
2177                JoinType::Inner,
2178            )
2179            .collect()
2180            .unwrap();
2181        let batched = LazyView::from_df(test_df())
2182            .join(
2183                LazyView::from_df(dept_df()),
2184                vec![("name".into(), "name".into())],
2185                JoinType::Inner,
2186            )
2187            .collect_batched()
2188            .unwrap();
2189
2190        assert_df_eq(&eager.borrow(), &batched.borrow(), "join parity");
2191    }
2192
2193    #[test]
2194    fn batched_complex_pipeline_parity() {
2195        // filter -> mutate -> select -> arrange (has a pipeline breaker at end)
2196        let predicate = DExpr::BinOp {
2197            op: DBinOp::Gt,
2198            left: Box::new(DExpr::Col("age".into())),
2199            right: Box::new(DExpr::LitInt(20)),
2200        };
2201        let assignments = vec![(
2202            "bonus".into(),
2203            DExpr::BinOp {
2204                op: DBinOp::Mul,
2205                left: Box::new(DExpr::Col("score".into())),
2206                right: Box::new(DExpr::LitFloat(1.1)),
2207            },
2208        )];
2209
2210        let eager = LazyView::from_df(test_df())
2211            .filter(predicate.clone())
2212            .mutate(assignments.clone())
2213            .select(vec!["name".into(), "bonus".into()])
2214            .arrange(vec![ArrangeKey::desc("bonus")])
2215            .collect()
2216            .unwrap();
2217        let batched = LazyView::from_df(test_df())
2218            .filter(predicate)
2219            .mutate(assignments)
2220            .select(vec!["name".into(), "bonus".into()])
2221            .arrange(vec![ArrangeKey::desc("bonus")])
2222            .collect_batched()
2223            .unwrap();
2224
2225        assert_df_eq(
2226            &eager.borrow(),
2227            &batched.borrow(),
2228            "complex pipeline parity",
2229        );
2230    }
2231
2232    // ── Determinism: 3 runs identical ───────────────────────────────────
2233
2234    #[test]
2235    fn batched_determinism_3_runs() {
2236        let mut results: Vec<Vec<i64>> = Vec::new();
2237        let mut results_names: Vec<Vec<String>> = Vec::new();
2238
2239        for _ in 0..3 {
2240            let result = LazyView::from_df(test_df())
2241                .filter(DExpr::BinOp {
2242                    op: DBinOp::Gt,
2243                    left: Box::new(DExpr::Col("age".into())),
2244                    right: Box::new(DExpr::LitInt(20)),
2245                })
2246                .select(vec!["name".into(), "age".into()])
2247                .arrange(vec![ArrangeKey::desc("age")])
2248                .collect_batched()
2249                .unwrap();
2250
2251            let df = result.borrow();
2252            let ages = match df.get_column("age").unwrap() {
2253                Column::Int(v) => v.clone(),
2254                _ => panic!("expected Int"),
2255            };
2256            let names = match df.get_column("name").unwrap() {
2257                Column::Str(v) => v.clone(),
2258                _ => panic!("expected Str"),
2259            };
2260            results.push(ages);
2261            results_names.push(names);
2262        }
2263
2264        // All 3 runs must be identical.
2265        assert_eq!(results[0], results[1]);
2266        assert_eq!(results[1], results[2]);
2267        assert_eq!(results_names[0], results_names[1]);
2268        assert_eq!(results_names[1], results_names[2]);
2269        // Verify expected values.
2270        assert_eq!(results[0], vec![35, 30, 25, 25]);
2271        assert_eq!(results_names[0], vec!["Carol", "Alice", "Bob", "Dave"]);
2272    }
2273
2274    // ── Large data: batching actually kicks in (>2048 rows) ─────────────
2275
2276    /// Build a large DataFrame with 10,000 rows.
2277    fn large_df() -> DataFrame {
2278        let n = 10_000usize;
2279        let names: Vec<String> = (0..n).map(|i| format!("user_{}", i)).collect();
2280        let ages: Vec<i64> = (0..n).map(|i| (i % 80) as i64 + 18).collect();
2281        let scores: Vec<f64> = (0..n).map(|i| 50.0 + (i % 50) as f64).collect();
2282        DataFrame {
2283            columns: vec![
2284                ("name".to_string(), Column::Str(names)),
2285                ("age".to_string(), Column::Int(ages)),
2286                ("score".to_string(), Column::Float(scores)),
2287            ],
2288        }
2289    }
2290
2291    #[test]
2292    fn batched_large_data_filter_parity() {
2293        let predicate = DExpr::BinOp {
2294            op: DBinOp::Gt,
2295            left: Box::new(DExpr::Col("age".into())),
2296            right: Box::new(DExpr::LitInt(50)),
2297        };
2298
2299        let eager = LazyView::from_df(large_df())
2300            .filter(predicate.clone())
2301            .collect()
2302            .unwrap();
2303        let batched = LazyView::from_df(large_df())
2304            .filter(predicate)
2305            .collect_batched()
2306            .unwrap();
2307
2308        assert_df_eq(
2309            &eager.borrow(),
2310            &batched.borrow(),
2311            "large data filter parity",
2312        );
2313        // Verify batching actually processed >1 batch.
2314        assert!(eager.borrow().nrows() > 0);
2315    }
2316
2317    #[test]
2318    fn batched_large_data_chain_parity() {
2319        let predicate = DExpr::BinOp {
2320            op: DBinOp::Gt,
2321            left: Box::new(DExpr::Col("age".into())),
2322            right: Box::new(DExpr::LitInt(50)),
2323        };
2324        let assignments = vec![(
2325            "bonus".into(),
2326            DExpr::BinOp {
2327                op: DBinOp::Mul,
2328                left: Box::new(DExpr::Col("score".into())),
2329                right: Box::new(DExpr::LitFloat(1.5)),
2330            },
2331        )];
2332
2333        let eager = LazyView::from_df(large_df())
2334            .filter(predicate.clone())
2335            .mutate(assignments.clone())
2336            .select(vec!["name".into(), "bonus".into()])
2337            .collect()
2338            .unwrap();
2339        let batched = LazyView::from_df(large_df())
2340            .filter(predicate)
2341            .mutate(assignments)
2342            .select(vec!["name".into(), "bonus".into()])
2343            .collect_batched()
2344            .unwrap();
2345
2346        assert_df_eq(
2347            &eager.borrow(),
2348            &batched.borrow(),
2349            "large data chain parity",
2350        );
2351    }
2352
2353    #[test]
2354    fn batched_large_data_determinism() {
2355        let mut prev_ages: Option<Vec<i64>> = None;
2356        for _ in 0..3 {
2357            let result = LazyView::from_df(large_df())
2358                .filter(DExpr::BinOp {
2359                    op: DBinOp::Gt,
2360                    left: Box::new(DExpr::Col("age".into())),
2361                    right: Box::new(DExpr::LitInt(90)),
2362                })
2363                .mutate(vec![(
2364                    "double_age".into(),
2365                    DExpr::BinOp {
2366                        op: DBinOp::Mul,
2367                        left: Box::new(DExpr::Col("age".into())),
2368                        right: Box::new(DExpr::LitInt(2)),
2369                    },
2370                )])
2371                .collect_batched()
2372                .unwrap();
2373
2374            let df = result.borrow();
2375            let ages = match df.get_column("age").unwrap() {
2376                Column::Int(v) => v.clone(),
2377                _ => panic!("expected Int"),
2378            };
2379            if let Some(ref prev) = prev_ages {
2380                assert_eq!(prev, &ages, "determinism: ages differ across runs");
2381            }
2382            prev_ages = Some(ages);
2383        }
2384    }
2385
2386    // ── Batch splitting helper ──────────────────────────────────────────
2387
2388    #[test]
2389    fn split_batches_correct_count() {
2390        let df = large_df();
2391        let batches = split_batches(&df);
2392        // 10000 rows / 2048 = 4 full + 1 partial = 5 batches
2393        assert_eq!(batches.len(), 5);
2394        assert_eq!(batches[0].nrows, 2048);
2395        assert_eq!(batches[1].nrows, 2048);
2396        assert_eq!(batches[2].nrows, 2048);
2397        assert_eq!(batches[3].nrows, 2048);
2398        assert_eq!(batches[4].nrows, 10000 - 4 * 2048); // 1808
2399        let total: usize = batches.iter().map(|b| b.nrows).sum();
2400        assert_eq!(total, 10000);
2401    }
2402
2403    #[test]
2404    fn split_batches_small_df() {
2405        let df = test_df(); // 4 rows
2406        let batches = split_batches(&df);
2407        assert_eq!(batches.len(), 1);
2408        assert_eq!(batches[0].nrows, 4);
2409    }
2410
2411    #[test]
2412    fn merge_batches_roundtrip() {
2413        let df = large_df();
2414        let batches = split_batches(&df);
2415        let merged = merge_batches(batches).unwrap();
2416        assert_df_eq(&df, &merged, "merge roundtrip");
2417    }
2418}