Skip to main content

cjc_data/
lazy.rs

1//! Lazy evaluation IR for TidyView.
2//!
3//! A `LazyView` captures a chain of operations as a tree of `ViewNode`s.
4//! On `.collect()`, the tree is optimized by rule-based passes, then executed.
5//!
6//! # Determinism
7//!
8//! - All plan transformations are pure, deterministic functions.
9//! - No HashMap/HashSet usage -- only Vec and BTreeSet for column tracking.
10//! - Execution delegates to TidyView/TidyFrame methods which already guarantee
11//!   deterministic output (Kahan summation, BTreeMap groups, stable sorts).
12
13use crate::{ArrangeKey, Column, DExpr, DBinOp, DataFrame, TidyAgg, TidyError, TidyFrame};
14use std::collections::BTreeSet;
15use std::rc::Rc;
16
17// ── ViewNode IR ──────────────────────────────────────────────────────────────
18
19/// A node in the lazy evaluation tree.
20#[derive(Debug, Clone)]
21pub enum ViewNode {
22    /// Leaf: scan a base DataFrame.
23    Scan { df: Rc<DataFrame> },
24    /// Filter rows by predicate.
25    Filter {
26        input: Box<ViewNode>,
27        predicate: DExpr,
28    },
29    /// Project to subset of columns.
30    Select {
31        input: Box<ViewNode>,
32        columns: Vec<String>,
33    },
34    /// Add/replace columns via expressions.
35    Mutate {
36        input: Box<ViewNode>,
37        assignments: Vec<(String, DExpr)>,
38    },
39    /// Sort by keys.
40    Arrange {
41        input: Box<ViewNode>,
42        keys: Vec<ArrangeKey>,
43    },
44    /// Group + summarise (pipeline breaker).
45    GroupSummarise {
46        input: Box<ViewNode>,
47        group_keys: Vec<String>,
48        aggregations: Vec<(String, TidyAgg)>,
49    },
50    /// Distinct on columns.
51    Distinct {
52        input: Box<ViewNode>,
53        columns: Vec<String>,
54    },
55    /// Join two inputs.
56    Join {
57        left: Box<ViewNode>,
58        right: Box<ViewNode>,
59        on: Vec<(String, String)>,
60        kind: JoinType,
61    },
62}
63
64/// The kind of join.
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum JoinType {
67    Inner,
68    Left,
69    Semi,
70    Anti,
71}
72
73// ── LazyView builder ─────────────────────────────────────────────────────────
74
75/// A lazy view that captures operations without executing them.
76pub struct LazyView {
77    plan: ViewNode,
78}
79
80impl LazyView {
81    /// Create from a DataFrame (takes ownership, wraps in Rc).
82    pub fn from_df(df: DataFrame) -> Self {
83        LazyView {
84            plan: ViewNode::Scan { df: Rc::new(df) },
85        }
86    }
87
88    /// Create from an Rc<DataFrame>.
89    pub fn from_rc(df: Rc<DataFrame>) -> Self {
90        LazyView {
91            plan: ViewNode::Scan { df },
92        }
93    }
94
95    /// Filter rows by a DExpr predicate.
96    pub fn filter(self, predicate: DExpr) -> Self {
97        LazyView {
98            plan: ViewNode::Filter {
99                input: Box::new(self.plan),
100                predicate,
101            },
102        }
103    }
104
105    /// Project to a subset of columns.
106    pub fn select(self, columns: Vec<String>) -> Self {
107        LazyView {
108            plan: ViewNode::Select {
109                input: Box::new(self.plan),
110                columns,
111            },
112        }
113    }
114
115    /// Add or replace columns via expressions.
116    pub fn mutate(self, assignments: Vec<(String, DExpr)>) -> Self {
117        LazyView {
118            plan: ViewNode::Mutate {
119                input: Box::new(self.plan),
120                assignments,
121            },
122        }
123    }
124
125    /// Sort rows by keys.
126    pub fn arrange(self, keys: Vec<ArrangeKey>) -> Self {
127        LazyView {
128            plan: ViewNode::Arrange {
129                input: Box::new(self.plan),
130                keys,
131            },
132        }
133    }
134
135    /// Group by keys and aggregate.
136    pub fn group_summarise(
137        self,
138        group_keys: Vec<String>,
139        aggregations: Vec<(String, TidyAgg)>,
140    ) -> Self {
141        LazyView {
142            plan: ViewNode::GroupSummarise {
143                input: Box::new(self.plan),
144                group_keys,
145                aggregations,
146            },
147        }
148    }
149
150    /// Keep distinct rows by columns.
151    pub fn distinct(self, columns: Vec<String>) -> Self {
152        LazyView {
153            plan: ViewNode::Distinct {
154                input: Box::new(self.plan),
155                columns,
156            },
157        }
158    }
159
160    /// Join with another LazyView.
161    pub fn join(self, right: LazyView, on: Vec<(String, String)>, kind: JoinType) -> Self {
162        LazyView {
163            plan: ViewNode::Join {
164                left: Box::new(self.plan),
165                right: Box::new(right.plan),
166                on,
167                kind,
168            },
169        }
170    }
171
172    /// Optimize and execute the plan, returning a TidyFrame.
173    pub fn collect(self) -> Result<TidyFrame, TidyError> {
174        let optimized = optimize(self.plan);
175        execute(optimized)
176    }
177
178    /// Inspect the plan tree (for testing/debugging).
179    pub fn plan(&self) -> &ViewNode {
180        &self.plan
181    }
182
183    /// Consume and return the optimized plan without executing (for testing).
184    pub fn optimized_plan(self) -> ViewNode {
185        optimize(self.plan)
186    }
187}
188
189// ── Optimizer ────────────────────────────────────────────────────────────────
190
191/// Apply all optimization passes to a ViewNode tree.
192///
193/// Pass order matters: merge filters first so pushdown sees fewer nodes,
194/// then push predicates toward scans, then prune redundant selects.
195pub fn optimize(plan: ViewNode) -> ViewNode {
196    let plan = merge_filters(plan);
197    let plan = push_predicates_down(plan);
198    let plan = eliminate_redundant_selects(plan);
199    plan
200}
201
202// ── Pass 1: Filter Merging ───────────────────────────────────────────────────
203
204/// Merge consecutive Filter nodes into a single Filter with AND predicate.
205///
206/// `Filter(Filter(input, p1), p2)` becomes `Filter(input, p1 AND p2)`.
207fn merge_filters(plan: ViewNode) -> ViewNode {
208    match plan {
209        ViewNode::Filter { input, predicate } => {
210            let merged_input = merge_filters(*input);
211            match merged_input {
212                ViewNode::Filter {
213                    input: inner,
214                    predicate: inner_pred,
215                } => {
216                    // Combine: inner_pred AND predicate
217                    let combined = DExpr::BinOp {
218                        op: DBinOp::And,
219                        left: Box::new(inner_pred),
220                        right: Box::new(predicate),
221                    };
222                    ViewNode::Filter {
223                        input: inner,
224                        predicate: combined,
225                    }
226                }
227                other => ViewNode::Filter {
228                    input: Box::new(other),
229                    predicate,
230                },
231            }
232        }
233        // Recurse into all other node types
234        ViewNode::Select { input, columns } => ViewNode::Select {
235            input: Box::new(merge_filters(*input)),
236            columns,
237        },
238        ViewNode::Mutate {
239            input,
240            assignments,
241        } => ViewNode::Mutate {
242            input: Box::new(merge_filters(*input)),
243            assignments,
244        },
245        ViewNode::Arrange { input, keys } => ViewNode::Arrange {
246            input: Box::new(merge_filters(*input)),
247            keys,
248        },
249        ViewNode::GroupSummarise {
250            input,
251            group_keys,
252            aggregations,
253        } => ViewNode::GroupSummarise {
254            input: Box::new(merge_filters(*input)),
255            group_keys,
256            aggregations,
257        },
258        ViewNode::Distinct { input, columns } => ViewNode::Distinct {
259            input: Box::new(merge_filters(*input)),
260            columns,
261        },
262        ViewNode::Join {
263            left,
264            right,
265            on,
266            kind,
267        } => ViewNode::Join {
268            left: Box::new(merge_filters(*left)),
269            right: Box::new(merge_filters(*right)),
270            on,
271            kind,
272        },
273        other => other, // Scan
274    }
275}
276
277// ── Pass 2: Predicate Pushdown ───────────────────────────────────────────────
278
279/// Push Filter nodes closer to Scan nodes.
280///
281/// Rules:
282/// - Filter past Select: always safe (filter refs columns that must exist).
283/// - Filter past Mutate: only if predicate does NOT reference any mutated column.
284/// - Filter into Join: push to the side that owns ALL referenced columns.
285/// - Filter past Arrange: always safe (sort order preserved after filter).
286/// - Do NOT push past GroupSummarise (aggregation changes row identity).
287/// - Do NOT push past Distinct (distinct changes row identity).
288fn push_predicates_down(plan: ViewNode) -> ViewNode {
289    match plan {
290        ViewNode::Filter { input, predicate } => {
291            let optimized_input = push_predicates_down(*input);
292            push_filter_into(optimized_input, predicate)
293        }
294        // Recurse into all other nodes
295        ViewNode::Select { input, columns } => ViewNode::Select {
296            input: Box::new(push_predicates_down(*input)),
297            columns,
298        },
299        ViewNode::Mutate {
300            input,
301            assignments,
302        } => ViewNode::Mutate {
303            input: Box::new(push_predicates_down(*input)),
304            assignments,
305        },
306        ViewNode::Arrange { input, keys } => ViewNode::Arrange {
307            input: Box::new(push_predicates_down(*input)),
308            keys,
309        },
310        ViewNode::GroupSummarise {
311            input,
312            group_keys,
313            aggregations,
314        } => ViewNode::GroupSummarise {
315            input: Box::new(push_predicates_down(*input)),
316            group_keys,
317            aggregations,
318        },
319        ViewNode::Distinct { input, columns } => ViewNode::Distinct {
320            input: Box::new(push_predicates_down(*input)),
321            columns,
322        },
323        ViewNode::Join {
324            left,
325            right,
326            on,
327            kind,
328        } => ViewNode::Join {
329            left: Box::new(push_predicates_down(*left)),
330            right: Box::new(push_predicates_down(*right)),
331            on,
332            kind,
333        },
334        other => other,
335    }
336}
337
338/// Try to push a filter predicate below the given node.
339fn push_filter_into(node: ViewNode, predicate: DExpr) -> ViewNode {
340    match node {
341        // Push filter past Select (always safe -- filter references columns
342        // that are in the select list or the query is malformed anyway).
343        ViewNode::Select { input, columns } => ViewNode::Select {
344            input: Box::new(push_filter_into(*input, predicate)),
345            columns,
346        },
347
348        // Push filter past Arrange (filter doesn't affect sort order).
349        ViewNode::Arrange { input, keys } => ViewNode::Arrange {
350            input: Box::new(push_filter_into(*input, predicate)),
351            keys,
352        },
353
354        // Push filter past Mutate only if the predicate does NOT reference
355        // any column that Mutate introduces/replaces.
356        ViewNode::Mutate {
357            input,
358            assignments,
359        } => {
360            let pred_cols = expr_columns(&predicate);
361            let mutated_cols: BTreeSet<String> =
362                assignments.iter().map(|(name, _)| name.clone()).collect();
363            let references_mutated = pred_cols.iter().any(|c| mutated_cols.contains(c));
364
365            if references_mutated {
366                // Cannot push -- predicate depends on mutated columns.
367                ViewNode::Filter {
368                    input: Box::new(ViewNode::Mutate {
369                        input,
370                        assignments,
371                    }),
372                    predicate,
373                }
374            } else {
375                // Safe to push below.
376                ViewNode::Mutate {
377                    input: Box::new(push_filter_into(*input, predicate)),
378                    assignments,
379                }
380            }
381        }
382
383        // Push filter into Join: if predicate references only left-side columns,
384        // push into left; if only right-side, push into right; otherwise keep above.
385        ViewNode::Join {
386            left,
387            right,
388            on,
389            kind,
390        } => {
391            let pred_cols = expr_columns(&predicate);
392            let left_cols = node_output_columns(&left);
393            let right_cols = node_output_columns(&right);
394
395            let all_in_left = pred_cols.iter().all(|c| left_cols.contains(c));
396            let all_in_right = pred_cols.iter().all(|c| right_cols.contains(c));
397
398            if all_in_left {
399                ViewNode::Join {
400                    left: Box::new(push_filter_into(*left, predicate)),
401                    right,
402                    on,
403                    kind,
404                }
405            } else if all_in_right {
406                ViewNode::Join {
407                    left,
408                    right: Box::new(push_filter_into(*right, predicate)),
409                    on,
410                    kind,
411                }
412            } else {
413                // Predicate spans both sides -- keep above.
414                ViewNode::Filter {
415                    input: Box::new(ViewNode::Join {
416                        left,
417                        right,
418                        on,
419                        kind,
420                    }),
421                    predicate,
422                }
423            }
424        }
425
426        // Do NOT push past GroupSummarise or Distinct -- they change row identity.
427        other => ViewNode::Filter {
428            input: Box::new(other),
429            predicate,
430        },
431    }
432}
433
434// ── Pass 3: Redundant Select Elimination ─────────────────────────────────────
435
436/// Remove Select nodes that select all columns from their input
437/// (i.e., the select list matches the input's output columns exactly).
438fn eliminate_redundant_selects(plan: ViewNode) -> ViewNode {
439    match plan {
440        ViewNode::Select { input, columns } => {
441            let optimized_input = eliminate_redundant_selects(*input);
442            let input_cols = node_output_columns(&optimized_input);
443
444            // If the select list matches all input columns (same set), remove it.
445            let select_set: BTreeSet<&str> = columns.iter().map(|s| s.as_str()).collect();
446            let input_set: BTreeSet<&str> = input_cols.iter().map(|s| s.as_str()).collect();
447
448            if select_set == input_set {
449                optimized_input
450            } else {
451                ViewNode::Select {
452                    input: Box::new(optimized_input),
453                    columns,
454                }
455            }
456        }
457        ViewNode::Filter { input, predicate } => ViewNode::Filter {
458            input: Box::new(eliminate_redundant_selects(*input)),
459            predicate,
460        },
461        ViewNode::Mutate {
462            input,
463            assignments,
464        } => ViewNode::Mutate {
465            input: Box::new(eliminate_redundant_selects(*input)),
466            assignments,
467        },
468        ViewNode::Arrange { input, keys } => ViewNode::Arrange {
469            input: Box::new(eliminate_redundant_selects(*input)),
470            keys,
471        },
472        ViewNode::GroupSummarise {
473            input,
474            group_keys,
475            aggregations,
476        } => ViewNode::GroupSummarise {
477            input: Box::new(eliminate_redundant_selects(*input)),
478            group_keys,
479            aggregations,
480        },
481        ViewNode::Distinct { input, columns } => ViewNode::Distinct {
482            input: Box::new(eliminate_redundant_selects(*input)),
483            columns,
484        },
485        ViewNode::Join {
486            left,
487            right,
488            on,
489            kind,
490        } => ViewNode::Join {
491            left: Box::new(eliminate_redundant_selects(*left)),
492            right: Box::new(eliminate_redundant_selects(*right)),
493            on,
494            kind,
495        },
496        other => other,
497    }
498}
499
500// ── Executor ─────────────────────────────────────────────────────────────────
501
502/// Execute an optimized ViewNode tree, producing a TidyFrame.
503fn execute(node: ViewNode) -> Result<TidyFrame, TidyError> {
504    match node {
505        ViewNode::Scan { df } => Ok(TidyFrame::from_df((*df).clone())),
506
507        ViewNode::Filter { input, predicate } => {
508            let frame = execute(*input)?;
509            let view = frame.view();
510            let filtered = view.filter(&predicate)?;
511            let df = filtered.materialize()?;
512            Ok(TidyFrame::from_df(df))
513        }
514
515        ViewNode::Select { input, columns } => {
516            let frame = execute(*input)?;
517            let view = frame.view();
518            let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
519            let selected = view.select(&col_refs)?;
520            let df = selected.materialize()?;
521            Ok(TidyFrame::from_df(df))
522        }
523
524        ViewNode::Mutate {
525            input,
526            assignments,
527        } => {
528            let frame = execute(*input)?;
529            let view = frame.view();
530            let assign_refs: Vec<(&str, DExpr)> = assignments
531                .into_iter()
532                .map(|(name, expr)| (leaked_str(&name), expr))
533                .collect();
534            // Use TidyView::mutate which materializes + applies assignments.
535            let result = view.mutate(&assign_refs.iter().map(|(n, e)| (*n, e.clone())).collect::<Vec<_>>())?;
536            Ok(result)
537        }
538
539        ViewNode::Arrange { input, keys } => {
540            let frame = execute(*input)?;
541            let view = frame.view();
542            let arranged = view.arrange(&keys)?;
543            let df = arranged.materialize()?;
544            Ok(TidyFrame::from_df(df))
545        }
546
547        ViewNode::GroupSummarise {
548            input,
549            group_keys,
550            aggregations,
551        } => {
552            let frame = execute(*input)?;
553            let view = frame.view();
554            let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
555            let grouped = view.group_by(&key_refs)?;
556            let agg_refs: Vec<(&str, TidyAgg)> = aggregations
557                .into_iter()
558                .map(|(name, agg)| (leaked_str(&name), agg))
559                .collect();
560            let result = grouped.summarise(
561                &agg_refs.iter().map(|(n, a)| (*n, a.clone())).collect::<Vec<_>>(),
562            )?;
563            Ok(result)
564        }
565
566        ViewNode::Distinct { input, columns } => {
567            let frame = execute(*input)?;
568            let view = frame.view();
569            let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
570            let distinct = view.distinct(&col_refs)?;
571            let df = distinct.materialize()?;
572            Ok(TidyFrame::from_df(df))
573        }
574
575        ViewNode::Join {
576            left,
577            right,
578            on,
579            kind,
580        } => {
581            let left_frame = execute(*left)?;
582            let right_frame = execute(*right)?;
583            let left_view = left_frame.view();
584            let right_view = right_frame.view();
585            let on_refs: Vec<(&str, &str)> = on
586                .iter()
587                .map(|(l, r)| (l.as_str(), r.as_str()))
588                .collect();
589
590            match kind {
591                JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
592                JoinType::Left => left_view.left_join(&right_view, &on_refs),
593                JoinType::Semi => {
594                    let result = left_view.semi_join(&right_view, &on_refs)?;
595                    let df = result.materialize()?;
596                    Ok(TidyFrame::from_df(df))
597                }
598                JoinType::Anti => {
599                    let result = left_view.anti_join(&right_view, &on_refs)?;
600                    let df = result.materialize()?;
601                    Ok(TidyFrame::from_df(df))
602                }
603            }
604        }
605    }
606}
607
608// ── Helper: column reference extraction from DExpr ───────────────────────────
609
610/// Collect all column names referenced by an expression.
611fn expr_columns(expr: &DExpr) -> BTreeSet<String> {
612    let mut cols = BTreeSet::new();
613    collect_expr_cols(expr, &mut cols);
614    cols
615}
616
617fn collect_expr_cols(expr: &DExpr, cols: &mut BTreeSet<String>) {
618    match expr {
619        DExpr::Col(name) => {
620            cols.insert(name.clone());
621        }
622        DExpr::BinOp { left, right, .. } => {
623            collect_expr_cols(left, cols);
624            collect_expr_cols(right, cols);
625        }
626        DExpr::Agg(_, inner) => collect_expr_cols(inner, cols),
627        DExpr::FnCall(_, args) => {
628            for arg in args {
629                collect_expr_cols(arg, cols);
630            }
631        }
632        DExpr::CumSum(e)
633        | DExpr::CumProd(e)
634        | DExpr::CumMax(e)
635        | DExpr::CumMin(e)
636        | DExpr::Lag(e, _)
637        | DExpr::Lead(e, _)
638        | DExpr::Rank(e)
639        | DExpr::DenseRank(e) => {
640            collect_expr_cols(e, cols);
641        }
642        // Rolling window functions reference a column by name (String field).
643        DExpr::RollingSum(col, _)
644        | DExpr::RollingMean(col, _)
645        | DExpr::RollingMin(col, _)
646        | DExpr::RollingMax(col, _)
647        | DExpr::RollingVar(col, _)
648        | DExpr::RollingSd(col, _) => {
649            cols.insert(col.clone());
650        }
651        DExpr::LitInt(_)
652        | DExpr::LitFloat(_)
653        | DExpr::LitBool(_)
654        | DExpr::LitStr(_)
655        | DExpr::Count
656        | DExpr::RowNumber => {}
657    }
658}
659
660// ── Helper: infer output columns of a ViewNode ───────────────────────────────
661
662/// Return the set of column names that a node produces.
663///
664/// For Scan nodes, reads from the DataFrame directly.
665/// For others, infers from the node type.
666fn node_output_columns(node: &ViewNode) -> BTreeSet<String> {
667    match node {
668        ViewNode::Scan { df } => df.column_names().into_iter().map(|s| s.to_string()).collect(),
669        ViewNode::Filter { input, .. } => node_output_columns(input),
670        ViewNode::Select { columns, .. } => columns.iter().cloned().collect(),
671        ViewNode::Mutate {
672            input,
673            assignments,
674        } => {
675            let mut cols = node_output_columns(input);
676            for (name, _) in assignments {
677                cols.insert(name.clone());
678            }
679            cols
680        }
681        ViewNode::Arrange { input, .. } => node_output_columns(input),
682        ViewNode::GroupSummarise {
683            group_keys,
684            aggregations,
685            ..
686        } => {
687            let mut cols: BTreeSet<String> = group_keys.iter().cloned().collect();
688            for (name, _) in aggregations {
689                cols.insert(name.clone());
690            }
691            cols
692        }
693        ViewNode::Distinct { input, .. } => node_output_columns(input),
694        ViewNode::Join {
695            left, right, on, ..
696        } => {
697            let mut cols = node_output_columns(left);
698            let right_cols = node_output_columns(right);
699            // Right join keys that duplicate left keys are excluded in output
700            let left_keys: BTreeSet<&str> = on.iter().map(|(l, _)| l.as_str()).collect();
701            let right_keys: BTreeSet<&str> = on.iter().map(|(_, r)| r.as_str()).collect();
702            for c in &right_cols {
703                if !right_keys.contains(c.as_str()) || !left_keys.contains(c.as_str()) {
704                    cols.insert(c.clone());
705                }
706            }
707            cols
708        }
709    }
710}
711
712// ── Helper: leak a String into a &'static str for API compatibility ──────────
713
714/// Convert a String to &'static str by leaking.
715///
716/// This is used only during plan execution (bounded number of calls per plan).
717/// The leaked memory is small (column name strings) and proportional to
718/// the plan size, not the data size.
719fn leaked_str(s: &str) -> &'static str {
720    Box::leak(s.to_string().into_boxed_str())
721}
722
723// ── Plan inspection helpers (for testing) ────────────────────────────────────
724
725impl ViewNode {
726    /// Count the number of Filter nodes in the tree.
727    pub fn count_filters(&self) -> usize {
728        match self {
729            ViewNode::Filter { input, .. } => 1 + input.count_filters(),
730            ViewNode::Select { input, .. } => input.count_filters(),
731            ViewNode::Mutate { input, .. } => input.count_filters(),
732            ViewNode::Arrange { input, .. } => input.count_filters(),
733            ViewNode::GroupSummarise { input, .. } => input.count_filters(),
734            ViewNode::Distinct { input, .. } => input.count_filters(),
735            ViewNode::Join { left, right, .. } => {
736                left.count_filters() + right.count_filters()
737            }
738            ViewNode::Scan { .. } => 0,
739        }
740    }
741
742    /// Check if the immediate child (input) of the outermost node is a Scan.
743    /// Useful for verifying predicate pushdown moved a filter near the scan.
744    pub fn is_filter_on_scan(&self) -> bool {
745        match self {
746            ViewNode::Filter { input, .. } => matches!(input.as_ref(), ViewNode::Scan { .. }),
747            _ => false,
748        }
749    }
750
751    /// Return the innermost node (the leaf Scan) by walking `input` chains.
752    pub fn innermost(&self) -> &ViewNode {
753        match self {
754            ViewNode::Filter { input, .. }
755            | ViewNode::Select { input, .. }
756            | ViewNode::Mutate { input, .. }
757            | ViewNode::Arrange { input, .. }
758            | ViewNode::GroupSummarise { input, .. }
759            | ViewNode::Distinct { input, .. } => input.innermost(),
760            ViewNode::Join { left, .. } => left.innermost(),
761            ViewNode::Scan { .. } => self,
762        }
763    }
764
765    /// Return the node kind name (for test assertions).
766    pub fn kind(&self) -> &'static str {
767        match self {
768            ViewNode::Scan { .. } => "Scan",
769            ViewNode::Filter { .. } => "Filter",
770            ViewNode::Select { .. } => "Select",
771            ViewNode::Mutate { .. } => "Mutate",
772            ViewNode::Arrange { .. } => "Arrange",
773            ViewNode::GroupSummarise { .. } => "GroupSummarise",
774            ViewNode::Distinct { .. } => "Distinct",
775            ViewNode::Join { .. } => "Join",
776        }
777    }
778
779    /// Walk the plan tree depth-first and collect node kinds top-down.
780    pub fn node_kinds(&self) -> Vec<&'static str> {
781        let mut out = vec![self.kind()];
782        match self {
783            ViewNode::Filter { input, .. }
784            | ViewNode::Select { input, .. }
785            | ViewNode::Mutate { input, .. }
786            | ViewNode::Arrange { input, .. }
787            | ViewNode::GroupSummarise { input, .. }
788            | ViewNode::Distinct { input, .. } => {
789                out.extend(input.node_kinds());
790            }
791            ViewNode::Join { left, right, .. } => {
792                out.extend(left.node_kinds());
793                out.extend(right.node_kinds());
794            }
795            ViewNode::Scan { .. } => {}
796        }
797        out
798    }
799}
800
801// ── Batch Executor ────────────────────────────────────────────────────────────
802
803/// Maximum rows per batch for vectorized processing.
804const BATCH_SIZE: usize = 2048;
805
806/// A chunk of up to `BATCH_SIZE` rows for vectorized processing.
807///
808/// Batches are processed sequentially in order (batch 0, batch 1, ...)
809/// to preserve deterministic row ordering.
810#[derive(Debug, Clone)]
811pub struct Batch {
812    pub columns: Vec<(String, Column)>,
813    pub nrows: usize,
814}
815
816impl Batch {
817    /// Convert this batch into a DataFrame.
818    fn into_dataframe(self) -> DataFrame {
819        DataFrame {
820            columns: self.columns,
821        }
822    }
823
824    /// Get a column by name.
825    fn get_column(&self, name: &str) -> Option<&Column> {
826        self.columns.iter().find(|(n, _)| n == name).map(|(_, c)| c)
827    }
828
829    /// Column names in order.
830    fn column_names(&self) -> Vec<&str> {
831        self.columns.iter().map(|(n, _)| n.as_str()).collect()
832    }
833}
834
835/// Slice a column from `start..end`.
836fn slice_column(col: &Column, start: usize, end: usize) -> Column {
837    match col {
838        Column::Float(v) => Column::Float(v[start..end].to_vec()),
839        Column::Int(v) => Column::Int(v[start..end].to_vec()),
840        Column::Str(v) => Column::Str(v[start..end].to_vec()),
841        Column::Bool(v) => Column::Bool(v[start..end].to_vec()),
842        Column::Categorical { levels, codes } => Column::Categorical {
843            levels: levels.clone(),
844            codes: codes[start..end].to_vec(),
845        },
846        Column::DateTime(v) => Column::DateTime(v[start..end].to_vec()),
847    }
848}
849
850/// Split a DataFrame into batches of up to `BATCH_SIZE` rows.
851fn split_batches(df: &DataFrame) -> Vec<Batch> {
852    let nrows = df.nrows();
853    if nrows == 0 {
854        return vec![Batch {
855            columns: df.columns.iter().map(|(n, c)| {
856                (n.clone(), slice_column(c, 0, 0))
857            }).collect(),
858            nrows: 0,
859        }];
860    }
861    let mut batches = Vec::new();
862    let mut offset = 0;
863    while offset < nrows {
864        let end = (offset + BATCH_SIZE).min(nrows);
865        let batch_cols = df
866            .columns
867            .iter()
868            .map(|(name, col)| (name.clone(), slice_column(col, offset, end)))
869            .collect();
870        batches.push(Batch {
871            columns: batch_cols,
872            nrows: end - offset,
873        });
874        offset = end;
875    }
876    batches
877}
878
879/// Merge a vector of batches back into a single DataFrame.
880///
881/// Batches must have identical column schemas. Empty batches are skipped.
882fn merge_batches(batches: Vec<Batch>) -> Result<DataFrame, TidyError> {
883    if batches.is_empty() {
884        return Ok(DataFrame::new());
885    }
886
887    // Determine schema from first non-empty batch (or just first batch).
888    let schema: Vec<String> = batches[0].column_names().iter().map(|s| s.to_string()).collect();
889    if schema.is_empty() {
890        return Ok(DataFrame::new());
891    }
892
893    // Pre-allocate merged columns.
894    let total_rows: usize = batches.iter().map(|b| b.nrows).sum();
895    let mut merged_cols: Vec<(String, Column)> = schema
896        .iter()
897        .map(|name| {
898            // Determine type from first batch's column.
899            let first_col = batches[0].get_column(name).unwrap();
900            let empty = match first_col {
901                Column::Float(_) => Column::Float(Vec::with_capacity(total_rows)),
902                Column::Int(_) => Column::Int(Vec::with_capacity(total_rows)),
903                Column::Str(_) => Column::Str(Vec::with_capacity(total_rows)),
904                Column::Bool(_) => Column::Bool(Vec::with_capacity(total_rows)),
905                Column::Categorical { levels, .. } => Column::Categorical {
906                    levels: levels.clone(),
907                    codes: Vec::with_capacity(total_rows),
908                },
909                Column::DateTime(_) => Column::DateTime(Vec::with_capacity(total_rows)),
910            };
911            (name.clone(), empty)
912        })
913        .collect();
914
915    // Append each batch's data.
916    for batch in &batches {
917        if batch.nrows == 0 {
918            continue;
919        }
920        for (i, (name, merged_col)) in merged_cols.iter_mut().enumerate() {
921            let batch_col = batch.get_column(name).ok_or_else(|| {
922                TidyError::ColumnNotFound(format!(
923                    "batch merge: column '{}' missing in batch (index {})",
924                    name, i
925                ))
926            })?;
927            append_column(merged_col, batch_col);
928        }
929    }
930
931    Ok(DataFrame { columns: merged_cols })
932}
933
934/// Append all rows from `src` into `dst` (same type assumed).
935fn append_column(dst: &mut Column, src: &Column) {
936    match (dst, src) {
937        (Column::Float(d), Column::Float(s)) => d.extend_from_slice(s),
938        (Column::Int(d), Column::Int(s)) => d.extend_from_slice(s),
939        (Column::Str(d), Column::Str(s)) => d.extend(s.iter().cloned()),
940        (Column::Bool(d), Column::Bool(s)) => d.extend_from_slice(s),
941        (Column::Categorical { codes: d, .. }, Column::Categorical { codes: s, .. }) => {
942            d.extend_from_slice(s);
943        }
944        (Column::DateTime(d), Column::DateTime(s)) => d.extend_from_slice(s),
945        _ => {} // Type mismatch: should not happen if schema is consistent.
946    }
947}
948
949// ── Streamable operation representation ──────────────────────────────────────
950
951/// A streamable (non-breaking) operation that can be applied per-batch.
952#[derive(Debug, Clone)]
953enum StreamableOp {
954    Filter { predicate: DExpr },
955    Select { columns: Vec<String> },
956    Mutate { assignments: Vec<(String, DExpr)> },
957}
958
959/// Returns true if the node is a pipeline breaker (requires full materialization).
960fn is_pipeline_breaker(node: &ViewNode) -> bool {
961    matches!(
962        node,
963        ViewNode::Arrange { .. }
964            | ViewNode::GroupSummarise { .. }
965            | ViewNode::Distinct { .. }
966            | ViewNode::Join { .. }
967    )
968}
969
970/// Walk the plan tree and collect a chain of streamable operations from the top.
971///
972/// Returns `(streamable_ops_in_execution_order, base_node)`.
973/// The chain is collected top-down (outermost first), then reversed so the
974/// innermost (closest to scan) operation is applied first.
975fn collect_streamable_chain(node: ViewNode) -> (Vec<StreamableOp>, Box<ViewNode>) {
976    let mut ops = Vec::new();
977    let mut current = node;
978
979    loop {
980        match current {
981            ViewNode::Filter { input, predicate } => {
982                ops.push(StreamableOp::Filter { predicate });
983                current = *input;
984            }
985            ViewNode::Select { input, columns } => {
986                ops.push(StreamableOp::Select { columns });
987                current = *input;
988            }
989            ViewNode::Mutate { input, assignments } => {
990                ops.push(StreamableOp::Mutate { assignments });
991                current = *input;
992            }
993            // Any other node is the base (Scan or a pipeline breaker).
994            other => {
995                // Reverse: we collected outermost-first, but need to apply innermost-first.
996                ops.reverse();
997                return (ops, Box::new(other));
998            }
999        }
1000    }
1001}
1002
1003/// Apply a single streamable operation to a batch.
1004fn apply_op_to_batch(batch: Batch, op: &StreamableOp) -> Result<Batch, TidyError> {
1005    match op {
1006        StreamableOp::Filter { predicate } => {
1007            // Materialize batch into a temporary DataFrame, apply filter via TidyView.
1008            let df = batch.into_dataframe();
1009            if df.nrows() == 0 {
1010                return Ok(Batch {
1011                    nrows: 0,
1012                    columns: df.columns,
1013                });
1014            }
1015            let frame = TidyFrame::from_df(df);
1016            let view = frame.view();
1017            let filtered = view.filter(predicate)?;
1018            let result_df = filtered.materialize()?;
1019            let nrows = result_df.nrows();
1020            Ok(Batch {
1021                columns: result_df.columns,
1022                nrows,
1023            })
1024        }
1025        StreamableOp::Select { columns } => {
1026            // Keep only the named columns, in the requested order.
1027            let selected: Vec<(String, Column)> = columns
1028                .iter()
1029                .filter_map(|name| {
1030                    batch
1031                        .columns
1032                        .iter()
1033                        .find(|(n, _)| n == name)
1034                        .cloned()
1035                })
1036                .collect();
1037            Ok(Batch {
1038                nrows: batch.nrows,
1039                columns: selected,
1040            })
1041        }
1042        StreamableOp::Mutate { assignments } => {
1043            // Materialize batch into a temporary DataFrame, apply mutate via TidyView.
1044            let df = batch.into_dataframe();
1045            let frame = TidyFrame::from_df(df);
1046            let view = frame.view();
1047            let assign_refs: Vec<(&str, DExpr)> = assignments
1048                .iter()
1049                .map(|(name, expr)| (leaked_str(name), expr.clone()))
1050                .collect();
1051            let result = view.mutate(
1052                &assign_refs
1053                    .iter()
1054                    .map(|(n, e)| (*n, e.clone()))
1055                    .collect::<Vec<_>>(),
1056            )?;
1057            let result_df = result.borrow().clone();
1058            let nrows = result_df.nrows();
1059            Ok(Batch {
1060                columns: result_df.columns,
1061                nrows,
1062            })
1063        }
1064    }
1065}
1066
1067/// Apply a chain of streamable operations to a DataFrame in batches.
1068fn apply_chain_batched(
1069    frame: &TidyFrame,
1070    chain: &[StreamableOp],
1071) -> Result<TidyFrame, TidyError> {
1072    let df = frame.borrow().clone();
1073    let batches = split_batches(&df);
1074
1075    let mut result_batches: Vec<Batch> = Vec::new();
1076    for batch in batches {
1077        let mut current = batch;
1078        for op in chain {
1079            current = apply_op_to_batch(current, op)?;
1080        }
1081        if current.nrows > 0 {
1082            result_batches.push(current);
1083        }
1084    }
1085
1086    if result_batches.is_empty() {
1087        // Preserve schema from original DataFrame but with zero rows.
1088        let empty_df = DataFrame {
1089            columns: df
1090                .columns
1091                .iter()
1092                .map(|(name, col)| {
1093                    (name.clone(), slice_column(col, 0, 0))
1094                })
1095                .collect(),
1096        };
1097        // If chain includes a Select, apply column pruning to the empty frame.
1098        let mut result_cols: Option<Vec<String>> = None;
1099        for op in chain {
1100            if let StreamableOp::Select { columns } = op {
1101                result_cols = Some(columns.clone());
1102            }
1103        }
1104        if let Some(cols) = result_cols {
1105            let pruned: Vec<(String, Column)> = cols
1106                .iter()
1107                .filter_map(|name| {
1108                    empty_df
1109                        .columns
1110                        .iter()
1111                        .find(|(n, _)| n == name)
1112                        .cloned()
1113                })
1114                .collect();
1115            return Ok(TidyFrame::from_df(DataFrame { columns: pruned }));
1116        }
1117        return Ok(TidyFrame::from_df(empty_df));
1118    }
1119
1120    let merged = merge_batches(result_batches)?;
1121    Ok(TidyFrame::from_df(merged))
1122}
1123
1124/// Execute an optimized ViewNode tree using batch processing where possible.
1125///
1126/// Streamable operations (Filter, Select, Mutate) are fused into a single
1127/// batch pass over 2048-row chunks. Pipeline breakers (Arrange, GroupSummarise,
1128/// Distinct, Join) force full materialization before proceeding.
1129///
1130/// # Determinism
1131///
1132/// Batches are processed sequentially in row order. The merged output has
1133/// rows in the same order as non-batched execution. Float reductions use
1134/// Kahan summation (delegated to TidyView internals).
1135pub fn execute_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
1136    match &node {
1137        // Leaf: just materialize.
1138        ViewNode::Scan { .. } => execute(node),
1139
1140        // Streamable operations: collect chain and batch-execute.
1141        _ if !is_pipeline_breaker(&node) => {
1142            let (chain, base) = collect_streamable_chain(node);
1143            if chain.is_empty() {
1144                // Shouldn't happen, but handle gracefully.
1145                return execute_batched(*base);
1146            }
1147            let base_frame = execute_batched(*base)?;
1148            apply_chain_batched(&base_frame, &chain)
1149        }
1150
1151        // Pipeline breakers: execute children batched, then apply breaker eagerly.
1152        _ => execute_breaker_batched(node),
1153    }
1154}
1155
1156/// Execute a pipeline-breaking node by first executing its children via
1157/// batch processing, then applying the breaker operation eagerly.
1158fn execute_breaker_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
1159    match node {
1160        ViewNode::Arrange { input, keys } => {
1161            let frame = execute_batched(*input)?;
1162            let view = frame.view();
1163            let arranged = view.arrange(&keys)?;
1164            let df = arranged.materialize()?;
1165            Ok(TidyFrame::from_df(df))
1166        }
1167
1168        ViewNode::GroupSummarise {
1169            input,
1170            group_keys,
1171            aggregations,
1172        } => {
1173            let frame = execute_batched(*input)?;
1174            let view = frame.view();
1175            let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
1176            let grouped = view.group_by(&key_refs)?;
1177            let agg_refs: Vec<(&str, TidyAgg)> = aggregations
1178                .into_iter()
1179                .map(|(name, agg)| (leaked_str(&name), agg))
1180                .collect();
1181            let result = grouped.summarise(
1182                &agg_refs
1183                    .iter()
1184                    .map(|(n, a)| (*n, a.clone()))
1185                    .collect::<Vec<_>>(),
1186            )?;
1187            Ok(result)
1188        }
1189
1190        ViewNode::Distinct { input, columns } => {
1191            let frame = execute_batched(*input)?;
1192            let view = frame.view();
1193            let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
1194            let distinct = view.distinct(&col_refs)?;
1195            let df = distinct.materialize()?;
1196            Ok(TidyFrame::from_df(df))
1197        }
1198
1199        ViewNode::Join {
1200            left,
1201            right,
1202            on,
1203            kind,
1204        } => {
1205            let left_frame = execute_batched(*left)?;
1206            let right_frame = execute_batched(*right)?;
1207            let left_view = left_frame.view();
1208            let right_view = right_frame.view();
1209            let on_refs: Vec<(&str, &str)> =
1210                on.iter().map(|(l, r)| (l.as_str(), r.as_str())).collect();
1211
1212            match kind {
1213                JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
1214                JoinType::Left => left_view.left_join(&right_view, &on_refs),
1215                JoinType::Semi => {
1216                    let result = left_view.semi_join(&right_view, &on_refs)?;
1217                    let df = result.materialize()?;
1218                    Ok(TidyFrame::from_df(df))
1219                }
1220                JoinType::Anti => {
1221                    let result = left_view.anti_join(&right_view, &on_refs)?;
1222                    let df = result.materialize()?;
1223                    Ok(TidyFrame::from_df(df))
1224                }
1225            }
1226        }
1227
1228        // Non-breakers should not reach here, but handle gracefully.
1229        other => execute(other),
1230    }
1231}
1232
1233impl LazyView {
1234    /// Optimize and execute the plan using batch processing.
1235    ///
1236    /// This is an alternative to `collect()` that processes data in
1237    /// 2048-row batches, fusing chains of streamable operations
1238    /// (Filter, Select, Mutate) into a single pass per batch.
1239    ///
1240    /// Pipeline breakers (Arrange, GroupSummarise, Distinct, Join) cause
1241    /// full materialization at that point.
1242    ///
1243    /// The output is identical to `collect()` -- this is purely an
1244    /// execution strategy optimization.
1245    pub fn collect_batched(self) -> Result<TidyFrame, TidyError> {
1246        let optimized = optimize(self.plan);
1247        execute_batched(optimized)
1248    }
1249}
1250
1251// ══════════════════════════════════════════════════════════════════════════════
1252// Tests
1253// ══════════════════════════════════════════════════════════════════════════════
1254
1255#[cfg(test)]
1256mod tests {
1257    use super::*;
1258    use crate::{Column, DExpr, DBinOp, DataFrame, TidyAgg, ArrangeKey, TidyView};
1259
1260    /// Build a small test DataFrame: name(Str), age(Int), score(Float).
1261    fn test_df() -> DataFrame {
1262        DataFrame {
1263            columns: vec![
1264                (
1265                    "name".to_string(),
1266                    Column::Str(vec![
1267                        "Alice".into(),
1268                        "Bob".into(),
1269                        "Carol".into(),
1270                        "Dave".into(),
1271                    ]),
1272                ),
1273                ("age".to_string(), Column::Int(vec![30, 25, 35, 25])),
1274                (
1275                    "score".to_string(),
1276                    Column::Float(vec![90.0, 85.0, 95.0, 80.0]),
1277                ),
1278            ],
1279        }
1280    }
1281
1282    /// Build a second DataFrame for join tests: name(Str), dept(Str).
1283    fn dept_df() -> DataFrame {
1284        DataFrame {
1285            columns: vec![
1286                (
1287                    "name".to_string(),
1288                    Column::Str(vec!["Alice".into(), "Bob".into(), "Eve".into()]),
1289                ),
1290                (
1291                    "dept".to_string(),
1292                    Column::Str(vec!["Eng".into(), "Sales".into(), "Eng".into()]),
1293                ),
1294            ],
1295        }
1296    }
1297
1298    // ── Basic lazy chain produces same result as eager ────────────────────
1299
1300    #[test]
1301    fn lazy_filter_matches_eager() {
1302        let df = test_df();
1303        let predicate = DExpr::BinOp {
1304            op: DBinOp::Gt,
1305            left: Box::new(DExpr::Col("age".into())),
1306            right: Box::new(DExpr::LitInt(25)),
1307        };
1308
1309        // Eager
1310        let eager_view = TidyView::from_df(df.clone());
1311        let eager_filtered = eager_view.filter(&predicate).unwrap();
1312        let eager_df = eager_filtered.materialize().unwrap();
1313
1314        // Lazy
1315        let lazy_frame = LazyView::from_df(df)
1316            .filter(predicate)
1317            .collect()
1318            .unwrap();
1319        let lazy_df = lazy_frame.borrow();
1320
1321        assert_eq!(eager_df.nrows(), lazy_df.nrows());
1322        assert_eq!(eager_df.nrows(), 2); // Alice(30) and Carol(35)
1323
1324        // Verify same names
1325        let eager_names: Vec<String> = match eager_df.get_column("name").unwrap() {
1326            Column::Str(v) => v.clone(),
1327            _ => panic!("expected Str"),
1328        };
1329        let lazy_names: Vec<String> = match lazy_df.get_column("name").unwrap() {
1330            Column::Str(v) => v.clone(),
1331            _ => panic!("expected Str"),
1332        };
1333        assert_eq!(eager_names, lazy_names);
1334    }
1335
1336    #[test]
1337    fn lazy_select_matches_eager() {
1338        let df = test_df();
1339
1340        // Eager
1341        let eager_view = TidyView::from_df(df.clone());
1342        let eager_selected = eager_view.select(&["name", "age"]).unwrap();
1343        let eager_df = eager_selected.materialize().unwrap();
1344
1345        // Lazy
1346        let lazy_frame = LazyView::from_df(df)
1347            .select(vec!["name".into(), "age".into()])
1348            .collect()
1349            .unwrap();
1350        let lazy_df = lazy_frame.borrow();
1351
1352        assert_eq!(eager_df.ncols(), 2);
1353        assert_eq!(lazy_df.ncols(), 2);
1354        assert_eq!(eager_df.column_names(), lazy_df.column_names());
1355    }
1356
1357    #[test]
1358    fn lazy_arrange_matches_eager() {
1359        let df = test_df();
1360        let keys = vec![ArrangeKey::asc("age")];
1361
1362        // Eager
1363        let eager_view = TidyView::from_df(df.clone());
1364        let eager_arranged = eager_view.arrange(&keys).unwrap();
1365        let eager_df = eager_arranged.materialize().unwrap();
1366
1367        // Lazy
1368        let lazy_frame = LazyView::from_df(df)
1369            .arrange(keys)
1370            .collect()
1371            .unwrap();
1372        let lazy_df = lazy_frame.borrow();
1373
1374        let eager_ages = match eager_df.get_column("age").unwrap() {
1375            Column::Int(v) => v.clone(),
1376            _ => panic!("expected Int"),
1377        };
1378        let lazy_ages = match lazy_df.get_column("age").unwrap() {
1379            Column::Int(v) => v.clone(),
1380            _ => panic!("expected Int"),
1381        };
1382        assert_eq!(eager_ages, lazy_ages);
1383        // Should be sorted ascending
1384        assert_eq!(eager_ages, vec![25, 25, 30, 35]);
1385    }
1386
1387    #[test]
1388    fn lazy_group_summarise_matches_eager() {
1389        let df = test_df();
1390
1391        // Eager
1392        let eager_view = TidyView::from_df(df.clone());
1393        let grouped = eager_view.group_by(&["age"]).unwrap();
1394        let eager_frame = grouped
1395            .summarise(&[("count", TidyAgg::Count)])
1396            .unwrap();
1397        let eager_df = eager_frame.borrow();
1398
1399        // Lazy
1400        let lazy_frame = LazyView::from_df(df)
1401            .group_summarise(
1402                vec!["age".into()],
1403                vec![("count".into(), TidyAgg::Count)],
1404            )
1405            .collect()
1406            .unwrap();
1407        let lazy_df = lazy_frame.borrow();
1408
1409        assert_eq!(eager_df.nrows(), lazy_df.nrows());
1410        assert_eq!(eager_df.column_names(), lazy_df.column_names());
1411    }
1412
1413    // ── Predicate pushdown ───────────────────────────────────────────────
1414
1415    #[test]
1416    fn predicate_pushdown_past_select() {
1417        let df = test_df();
1418        let predicate = DExpr::BinOp {
1419            op: DBinOp::Gt,
1420            left: Box::new(DExpr::Col("age".into())),
1421            right: Box::new(DExpr::LitInt(25)),
1422        };
1423
1424        // Build: Scan -> Select -> Filter
1425        let lazy = LazyView::from_df(df)
1426            .select(vec!["name".into(), "age".into()])
1427            .filter(predicate);
1428
1429        let optimized = lazy.optimized_plan();
1430
1431        // After pushdown, the filter should be below the select.
1432        // Plan should be: Select -> Filter -> Scan
1433        let kinds = optimized.node_kinds();
1434        assert_eq!(kinds, vec!["Select", "Filter", "Scan"]);
1435    }
1436
1437    #[test]
1438    fn predicate_pushdown_past_arrange() {
1439        let df = test_df();
1440        let predicate = DExpr::BinOp {
1441            op: DBinOp::Gt,
1442            left: Box::new(DExpr::Col("age".into())),
1443            right: Box::new(DExpr::LitInt(25)),
1444        };
1445
1446        // Build: Scan -> Arrange -> Filter
1447        let lazy = LazyView::from_df(df)
1448            .arrange(vec![ArrangeKey::asc("age")])
1449            .filter(predicate);
1450
1451        let optimized = lazy.optimized_plan();
1452
1453        // Filter should be pushed below Arrange.
1454        let kinds = optimized.node_kinds();
1455        assert_eq!(kinds, vec!["Arrange", "Filter", "Scan"]);
1456    }
1457
1458    #[test]
1459    fn predicate_not_pushed_past_mutate_when_dependent() {
1460        let df = test_df();
1461        // Mutate adds "doubled_age" = age * 2
1462        // Filter on "doubled_age" > 50 -- references mutated column, cannot push.
1463        let predicate = DExpr::BinOp {
1464            op: DBinOp::Gt,
1465            left: Box::new(DExpr::Col("doubled_age".into())),
1466            right: Box::new(DExpr::LitInt(50)),
1467        };
1468
1469        let lazy = LazyView::from_df(df)
1470            .mutate(vec![(
1471                "doubled_age".into(),
1472                DExpr::BinOp {
1473                    op: DBinOp::Mul,
1474                    left: Box::new(DExpr::Col("age".into())),
1475                    right: Box::new(DExpr::LitInt(2)),
1476                },
1477            )])
1478            .filter(predicate);
1479
1480        let optimized = lazy.optimized_plan();
1481
1482        // Filter should stay ABOVE Mutate (cannot push).
1483        let kinds = optimized.node_kinds();
1484        assert_eq!(kinds, vec!["Filter", "Mutate", "Scan"]);
1485    }
1486
1487    #[test]
1488    fn predicate_pushed_past_mutate_when_independent() {
1489        let df = test_df();
1490        // Mutate adds "doubled_age" = age * 2
1491        // Filter on "score" > 85 -- does NOT reference mutated column, can push.
1492        let predicate = DExpr::BinOp {
1493            op: DBinOp::Gt,
1494            left: Box::new(DExpr::Col("score".into())),
1495            right: Box::new(DExpr::LitFloat(85.0)),
1496        };
1497
1498        let lazy = LazyView::from_df(df)
1499            .mutate(vec![(
1500                "doubled_age".into(),
1501                DExpr::BinOp {
1502                    op: DBinOp::Mul,
1503                    left: Box::new(DExpr::Col("age".into())),
1504                    right: Box::new(DExpr::LitInt(2)),
1505                },
1506            )])
1507            .filter(predicate);
1508
1509        let optimized = lazy.optimized_plan();
1510
1511        // Filter should be pushed below Mutate.
1512        let kinds = optimized.node_kinds();
1513        assert_eq!(kinds, vec!["Mutate", "Filter", "Scan"]);
1514    }
1515
1516    #[test]
1517    fn predicate_not_pushed_past_group_summarise() {
1518        let df = test_df();
1519        let predicate = DExpr::BinOp {
1520            op: DBinOp::Gt,
1521            left: Box::new(DExpr::Col("count".into())),
1522            right: Box::new(DExpr::LitInt(1)),
1523        };
1524
1525        let lazy = LazyView::from_df(df)
1526            .group_summarise(
1527                vec!["age".into()],
1528                vec![("count".into(), TidyAgg::Count)],
1529            )
1530            .filter(predicate);
1531
1532        let optimized = lazy.optimized_plan();
1533
1534        // Filter must stay above GroupSummarise.
1535        let kinds = optimized.node_kinds();
1536        assert_eq!(kinds, vec!["Filter", "GroupSummarise", "Scan"]);
1537    }
1538
1539    // ── Filter merging ───────────────────────────────────────────────────
1540
1541    #[test]
1542    fn consecutive_filters_merged() {
1543        let df = test_df();
1544        let pred1 = DExpr::BinOp {
1545            op: DBinOp::Gt,
1546            left: Box::new(DExpr::Col("age".into())),
1547            right: Box::new(DExpr::LitInt(20)),
1548        };
1549        let pred2 = DExpr::BinOp {
1550            op: DBinOp::Lt,
1551            left: Box::new(DExpr::Col("score".into())),
1552            right: Box::new(DExpr::LitFloat(95.0)),
1553        };
1554
1555        let lazy = LazyView::from_df(df).filter(pred1).filter(pred2);
1556
1557        let optimized = lazy.optimized_plan();
1558
1559        // Should have only 1 filter node (merged), not 2.
1560        assert_eq!(optimized.count_filters(), 1);
1561
1562        // The merged filter should produce correct results.
1563        // age > 20 AND score < 95 => Alice(30,90), Bob(25,85), Dave(25,80)
1564        let df2 = test_df();
1565        let result = LazyView::from_df(df2)
1566            .filter(DExpr::BinOp {
1567                op: DBinOp::Gt,
1568                left: Box::new(DExpr::Col("age".into())),
1569                right: Box::new(DExpr::LitInt(20)),
1570            })
1571            .filter(DExpr::BinOp {
1572                op: DBinOp::Lt,
1573                left: Box::new(DExpr::Col("score".into())),
1574                right: Box::new(DExpr::LitFloat(95.0)),
1575            })
1576            .collect()
1577            .unwrap();
1578
1579        let result_df = result.borrow();
1580        assert_eq!(result_df.nrows(), 3);
1581    }
1582
1583    // ── Redundant select elimination ─────────────────────────────────────
1584
1585    #[test]
1586    fn redundant_select_eliminated() {
1587        let df = test_df();
1588
1589        // Select all 3 columns (same as input) -- should be eliminated.
1590        let lazy = LazyView::from_df(df)
1591            .select(vec!["name".into(), "age".into(), "score".into()]);
1592
1593        let optimized = lazy.optimized_plan();
1594
1595        // Should be just a Scan (redundant Select removed).
1596        assert_eq!(optimized.kind(), "Scan");
1597    }
1598
1599    #[test]
1600    fn non_redundant_select_kept() {
1601        let df = test_df();
1602
1603        // Select only 2 of 3 columns -- should NOT be eliminated.
1604        let lazy = LazyView::from_df(df).select(vec!["name".into(), "age".into()]);
1605
1606        let optimized = lazy.optimized_plan();
1607
1608        assert_eq!(optimized.kind(), "Select");
1609    }
1610
1611    // ── Determinism ──────────────────────────────────────────────────────
1612
1613    #[test]
1614    fn determinism_3_runs_identical() {
1615        for _ in 0..3 {
1616            let df = test_df();
1617            let result = LazyView::from_df(df)
1618                .filter(DExpr::BinOp {
1619                    op: DBinOp::Gt,
1620                    left: Box::new(DExpr::Col("age".into())),
1621                    right: Box::new(DExpr::LitInt(20)),
1622                })
1623                .select(vec!["name".into(), "age".into()])
1624                .arrange(vec![ArrangeKey::desc("age")])
1625                .collect()
1626                .unwrap();
1627
1628            let result_df = result.borrow();
1629            assert_eq!(result_df.nrows(), 4);
1630
1631            let ages = match result_df.get_column("age").unwrap() {
1632                Column::Int(v) => v.clone(),
1633                _ => panic!("expected Int"),
1634            };
1635            // Descending: 35, 30, 25, 25
1636            assert_eq!(ages, vec![35, 30, 25, 25]);
1637
1638            let names = match result_df.get_column("name").unwrap() {
1639                Column::Str(v) => v.clone(),
1640                _ => panic!("expected Str"),
1641            };
1642            assert_eq!(names, vec!["Carol", "Alice", "Bob", "Dave"]);
1643        }
1644    }
1645
1646    // ── Join execution ───────────────────────────────────────────────────
1647
1648    #[test]
1649    fn lazy_inner_join() {
1650        let left = test_df();
1651        let right = dept_df();
1652
1653        let result = LazyView::from_df(left)
1654            .join(
1655                LazyView::from_df(right),
1656                vec![("name".into(), "name".into())],
1657                JoinType::Inner,
1658            )
1659            .collect()
1660            .unwrap();
1661
1662        let result_df = result.borrow();
1663        // Only Alice and Bob match
1664        assert_eq!(result_df.nrows(), 2);
1665        assert!(result_df.get_column("dept").is_some());
1666    }
1667
1668    #[test]
1669    fn lazy_semi_join() {
1670        let left = test_df();
1671        let right = dept_df();
1672
1673        let result = LazyView::from_df(left)
1674            .join(
1675                LazyView::from_df(right),
1676                vec![("name".into(), "name".into())],
1677                JoinType::Semi,
1678            )
1679            .collect()
1680            .unwrap();
1681
1682        let result_df = result.borrow();
1683        // Semi join: Alice and Bob from left
1684        assert_eq!(result_df.nrows(), 2);
1685        // Semi join should NOT include right columns
1686        assert!(result_df.get_column("dept").is_none());
1687    }
1688
1689    #[test]
1690    fn lazy_anti_join() {
1691        let left = test_df();
1692        let right = dept_df();
1693
1694        let result = LazyView::from_df(left)
1695            .join(
1696                LazyView::from_df(right),
1697                vec![("name".into(), "name".into())],
1698                JoinType::Anti,
1699            )
1700            .collect()
1701            .unwrap();
1702
1703        let result_df = result.borrow();
1704        // Anti join: Carol and Dave (not in right)
1705        assert_eq!(result_df.nrows(), 2);
1706    }
1707
1708    // ── Distinct ─────────────────────────────────────────────────────────
1709
1710    #[test]
1711    fn lazy_distinct() {
1712        let df = test_df();
1713
1714        let result = LazyView::from_df(df)
1715            .distinct(vec!["age".into()])
1716            .collect()
1717            .unwrap();
1718
1719        let result_df = result.borrow();
1720        // 3 distinct ages: 25, 30, 35
1721        assert_eq!(result_df.nrows(), 3);
1722    }
1723
1724    // ── Complex chain ────────────────────────────────────────────────────
1725
1726    #[test]
1727    fn complex_lazy_chain() {
1728        let df = test_df();
1729
1730        // filter(age > 20) -> mutate(bonus = score * 1.1) -> select(name, bonus) -> arrange(bonus desc)
1731        let result = LazyView::from_df(df)
1732            .filter(DExpr::BinOp {
1733                op: DBinOp::Gt,
1734                left: Box::new(DExpr::Col("age".into())),
1735                right: Box::new(DExpr::LitInt(20)),
1736            })
1737            .mutate(vec![(
1738                "bonus".into(),
1739                DExpr::BinOp {
1740                    op: DBinOp::Mul,
1741                    left: Box::new(DExpr::Col("score".into())),
1742                    right: Box::new(DExpr::LitFloat(1.1)),
1743                },
1744            )])
1745            .select(vec!["name".into(), "bonus".into()])
1746            .arrange(vec![ArrangeKey::desc("bonus")])
1747            .collect()
1748            .unwrap();
1749
1750        let result_df = result.borrow();
1751        assert_eq!(result_df.nrows(), 4);
1752        assert_eq!(result_df.ncols(), 2);
1753        assert_eq!(result_df.column_names(), vec!["name", "bonus"]);
1754    }
1755
1756    // ── Predicate pushdown into join ─────────────────────────────────────
1757
1758    #[test]
1759    fn predicate_pushdown_into_join_left_side() {
1760        let left = test_df();
1761        let right = dept_df();
1762
1763        // Join then filter on "age" > 25 -- "age" only exists in left.
1764        let lazy = LazyView::from_df(left)
1765            .join(
1766                LazyView::from_df(right),
1767                vec![("name".into(), "name".into())],
1768                JoinType::Inner,
1769            )
1770            .filter(DExpr::BinOp {
1771                op: DBinOp::Gt,
1772                left: Box::new(DExpr::Col("age".into())),
1773                right: Box::new(DExpr::LitInt(25)),
1774            });
1775
1776        let optimized = lazy.optimized_plan();
1777
1778        // The filter should be pushed into the left side of the join.
1779        let kinds = optimized.node_kinds();
1780        // Expect: Join -> [Filter -> Scan (left), Scan (right)]
1781        assert_eq!(kinds[0], "Join");
1782        // The left subtree should contain a Filter.
1783        if let ViewNode::Join { left, right, .. } = &optimized {
1784            assert_eq!(left.kind(), "Filter");
1785            assert_eq!(right.kind(), "Scan");
1786        } else {
1787            panic!("expected Join at top");
1788        }
1789    }
1790
1791    // ══════════════════════════════════════════════════════════════════════
1792    // Batch executor tests
1793    // ══════════════════════════════════════════════════════════════════════
1794
1795    /// Helper: compare two DataFrames column-by-column for equality.
1796    fn assert_df_eq(a: &DataFrame, b: &DataFrame, context: &str) {
1797        assert_eq!(
1798            a.nrows(),
1799            b.nrows(),
1800            "{}: nrows differ ({} vs {})",
1801            context,
1802            a.nrows(),
1803            b.nrows()
1804        );
1805        assert_eq!(
1806            a.column_names(),
1807            b.column_names(),
1808            "{}: column names differ",
1809            context
1810        );
1811        for (name_a, col_a) in &a.columns {
1812            let col_b = b.get_column(name_a).unwrap_or_else(|| {
1813                panic!("{}: column '{}' missing in b", context, name_a)
1814            });
1815            assert_col_eq(col_a, col_b, &format!("{} col '{}'", context, name_a));
1816        }
1817    }
1818
1819    fn assert_col_eq(a: &Column, b: &Column, context: &str) {
1820        match (a, b) {
1821            (Column::Int(va), Column::Int(vb)) => assert_eq!(va, vb, "{}", context),
1822            (Column::Float(va), Column::Float(vb)) => {
1823                assert_eq!(va.len(), vb.len(), "{}: float len", context);
1824                for (i, (x, y)) in va.iter().zip(vb.iter()).enumerate() {
1825                    assert!(
1826                        (x - y).abs() < 1e-12,
1827                        "{}: float[{}] {} != {}",
1828                        context,
1829                        i,
1830                        x,
1831                        y
1832                    );
1833                }
1834            }
1835            (Column::Str(va), Column::Str(vb)) => assert_eq!(va, vb, "{}", context),
1836            (Column::Bool(va), Column::Bool(vb)) => assert_eq!(va, vb, "{}", context),
1837            _ => panic!("{}: column type mismatch", context),
1838        }
1839    }
1840
1841    // ── Parity: collect_batched == collect ───────────────────────────────
1842
1843    #[test]
1844    fn batched_filter_parity() {
1845        let predicate = DExpr::BinOp {
1846            op: DBinOp::Gt,
1847            left: Box::new(DExpr::Col("age".into())),
1848            right: Box::new(DExpr::LitInt(25)),
1849        };
1850
1851        let eager = LazyView::from_df(test_df())
1852            .filter(predicate.clone())
1853            .collect()
1854            .unwrap();
1855        let batched = LazyView::from_df(test_df())
1856            .filter(predicate)
1857            .collect_batched()
1858            .unwrap();
1859
1860        assert_df_eq(&eager.borrow(), &batched.borrow(), "filter parity");
1861    }
1862
1863    #[test]
1864    fn batched_select_parity() {
1865        let cols = vec!["name".into(), "score".into()];
1866
1867        let eager = LazyView::from_df(test_df())
1868            .select(cols.clone())
1869            .collect()
1870            .unwrap();
1871        let batched = LazyView::from_df(test_df())
1872            .select(cols)
1873            .collect_batched()
1874            .unwrap();
1875
1876        assert_df_eq(&eager.borrow(), &batched.borrow(), "select parity");
1877    }
1878
1879    #[test]
1880    fn batched_mutate_parity() {
1881        let assignments = vec![(
1882            "doubled".into(),
1883            DExpr::BinOp {
1884                op: DBinOp::Mul,
1885                left: Box::new(DExpr::Col("age".into())),
1886                right: Box::new(DExpr::LitInt(2)),
1887            },
1888        )];
1889
1890        let eager = LazyView::from_df(test_df())
1891            .mutate(assignments.clone())
1892            .collect()
1893            .unwrap();
1894        let batched = LazyView::from_df(test_df())
1895            .mutate(assignments)
1896            .collect_batched()
1897            .unwrap();
1898
1899        assert_df_eq(&eager.borrow(), &batched.borrow(), "mutate parity");
1900    }
1901
1902    #[test]
1903    fn batched_filter_select_mutate_chain_parity() {
1904        let predicate = DExpr::BinOp {
1905            op: DBinOp::Gt,
1906            left: Box::new(DExpr::Col("age".into())),
1907            right: Box::new(DExpr::LitInt(20)),
1908        };
1909        let assignments = vec![(
1910            "bonus".into(),
1911            DExpr::BinOp {
1912                op: DBinOp::Mul,
1913                left: Box::new(DExpr::Col("score".into())),
1914                right: Box::new(DExpr::LitFloat(1.1)),
1915            },
1916        )];
1917
1918        let eager = LazyView::from_df(test_df())
1919            .filter(predicate.clone())
1920            .mutate(assignments.clone())
1921            .select(vec!["name".into(), "bonus".into()])
1922            .collect()
1923            .unwrap();
1924        let batched = LazyView::from_df(test_df())
1925            .filter(predicate)
1926            .mutate(assignments)
1927            .select(vec!["name".into(), "bonus".into()])
1928            .collect_batched()
1929            .unwrap();
1930
1931        assert_df_eq(
1932            &eager.borrow(),
1933            &batched.borrow(),
1934            "filter+mutate+select chain parity",
1935        );
1936    }
1937
1938    #[test]
1939    fn batched_group_summarise_parity() {
1940        let eager = LazyView::from_df(test_df())
1941            .group_summarise(
1942                vec!["age".into()],
1943                vec![("count".into(), TidyAgg::Count)],
1944            )
1945            .collect()
1946            .unwrap();
1947        let batched = LazyView::from_df(test_df())
1948            .group_summarise(
1949                vec!["age".into()],
1950                vec![("count".into(), TidyAgg::Count)],
1951            )
1952            .collect_batched()
1953            .unwrap();
1954
1955        assert_df_eq(
1956            &eager.borrow(),
1957            &batched.borrow(),
1958            "group_summarise parity",
1959        );
1960    }
1961
1962    #[test]
1963    fn batched_arrange_parity() {
1964        let keys = vec![ArrangeKey::asc("age")];
1965
1966        let eager = LazyView::from_df(test_df())
1967            .arrange(keys.clone())
1968            .collect()
1969            .unwrap();
1970        let batched = LazyView::from_df(test_df())
1971            .arrange(keys)
1972            .collect_batched()
1973            .unwrap();
1974
1975        assert_df_eq(&eager.borrow(), &batched.borrow(), "arrange parity");
1976    }
1977
1978    #[test]
1979    fn batched_distinct_parity() {
1980        let eager = LazyView::from_df(test_df())
1981            .distinct(vec!["age".into()])
1982            .collect()
1983            .unwrap();
1984        let batched = LazyView::from_df(test_df())
1985            .distinct(vec!["age".into()])
1986            .collect_batched()
1987            .unwrap();
1988
1989        assert_df_eq(&eager.borrow(), &batched.borrow(), "distinct parity");
1990    }
1991
1992    #[test]
1993    fn batched_join_parity() {
1994        let eager = LazyView::from_df(test_df())
1995            .join(
1996                LazyView::from_df(dept_df()),
1997                vec![("name".into(), "name".into())],
1998                JoinType::Inner,
1999            )
2000            .collect()
2001            .unwrap();
2002        let batched = LazyView::from_df(test_df())
2003            .join(
2004                LazyView::from_df(dept_df()),
2005                vec![("name".into(), "name".into())],
2006                JoinType::Inner,
2007            )
2008            .collect_batched()
2009            .unwrap();
2010
2011        assert_df_eq(&eager.borrow(), &batched.borrow(), "join parity");
2012    }
2013
2014    #[test]
2015    fn batched_complex_pipeline_parity() {
2016        // filter -> mutate -> select -> arrange (has a pipeline breaker at end)
2017        let predicate = DExpr::BinOp {
2018            op: DBinOp::Gt,
2019            left: Box::new(DExpr::Col("age".into())),
2020            right: Box::new(DExpr::LitInt(20)),
2021        };
2022        let assignments = vec![(
2023            "bonus".into(),
2024            DExpr::BinOp {
2025                op: DBinOp::Mul,
2026                left: Box::new(DExpr::Col("score".into())),
2027                right: Box::new(DExpr::LitFloat(1.1)),
2028            },
2029        )];
2030
2031        let eager = LazyView::from_df(test_df())
2032            .filter(predicate.clone())
2033            .mutate(assignments.clone())
2034            .select(vec!["name".into(), "bonus".into()])
2035            .arrange(vec![ArrangeKey::desc("bonus")])
2036            .collect()
2037            .unwrap();
2038        let batched = LazyView::from_df(test_df())
2039            .filter(predicate)
2040            .mutate(assignments)
2041            .select(vec!["name".into(), "bonus".into()])
2042            .arrange(vec![ArrangeKey::desc("bonus")])
2043            .collect_batched()
2044            .unwrap();
2045
2046        assert_df_eq(
2047            &eager.borrow(),
2048            &batched.borrow(),
2049            "complex pipeline parity",
2050        );
2051    }
2052
2053    // ── Determinism: 3 runs identical ───────────────────────────────────
2054
2055    #[test]
2056    fn batched_determinism_3_runs() {
2057        let mut results: Vec<Vec<i64>> = Vec::new();
2058        let mut results_names: Vec<Vec<String>> = Vec::new();
2059
2060        for _ in 0..3 {
2061            let result = LazyView::from_df(test_df())
2062                .filter(DExpr::BinOp {
2063                    op: DBinOp::Gt,
2064                    left: Box::new(DExpr::Col("age".into())),
2065                    right: Box::new(DExpr::LitInt(20)),
2066                })
2067                .select(vec!["name".into(), "age".into()])
2068                .arrange(vec![ArrangeKey::desc("age")])
2069                .collect_batched()
2070                .unwrap();
2071
2072            let df = result.borrow();
2073            let ages = match df.get_column("age").unwrap() {
2074                Column::Int(v) => v.clone(),
2075                _ => panic!("expected Int"),
2076            };
2077            let names = match df.get_column("name").unwrap() {
2078                Column::Str(v) => v.clone(),
2079                _ => panic!("expected Str"),
2080            };
2081            results.push(ages);
2082            results_names.push(names);
2083        }
2084
2085        // All 3 runs must be identical.
2086        assert_eq!(results[0], results[1]);
2087        assert_eq!(results[1], results[2]);
2088        assert_eq!(results_names[0], results_names[1]);
2089        assert_eq!(results_names[1], results_names[2]);
2090        // Verify expected values.
2091        assert_eq!(results[0], vec![35, 30, 25, 25]);
2092        assert_eq!(results_names[0], vec!["Carol", "Alice", "Bob", "Dave"]);
2093    }
2094
2095    // ── Large data: batching actually kicks in (>2048 rows) ─────────────
2096
2097    /// Build a large DataFrame with 10,000 rows.
2098    fn large_df() -> DataFrame {
2099        let n = 10_000usize;
2100        let names: Vec<String> = (0..n).map(|i| format!("user_{}", i)).collect();
2101        let ages: Vec<i64> = (0..n).map(|i| (i % 80) as i64 + 18).collect();
2102        let scores: Vec<f64> = (0..n).map(|i| 50.0 + (i % 50) as f64).collect();
2103        DataFrame {
2104            columns: vec![
2105                ("name".to_string(), Column::Str(names)),
2106                ("age".to_string(), Column::Int(ages)),
2107                ("score".to_string(), Column::Float(scores)),
2108            ],
2109        }
2110    }
2111
2112    #[test]
2113    fn batched_large_data_filter_parity() {
2114        let predicate = DExpr::BinOp {
2115            op: DBinOp::Gt,
2116            left: Box::new(DExpr::Col("age".into())),
2117            right: Box::new(DExpr::LitInt(50)),
2118        };
2119
2120        let eager = LazyView::from_df(large_df())
2121            .filter(predicate.clone())
2122            .collect()
2123            .unwrap();
2124        let batched = LazyView::from_df(large_df())
2125            .filter(predicate)
2126            .collect_batched()
2127            .unwrap();
2128
2129        assert_df_eq(
2130            &eager.borrow(),
2131            &batched.borrow(),
2132            "large data filter parity",
2133        );
2134        // Verify batching actually processed >1 batch.
2135        assert!(eager.borrow().nrows() > 0);
2136    }
2137
2138    #[test]
2139    fn batched_large_data_chain_parity() {
2140        let predicate = DExpr::BinOp {
2141            op: DBinOp::Gt,
2142            left: Box::new(DExpr::Col("age".into())),
2143            right: Box::new(DExpr::LitInt(50)),
2144        };
2145        let assignments = vec![(
2146            "bonus".into(),
2147            DExpr::BinOp {
2148                op: DBinOp::Mul,
2149                left: Box::new(DExpr::Col("score".into())),
2150                right: Box::new(DExpr::LitFloat(1.5)),
2151            },
2152        )];
2153
2154        let eager = LazyView::from_df(large_df())
2155            .filter(predicate.clone())
2156            .mutate(assignments.clone())
2157            .select(vec!["name".into(), "bonus".into()])
2158            .collect()
2159            .unwrap();
2160        let batched = LazyView::from_df(large_df())
2161            .filter(predicate)
2162            .mutate(assignments)
2163            .select(vec!["name".into(), "bonus".into()])
2164            .collect_batched()
2165            .unwrap();
2166
2167        assert_df_eq(
2168            &eager.borrow(),
2169            &batched.borrow(),
2170            "large data chain parity",
2171        );
2172    }
2173
2174    #[test]
2175    fn batched_large_data_determinism() {
2176        let mut prev_ages: Option<Vec<i64>> = None;
2177        for _ in 0..3 {
2178            let result = LazyView::from_df(large_df())
2179                .filter(DExpr::BinOp {
2180                    op: DBinOp::Gt,
2181                    left: Box::new(DExpr::Col("age".into())),
2182                    right: Box::new(DExpr::LitInt(90)),
2183                })
2184                .mutate(vec![(
2185                    "double_age".into(),
2186                    DExpr::BinOp {
2187                        op: DBinOp::Mul,
2188                        left: Box::new(DExpr::Col("age".into())),
2189                        right: Box::new(DExpr::LitInt(2)),
2190                    },
2191                )])
2192                .collect_batched()
2193                .unwrap();
2194
2195            let df = result.borrow();
2196            let ages = match df.get_column("age").unwrap() {
2197                Column::Int(v) => v.clone(),
2198                _ => panic!("expected Int"),
2199            };
2200            if let Some(ref prev) = prev_ages {
2201                assert_eq!(prev, &ages, "determinism: ages differ across runs");
2202            }
2203            prev_ages = Some(ages);
2204        }
2205    }
2206
2207    // ── Batch splitting helper ──────────────────────────────────────────
2208
2209    #[test]
2210    fn split_batches_correct_count() {
2211        let df = large_df();
2212        let batches = split_batches(&df);
2213        // 10000 rows / 2048 = 4 full + 1 partial = 5 batches
2214        assert_eq!(batches.len(), 5);
2215        assert_eq!(batches[0].nrows, 2048);
2216        assert_eq!(batches[1].nrows, 2048);
2217        assert_eq!(batches[2].nrows, 2048);
2218        assert_eq!(batches[3].nrows, 2048);
2219        assert_eq!(batches[4].nrows, 10000 - 4 * 2048); // 1808
2220        let total: usize = batches.iter().map(|b| b.nrows).sum();
2221        assert_eq!(total, 10000);
2222    }
2223
2224    #[test]
2225    fn split_batches_small_df() {
2226        let df = test_df(); // 4 rows
2227        let batches = split_batches(&df);
2228        assert_eq!(batches.len(), 1);
2229        assert_eq!(batches[0].nrows, 4);
2230    }
2231
2232    #[test]
2233    fn merge_batches_roundtrip() {
2234        let df = large_df();
2235        let batches = split_batches(&df);
2236        let merged = merge_batches(batches).unwrap();
2237        assert_df_eq(&df, &merged, "merge roundtrip");
2238    }
2239}