1use crate::{ArrangeKey, Column, DExpr, DBinOp, DataFrame, TidyAgg, TidyError, TidyFrame};
14use std::collections::BTreeSet;
15use std::rc::Rc;
16
17#[derive(Debug, Clone)]
21pub enum ViewNode {
22 Scan { df: Rc<DataFrame> },
24 Filter {
26 input: Box<ViewNode>,
27 predicate: DExpr,
28 },
29 Select {
31 input: Box<ViewNode>,
32 columns: Vec<String>,
33 },
34 Mutate {
36 input: Box<ViewNode>,
37 assignments: Vec<(String, DExpr)>,
38 },
39 Arrange {
41 input: Box<ViewNode>,
42 keys: Vec<ArrangeKey>,
43 },
44 GroupSummarise {
46 input: Box<ViewNode>,
47 group_keys: Vec<String>,
48 aggregations: Vec<(String, TidyAgg)>,
49 },
50 Distinct {
52 input: Box<ViewNode>,
53 columns: Vec<String>,
54 },
55 Join {
57 left: Box<ViewNode>,
58 right: Box<ViewNode>,
59 on: Vec<(String, String)>,
60 kind: JoinType,
61 },
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum JoinType {
67 Inner,
68 Left,
69 Semi,
70 Anti,
71}
72
73pub struct LazyView {
77 plan: ViewNode,
78}
79
80impl LazyView {
81 pub fn from_df(df: DataFrame) -> Self {
83 LazyView {
84 plan: ViewNode::Scan { df: Rc::new(df) },
85 }
86 }
87
88 pub fn from_rc(df: Rc<DataFrame>) -> Self {
90 LazyView {
91 plan: ViewNode::Scan { df },
92 }
93 }
94
95 pub fn filter(self, predicate: DExpr) -> Self {
97 LazyView {
98 plan: ViewNode::Filter {
99 input: Box::new(self.plan),
100 predicate,
101 },
102 }
103 }
104
105 pub fn select(self, columns: Vec<String>) -> Self {
107 LazyView {
108 plan: ViewNode::Select {
109 input: Box::new(self.plan),
110 columns,
111 },
112 }
113 }
114
115 pub fn mutate(self, assignments: Vec<(String, DExpr)>) -> Self {
117 LazyView {
118 plan: ViewNode::Mutate {
119 input: Box::new(self.plan),
120 assignments,
121 },
122 }
123 }
124
125 pub fn arrange(self, keys: Vec<ArrangeKey>) -> Self {
127 LazyView {
128 plan: ViewNode::Arrange {
129 input: Box::new(self.plan),
130 keys,
131 },
132 }
133 }
134
135 pub fn group_summarise(
137 self,
138 group_keys: Vec<String>,
139 aggregations: Vec<(String, TidyAgg)>,
140 ) -> Self {
141 LazyView {
142 plan: ViewNode::GroupSummarise {
143 input: Box::new(self.plan),
144 group_keys,
145 aggregations,
146 },
147 }
148 }
149
150 pub fn distinct(self, columns: Vec<String>) -> Self {
152 LazyView {
153 plan: ViewNode::Distinct {
154 input: Box::new(self.plan),
155 columns,
156 },
157 }
158 }
159
160 pub fn join(self, right: LazyView, on: Vec<(String, String)>, kind: JoinType) -> Self {
162 LazyView {
163 plan: ViewNode::Join {
164 left: Box::new(self.plan),
165 right: Box::new(right.plan),
166 on,
167 kind,
168 },
169 }
170 }
171
172 pub fn collect(self) -> Result<TidyFrame, TidyError> {
174 let optimized = optimize(self.plan);
175 execute(optimized)
176 }
177
178 pub fn plan(&self) -> &ViewNode {
180 &self.plan
181 }
182
183 pub fn optimized_plan(self) -> ViewNode {
185 optimize(self.plan)
186 }
187}
188
189pub fn optimize(plan: ViewNode) -> ViewNode {
196 let plan = merge_filters(plan);
197 let plan = push_predicates_down(plan);
198 let plan = eliminate_redundant_selects(plan);
199 plan
200}
201
202fn merge_filters(plan: ViewNode) -> ViewNode {
208 match plan {
209 ViewNode::Filter { input, predicate } => {
210 let merged_input = merge_filters(*input);
211 match merged_input {
212 ViewNode::Filter {
213 input: inner,
214 predicate: inner_pred,
215 } => {
216 let combined = DExpr::BinOp {
218 op: DBinOp::And,
219 left: Box::new(inner_pred),
220 right: Box::new(predicate),
221 };
222 ViewNode::Filter {
223 input: inner,
224 predicate: combined,
225 }
226 }
227 other => ViewNode::Filter {
228 input: Box::new(other),
229 predicate,
230 },
231 }
232 }
233 ViewNode::Select { input, columns } => ViewNode::Select {
235 input: Box::new(merge_filters(*input)),
236 columns,
237 },
238 ViewNode::Mutate {
239 input,
240 assignments,
241 } => ViewNode::Mutate {
242 input: Box::new(merge_filters(*input)),
243 assignments,
244 },
245 ViewNode::Arrange { input, keys } => ViewNode::Arrange {
246 input: Box::new(merge_filters(*input)),
247 keys,
248 },
249 ViewNode::GroupSummarise {
250 input,
251 group_keys,
252 aggregations,
253 } => ViewNode::GroupSummarise {
254 input: Box::new(merge_filters(*input)),
255 group_keys,
256 aggregations,
257 },
258 ViewNode::Distinct { input, columns } => ViewNode::Distinct {
259 input: Box::new(merge_filters(*input)),
260 columns,
261 },
262 ViewNode::Join {
263 left,
264 right,
265 on,
266 kind,
267 } => ViewNode::Join {
268 left: Box::new(merge_filters(*left)),
269 right: Box::new(merge_filters(*right)),
270 on,
271 kind,
272 },
273 other => other, }
275}
276
277fn push_predicates_down(plan: ViewNode) -> ViewNode {
289 match plan {
290 ViewNode::Filter { input, predicate } => {
291 let optimized_input = push_predicates_down(*input);
292 push_filter_into(optimized_input, predicate)
293 }
294 ViewNode::Select { input, columns } => ViewNode::Select {
296 input: Box::new(push_predicates_down(*input)),
297 columns,
298 },
299 ViewNode::Mutate {
300 input,
301 assignments,
302 } => ViewNode::Mutate {
303 input: Box::new(push_predicates_down(*input)),
304 assignments,
305 },
306 ViewNode::Arrange { input, keys } => ViewNode::Arrange {
307 input: Box::new(push_predicates_down(*input)),
308 keys,
309 },
310 ViewNode::GroupSummarise {
311 input,
312 group_keys,
313 aggregations,
314 } => ViewNode::GroupSummarise {
315 input: Box::new(push_predicates_down(*input)),
316 group_keys,
317 aggregations,
318 },
319 ViewNode::Distinct { input, columns } => ViewNode::Distinct {
320 input: Box::new(push_predicates_down(*input)),
321 columns,
322 },
323 ViewNode::Join {
324 left,
325 right,
326 on,
327 kind,
328 } => ViewNode::Join {
329 left: Box::new(push_predicates_down(*left)),
330 right: Box::new(push_predicates_down(*right)),
331 on,
332 kind,
333 },
334 other => other,
335 }
336}
337
338fn push_filter_into(node: ViewNode, predicate: DExpr) -> ViewNode {
340 match node {
341 ViewNode::Select { input, columns } => ViewNode::Select {
344 input: Box::new(push_filter_into(*input, predicate)),
345 columns,
346 },
347
348 ViewNode::Arrange { input, keys } => ViewNode::Arrange {
350 input: Box::new(push_filter_into(*input, predicate)),
351 keys,
352 },
353
354 ViewNode::Mutate {
357 input,
358 assignments,
359 } => {
360 let pred_cols = expr_columns(&predicate);
361 let mutated_cols: BTreeSet<String> =
362 assignments.iter().map(|(name, _)| name.clone()).collect();
363 let references_mutated = pred_cols.iter().any(|c| mutated_cols.contains(c));
364
365 if references_mutated {
366 ViewNode::Filter {
368 input: Box::new(ViewNode::Mutate {
369 input,
370 assignments,
371 }),
372 predicate,
373 }
374 } else {
375 ViewNode::Mutate {
377 input: Box::new(push_filter_into(*input, predicate)),
378 assignments,
379 }
380 }
381 }
382
383 ViewNode::Join {
386 left,
387 right,
388 on,
389 kind,
390 } => {
391 let pred_cols = expr_columns(&predicate);
392 let left_cols = node_output_columns(&left);
393 let right_cols = node_output_columns(&right);
394
395 let all_in_left = pred_cols.iter().all(|c| left_cols.contains(c));
396 let all_in_right = pred_cols.iter().all(|c| right_cols.contains(c));
397
398 if all_in_left {
399 ViewNode::Join {
400 left: Box::new(push_filter_into(*left, predicate)),
401 right,
402 on,
403 kind,
404 }
405 } else if all_in_right {
406 ViewNode::Join {
407 left,
408 right: Box::new(push_filter_into(*right, predicate)),
409 on,
410 kind,
411 }
412 } else {
413 ViewNode::Filter {
415 input: Box::new(ViewNode::Join {
416 left,
417 right,
418 on,
419 kind,
420 }),
421 predicate,
422 }
423 }
424 }
425
426 other => ViewNode::Filter {
428 input: Box::new(other),
429 predicate,
430 },
431 }
432}
433
434fn eliminate_redundant_selects(plan: ViewNode) -> ViewNode {
439 match plan {
440 ViewNode::Select { input, columns } => {
441 let optimized_input = eliminate_redundant_selects(*input);
442 let input_cols = node_output_columns(&optimized_input);
443
444 let select_set: BTreeSet<&str> = columns.iter().map(|s| s.as_str()).collect();
446 let input_set: BTreeSet<&str> = input_cols.iter().map(|s| s.as_str()).collect();
447
448 if select_set == input_set {
449 optimized_input
450 } else {
451 ViewNode::Select {
452 input: Box::new(optimized_input),
453 columns,
454 }
455 }
456 }
457 ViewNode::Filter { input, predicate } => ViewNode::Filter {
458 input: Box::new(eliminate_redundant_selects(*input)),
459 predicate,
460 },
461 ViewNode::Mutate {
462 input,
463 assignments,
464 } => ViewNode::Mutate {
465 input: Box::new(eliminate_redundant_selects(*input)),
466 assignments,
467 },
468 ViewNode::Arrange { input, keys } => ViewNode::Arrange {
469 input: Box::new(eliminate_redundant_selects(*input)),
470 keys,
471 },
472 ViewNode::GroupSummarise {
473 input,
474 group_keys,
475 aggregations,
476 } => ViewNode::GroupSummarise {
477 input: Box::new(eliminate_redundant_selects(*input)),
478 group_keys,
479 aggregations,
480 },
481 ViewNode::Distinct { input, columns } => ViewNode::Distinct {
482 input: Box::new(eliminate_redundant_selects(*input)),
483 columns,
484 },
485 ViewNode::Join {
486 left,
487 right,
488 on,
489 kind,
490 } => ViewNode::Join {
491 left: Box::new(eliminate_redundant_selects(*left)),
492 right: Box::new(eliminate_redundant_selects(*right)),
493 on,
494 kind,
495 },
496 other => other,
497 }
498}
499
500fn execute(node: ViewNode) -> Result<TidyFrame, TidyError> {
504 match node {
505 ViewNode::Scan { df } => Ok(TidyFrame::from_df((*df).clone())),
506
507 ViewNode::Filter { input, predicate } => {
508 let frame = execute(*input)?;
509 let view = frame.view();
510 let filtered = view.filter(&predicate)?;
511 let df = filtered.materialize()?;
512 Ok(TidyFrame::from_df(df))
513 }
514
515 ViewNode::Select { input, columns } => {
516 let frame = execute(*input)?;
517 let view = frame.view();
518 let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
519 let selected = view.select(&col_refs)?;
520 let df = selected.materialize()?;
521 Ok(TidyFrame::from_df(df))
522 }
523
524 ViewNode::Mutate {
525 input,
526 assignments,
527 } => {
528 let frame = execute(*input)?;
529 let view = frame.view();
530 let assign_refs: Vec<(&str, DExpr)> = assignments
531 .into_iter()
532 .map(|(name, expr)| (leaked_str(&name), expr))
533 .collect();
534 let result = view.mutate(&assign_refs.iter().map(|(n, e)| (*n, e.clone())).collect::<Vec<_>>())?;
536 Ok(result)
537 }
538
539 ViewNode::Arrange { input, keys } => {
540 let frame = execute(*input)?;
541 let view = frame.view();
542 let arranged = view.arrange(&keys)?;
543 let df = arranged.materialize()?;
544 Ok(TidyFrame::from_df(df))
545 }
546
547 ViewNode::GroupSummarise {
548 input,
549 group_keys,
550 aggregations,
551 } => {
552 let frame = execute(*input)?;
553 let view = frame.view();
554 let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
555 let grouped = view.group_by(&key_refs)?;
556 let agg_refs: Vec<(&str, TidyAgg)> = aggregations
557 .into_iter()
558 .map(|(name, agg)| (leaked_str(&name), agg))
559 .collect();
560 let result = grouped.summarise(
561 &agg_refs.iter().map(|(n, a)| (*n, a.clone())).collect::<Vec<_>>(),
562 )?;
563 Ok(result)
564 }
565
566 ViewNode::Distinct { input, columns } => {
567 let frame = execute(*input)?;
568 let view = frame.view();
569 let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
570 let distinct = view.distinct(&col_refs)?;
571 let df = distinct.materialize()?;
572 Ok(TidyFrame::from_df(df))
573 }
574
575 ViewNode::Join {
576 left,
577 right,
578 on,
579 kind,
580 } => {
581 let left_frame = execute(*left)?;
582 let right_frame = execute(*right)?;
583 let left_view = left_frame.view();
584 let right_view = right_frame.view();
585 let on_refs: Vec<(&str, &str)> = on
586 .iter()
587 .map(|(l, r)| (l.as_str(), r.as_str()))
588 .collect();
589
590 match kind {
591 JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
592 JoinType::Left => left_view.left_join(&right_view, &on_refs),
593 JoinType::Semi => {
594 let result = left_view.semi_join(&right_view, &on_refs)?;
595 let df = result.materialize()?;
596 Ok(TidyFrame::from_df(df))
597 }
598 JoinType::Anti => {
599 let result = left_view.anti_join(&right_view, &on_refs)?;
600 let df = result.materialize()?;
601 Ok(TidyFrame::from_df(df))
602 }
603 }
604 }
605 }
606}
607
608fn expr_columns(expr: &DExpr) -> BTreeSet<String> {
612 let mut cols = BTreeSet::new();
613 collect_expr_cols(expr, &mut cols);
614 cols
615}
616
617fn collect_expr_cols(expr: &DExpr, cols: &mut BTreeSet<String>) {
618 match expr {
619 DExpr::Col(name) => {
620 cols.insert(name.clone());
621 }
622 DExpr::BinOp { left, right, .. } => {
623 collect_expr_cols(left, cols);
624 collect_expr_cols(right, cols);
625 }
626 DExpr::Agg(_, inner) => collect_expr_cols(inner, cols),
627 DExpr::FnCall(_, args) => {
628 for arg in args {
629 collect_expr_cols(arg, cols);
630 }
631 }
632 DExpr::CumSum(e)
633 | DExpr::CumProd(e)
634 | DExpr::CumMax(e)
635 | DExpr::CumMin(e)
636 | DExpr::Lag(e, _)
637 | DExpr::Lead(e, _)
638 | DExpr::Rank(e)
639 | DExpr::DenseRank(e) => {
640 collect_expr_cols(e, cols);
641 }
642 DExpr::RollingSum(col, _)
644 | DExpr::RollingMean(col, _)
645 | DExpr::RollingMin(col, _)
646 | DExpr::RollingMax(col, _)
647 | DExpr::RollingVar(col, _)
648 | DExpr::RollingSd(col, _) => {
649 cols.insert(col.clone());
650 }
651 DExpr::LitInt(_)
652 | DExpr::LitFloat(_)
653 | DExpr::LitBool(_)
654 | DExpr::LitStr(_)
655 | DExpr::Count
656 | DExpr::RowNumber => {}
657 }
658}
659
660fn node_output_columns(node: &ViewNode) -> BTreeSet<String> {
667 match node {
668 ViewNode::Scan { df } => df.column_names().into_iter().map(|s| s.to_string()).collect(),
669 ViewNode::Filter { input, .. } => node_output_columns(input),
670 ViewNode::Select { columns, .. } => columns.iter().cloned().collect(),
671 ViewNode::Mutate {
672 input,
673 assignments,
674 } => {
675 let mut cols = node_output_columns(input);
676 for (name, _) in assignments {
677 cols.insert(name.clone());
678 }
679 cols
680 }
681 ViewNode::Arrange { input, .. } => node_output_columns(input),
682 ViewNode::GroupSummarise {
683 group_keys,
684 aggregations,
685 ..
686 } => {
687 let mut cols: BTreeSet<String> = group_keys.iter().cloned().collect();
688 for (name, _) in aggregations {
689 cols.insert(name.clone());
690 }
691 cols
692 }
693 ViewNode::Distinct { input, .. } => node_output_columns(input),
694 ViewNode::Join {
695 left, right, on, ..
696 } => {
697 let mut cols = node_output_columns(left);
698 let right_cols = node_output_columns(right);
699 let left_keys: BTreeSet<&str> = on.iter().map(|(l, _)| l.as_str()).collect();
701 let right_keys: BTreeSet<&str> = on.iter().map(|(_, r)| r.as_str()).collect();
702 for c in &right_cols {
703 if !right_keys.contains(c.as_str()) || !left_keys.contains(c.as_str()) {
704 cols.insert(c.clone());
705 }
706 }
707 cols
708 }
709 }
710}
711
712fn leaked_str(s: &str) -> &'static str {
720 Box::leak(s.to_string().into_boxed_str())
721}
722
723impl ViewNode {
726 pub fn count_filters(&self) -> usize {
728 match self {
729 ViewNode::Filter { input, .. } => 1 + input.count_filters(),
730 ViewNode::Select { input, .. } => input.count_filters(),
731 ViewNode::Mutate { input, .. } => input.count_filters(),
732 ViewNode::Arrange { input, .. } => input.count_filters(),
733 ViewNode::GroupSummarise { input, .. } => input.count_filters(),
734 ViewNode::Distinct { input, .. } => input.count_filters(),
735 ViewNode::Join { left, right, .. } => {
736 left.count_filters() + right.count_filters()
737 }
738 ViewNode::Scan { .. } => 0,
739 }
740 }
741
742 pub fn is_filter_on_scan(&self) -> bool {
745 match self {
746 ViewNode::Filter { input, .. } => matches!(input.as_ref(), ViewNode::Scan { .. }),
747 _ => false,
748 }
749 }
750
751 pub fn innermost(&self) -> &ViewNode {
753 match self {
754 ViewNode::Filter { input, .. }
755 | ViewNode::Select { input, .. }
756 | ViewNode::Mutate { input, .. }
757 | ViewNode::Arrange { input, .. }
758 | ViewNode::GroupSummarise { input, .. }
759 | ViewNode::Distinct { input, .. } => input.innermost(),
760 ViewNode::Join { left, .. } => left.innermost(),
761 ViewNode::Scan { .. } => self,
762 }
763 }
764
765 pub fn kind(&self) -> &'static str {
767 match self {
768 ViewNode::Scan { .. } => "Scan",
769 ViewNode::Filter { .. } => "Filter",
770 ViewNode::Select { .. } => "Select",
771 ViewNode::Mutate { .. } => "Mutate",
772 ViewNode::Arrange { .. } => "Arrange",
773 ViewNode::GroupSummarise { .. } => "GroupSummarise",
774 ViewNode::Distinct { .. } => "Distinct",
775 ViewNode::Join { .. } => "Join",
776 }
777 }
778
779 pub fn node_kinds(&self) -> Vec<&'static str> {
781 let mut out = vec![self.kind()];
782 match self {
783 ViewNode::Filter { input, .. }
784 | ViewNode::Select { input, .. }
785 | ViewNode::Mutate { input, .. }
786 | ViewNode::Arrange { input, .. }
787 | ViewNode::GroupSummarise { input, .. }
788 | ViewNode::Distinct { input, .. } => {
789 out.extend(input.node_kinds());
790 }
791 ViewNode::Join { left, right, .. } => {
792 out.extend(left.node_kinds());
793 out.extend(right.node_kinds());
794 }
795 ViewNode::Scan { .. } => {}
796 }
797 out
798 }
799}
800
801const BATCH_SIZE: usize = 2048;
805
806#[derive(Debug, Clone)]
811pub struct Batch {
812 pub columns: Vec<(String, Column)>,
813 pub nrows: usize,
814}
815
816impl Batch {
817 fn into_dataframe(self) -> DataFrame {
819 DataFrame {
820 columns: self.columns,
821 }
822 }
823
824 fn get_column(&self, name: &str) -> Option<&Column> {
826 self.columns.iter().find(|(n, _)| n == name).map(|(_, c)| c)
827 }
828
829 fn column_names(&self) -> Vec<&str> {
831 self.columns.iter().map(|(n, _)| n.as_str()).collect()
832 }
833}
834
835fn slice_column(col: &Column, start: usize, end: usize) -> Column {
837 match col {
838 Column::Float(v) => Column::Float(v[start..end].to_vec()),
839 Column::Int(v) => Column::Int(v[start..end].to_vec()),
840 Column::Str(v) => Column::Str(v[start..end].to_vec()),
841 Column::Bool(v) => Column::Bool(v[start..end].to_vec()),
842 Column::Categorical { levels, codes } => Column::Categorical {
843 levels: levels.clone(),
844 codes: codes[start..end].to_vec(),
845 },
846 Column::DateTime(v) => Column::DateTime(v[start..end].to_vec()),
847 }
848}
849
850fn split_batches(df: &DataFrame) -> Vec<Batch> {
852 let nrows = df.nrows();
853 if nrows == 0 {
854 return vec![Batch {
855 columns: df.columns.iter().map(|(n, c)| {
856 (n.clone(), slice_column(c, 0, 0))
857 }).collect(),
858 nrows: 0,
859 }];
860 }
861 let mut batches = Vec::new();
862 let mut offset = 0;
863 while offset < nrows {
864 let end = (offset + BATCH_SIZE).min(nrows);
865 let batch_cols = df
866 .columns
867 .iter()
868 .map(|(name, col)| (name.clone(), slice_column(col, offset, end)))
869 .collect();
870 batches.push(Batch {
871 columns: batch_cols,
872 nrows: end - offset,
873 });
874 offset = end;
875 }
876 batches
877}
878
879fn merge_batches(batches: Vec<Batch>) -> Result<DataFrame, TidyError> {
883 if batches.is_empty() {
884 return Ok(DataFrame::new());
885 }
886
887 let schema: Vec<String> = batches[0].column_names().iter().map(|s| s.to_string()).collect();
889 if schema.is_empty() {
890 return Ok(DataFrame::new());
891 }
892
893 let total_rows: usize = batches.iter().map(|b| b.nrows).sum();
895 let mut merged_cols: Vec<(String, Column)> = schema
896 .iter()
897 .map(|name| {
898 let first_col = batches[0].get_column(name).unwrap();
900 let empty = match first_col {
901 Column::Float(_) => Column::Float(Vec::with_capacity(total_rows)),
902 Column::Int(_) => Column::Int(Vec::with_capacity(total_rows)),
903 Column::Str(_) => Column::Str(Vec::with_capacity(total_rows)),
904 Column::Bool(_) => Column::Bool(Vec::with_capacity(total_rows)),
905 Column::Categorical { levels, .. } => Column::Categorical {
906 levels: levels.clone(),
907 codes: Vec::with_capacity(total_rows),
908 },
909 Column::DateTime(_) => Column::DateTime(Vec::with_capacity(total_rows)),
910 };
911 (name.clone(), empty)
912 })
913 .collect();
914
915 for batch in &batches {
917 if batch.nrows == 0 {
918 continue;
919 }
920 for (i, (name, merged_col)) in merged_cols.iter_mut().enumerate() {
921 let batch_col = batch.get_column(name).ok_or_else(|| {
922 TidyError::ColumnNotFound(format!(
923 "batch merge: column '{}' missing in batch (index {})",
924 name, i
925 ))
926 })?;
927 append_column(merged_col, batch_col);
928 }
929 }
930
931 Ok(DataFrame { columns: merged_cols })
932}
933
934fn append_column(dst: &mut Column, src: &Column) {
936 match (dst, src) {
937 (Column::Float(d), Column::Float(s)) => d.extend_from_slice(s),
938 (Column::Int(d), Column::Int(s)) => d.extend_from_slice(s),
939 (Column::Str(d), Column::Str(s)) => d.extend(s.iter().cloned()),
940 (Column::Bool(d), Column::Bool(s)) => d.extend_from_slice(s),
941 (Column::Categorical { codes: d, .. }, Column::Categorical { codes: s, .. }) => {
942 d.extend_from_slice(s);
943 }
944 (Column::DateTime(d), Column::DateTime(s)) => d.extend_from_slice(s),
945 _ => {} }
947}
948
949#[derive(Debug, Clone)]
953enum StreamableOp {
954 Filter { predicate: DExpr },
955 Select { columns: Vec<String> },
956 Mutate { assignments: Vec<(String, DExpr)> },
957}
958
959fn is_pipeline_breaker(node: &ViewNode) -> bool {
961 matches!(
962 node,
963 ViewNode::Arrange { .. }
964 | ViewNode::GroupSummarise { .. }
965 | ViewNode::Distinct { .. }
966 | ViewNode::Join { .. }
967 )
968}
969
970fn collect_streamable_chain(node: ViewNode) -> (Vec<StreamableOp>, Box<ViewNode>) {
976 let mut ops = Vec::new();
977 let mut current = node;
978
979 loop {
980 match current {
981 ViewNode::Filter { input, predicate } => {
982 ops.push(StreamableOp::Filter { predicate });
983 current = *input;
984 }
985 ViewNode::Select { input, columns } => {
986 ops.push(StreamableOp::Select { columns });
987 current = *input;
988 }
989 ViewNode::Mutate { input, assignments } => {
990 ops.push(StreamableOp::Mutate { assignments });
991 current = *input;
992 }
993 other => {
995 ops.reverse();
997 return (ops, Box::new(other));
998 }
999 }
1000 }
1001}
1002
1003fn apply_op_to_batch(batch: Batch, op: &StreamableOp) -> Result<Batch, TidyError> {
1005 match op {
1006 StreamableOp::Filter { predicate } => {
1007 let df = batch.into_dataframe();
1009 if df.nrows() == 0 {
1010 return Ok(Batch {
1011 nrows: 0,
1012 columns: df.columns,
1013 });
1014 }
1015 let frame = TidyFrame::from_df(df);
1016 let view = frame.view();
1017 let filtered = view.filter(predicate)?;
1018 let result_df = filtered.materialize()?;
1019 let nrows = result_df.nrows();
1020 Ok(Batch {
1021 columns: result_df.columns,
1022 nrows,
1023 })
1024 }
1025 StreamableOp::Select { columns } => {
1026 let selected: Vec<(String, Column)> = columns
1028 .iter()
1029 .filter_map(|name| {
1030 batch
1031 .columns
1032 .iter()
1033 .find(|(n, _)| n == name)
1034 .cloned()
1035 })
1036 .collect();
1037 Ok(Batch {
1038 nrows: batch.nrows,
1039 columns: selected,
1040 })
1041 }
1042 StreamableOp::Mutate { assignments } => {
1043 let df = batch.into_dataframe();
1045 let frame = TidyFrame::from_df(df);
1046 let view = frame.view();
1047 let assign_refs: Vec<(&str, DExpr)> = assignments
1048 .iter()
1049 .map(|(name, expr)| (leaked_str(name), expr.clone()))
1050 .collect();
1051 let result = view.mutate(
1052 &assign_refs
1053 .iter()
1054 .map(|(n, e)| (*n, e.clone()))
1055 .collect::<Vec<_>>(),
1056 )?;
1057 let result_df = result.borrow().clone();
1058 let nrows = result_df.nrows();
1059 Ok(Batch {
1060 columns: result_df.columns,
1061 nrows,
1062 })
1063 }
1064 }
1065}
1066
1067fn apply_chain_batched(
1069 frame: &TidyFrame,
1070 chain: &[StreamableOp],
1071) -> Result<TidyFrame, TidyError> {
1072 let df = frame.borrow().clone();
1073 let batches = split_batches(&df);
1074
1075 let mut result_batches: Vec<Batch> = Vec::new();
1076 for batch in batches {
1077 let mut current = batch;
1078 for op in chain {
1079 current = apply_op_to_batch(current, op)?;
1080 }
1081 if current.nrows > 0 {
1082 result_batches.push(current);
1083 }
1084 }
1085
1086 if result_batches.is_empty() {
1087 let empty_df = DataFrame {
1089 columns: df
1090 .columns
1091 .iter()
1092 .map(|(name, col)| {
1093 (name.clone(), slice_column(col, 0, 0))
1094 })
1095 .collect(),
1096 };
1097 let mut result_cols: Option<Vec<String>> = None;
1099 for op in chain {
1100 if let StreamableOp::Select { columns } = op {
1101 result_cols = Some(columns.clone());
1102 }
1103 }
1104 if let Some(cols) = result_cols {
1105 let pruned: Vec<(String, Column)> = cols
1106 .iter()
1107 .filter_map(|name| {
1108 empty_df
1109 .columns
1110 .iter()
1111 .find(|(n, _)| n == name)
1112 .cloned()
1113 })
1114 .collect();
1115 return Ok(TidyFrame::from_df(DataFrame { columns: pruned }));
1116 }
1117 return Ok(TidyFrame::from_df(empty_df));
1118 }
1119
1120 let merged = merge_batches(result_batches)?;
1121 Ok(TidyFrame::from_df(merged))
1122}
1123
1124pub fn execute_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
1136 match &node {
1137 ViewNode::Scan { .. } => execute(node),
1139
1140 _ if !is_pipeline_breaker(&node) => {
1142 let (chain, base) = collect_streamable_chain(node);
1143 if chain.is_empty() {
1144 return execute_batched(*base);
1146 }
1147 let base_frame = execute_batched(*base)?;
1148 apply_chain_batched(&base_frame, &chain)
1149 }
1150
1151 _ => execute_breaker_batched(node),
1153 }
1154}
1155
1156fn execute_breaker_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
1159 match node {
1160 ViewNode::Arrange { input, keys } => {
1161 let frame = execute_batched(*input)?;
1162 let view = frame.view();
1163 let arranged = view.arrange(&keys)?;
1164 let df = arranged.materialize()?;
1165 Ok(TidyFrame::from_df(df))
1166 }
1167
1168 ViewNode::GroupSummarise {
1169 input,
1170 group_keys,
1171 aggregations,
1172 } => {
1173 let frame = execute_batched(*input)?;
1174 let view = frame.view();
1175 let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
1176 let grouped = view.group_by(&key_refs)?;
1177 let agg_refs: Vec<(&str, TidyAgg)> = aggregations
1178 .into_iter()
1179 .map(|(name, agg)| (leaked_str(&name), agg))
1180 .collect();
1181 let result = grouped.summarise(
1182 &agg_refs
1183 .iter()
1184 .map(|(n, a)| (*n, a.clone()))
1185 .collect::<Vec<_>>(),
1186 )?;
1187 Ok(result)
1188 }
1189
1190 ViewNode::Distinct { input, columns } => {
1191 let frame = execute_batched(*input)?;
1192 let view = frame.view();
1193 let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
1194 let distinct = view.distinct(&col_refs)?;
1195 let df = distinct.materialize()?;
1196 Ok(TidyFrame::from_df(df))
1197 }
1198
1199 ViewNode::Join {
1200 left,
1201 right,
1202 on,
1203 kind,
1204 } => {
1205 let left_frame = execute_batched(*left)?;
1206 let right_frame = execute_batched(*right)?;
1207 let left_view = left_frame.view();
1208 let right_view = right_frame.view();
1209 let on_refs: Vec<(&str, &str)> =
1210 on.iter().map(|(l, r)| (l.as_str(), r.as_str())).collect();
1211
1212 match kind {
1213 JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
1214 JoinType::Left => left_view.left_join(&right_view, &on_refs),
1215 JoinType::Semi => {
1216 let result = left_view.semi_join(&right_view, &on_refs)?;
1217 let df = result.materialize()?;
1218 Ok(TidyFrame::from_df(df))
1219 }
1220 JoinType::Anti => {
1221 let result = left_view.anti_join(&right_view, &on_refs)?;
1222 let df = result.materialize()?;
1223 Ok(TidyFrame::from_df(df))
1224 }
1225 }
1226 }
1227
1228 other => execute(other),
1230 }
1231}
1232
1233impl LazyView {
1234 pub fn collect_batched(self) -> Result<TidyFrame, TidyError> {
1246 let optimized = optimize(self.plan);
1247 execute_batched(optimized)
1248 }
1249}
1250
1251#[cfg(test)]
1256mod tests {
1257 use super::*;
1258 use crate::{Column, DExpr, DBinOp, DataFrame, TidyAgg, ArrangeKey, TidyView};
1259
1260 fn test_df() -> DataFrame {
1262 DataFrame {
1263 columns: vec![
1264 (
1265 "name".to_string(),
1266 Column::Str(vec![
1267 "Alice".into(),
1268 "Bob".into(),
1269 "Carol".into(),
1270 "Dave".into(),
1271 ]),
1272 ),
1273 ("age".to_string(), Column::Int(vec![30, 25, 35, 25])),
1274 (
1275 "score".to_string(),
1276 Column::Float(vec![90.0, 85.0, 95.0, 80.0]),
1277 ),
1278 ],
1279 }
1280 }
1281
1282 fn dept_df() -> DataFrame {
1284 DataFrame {
1285 columns: vec![
1286 (
1287 "name".to_string(),
1288 Column::Str(vec!["Alice".into(), "Bob".into(), "Eve".into()]),
1289 ),
1290 (
1291 "dept".to_string(),
1292 Column::Str(vec!["Eng".into(), "Sales".into(), "Eng".into()]),
1293 ),
1294 ],
1295 }
1296 }
1297
1298 #[test]
1301 fn lazy_filter_matches_eager() {
1302 let df = test_df();
1303 let predicate = DExpr::BinOp {
1304 op: DBinOp::Gt,
1305 left: Box::new(DExpr::Col("age".into())),
1306 right: Box::new(DExpr::LitInt(25)),
1307 };
1308
1309 let eager_view = TidyView::from_df(df.clone());
1311 let eager_filtered = eager_view.filter(&predicate).unwrap();
1312 let eager_df = eager_filtered.materialize().unwrap();
1313
1314 let lazy_frame = LazyView::from_df(df)
1316 .filter(predicate)
1317 .collect()
1318 .unwrap();
1319 let lazy_df = lazy_frame.borrow();
1320
1321 assert_eq!(eager_df.nrows(), lazy_df.nrows());
1322 assert_eq!(eager_df.nrows(), 2); let eager_names: Vec<String> = match eager_df.get_column("name").unwrap() {
1326 Column::Str(v) => v.clone(),
1327 _ => panic!("expected Str"),
1328 };
1329 let lazy_names: Vec<String> = match lazy_df.get_column("name").unwrap() {
1330 Column::Str(v) => v.clone(),
1331 _ => panic!("expected Str"),
1332 };
1333 assert_eq!(eager_names, lazy_names);
1334 }
1335
1336 #[test]
1337 fn lazy_select_matches_eager() {
1338 let df = test_df();
1339
1340 let eager_view = TidyView::from_df(df.clone());
1342 let eager_selected = eager_view.select(&["name", "age"]).unwrap();
1343 let eager_df = eager_selected.materialize().unwrap();
1344
1345 let lazy_frame = LazyView::from_df(df)
1347 .select(vec!["name".into(), "age".into()])
1348 .collect()
1349 .unwrap();
1350 let lazy_df = lazy_frame.borrow();
1351
1352 assert_eq!(eager_df.ncols(), 2);
1353 assert_eq!(lazy_df.ncols(), 2);
1354 assert_eq!(eager_df.column_names(), lazy_df.column_names());
1355 }
1356
1357 #[test]
1358 fn lazy_arrange_matches_eager() {
1359 let df = test_df();
1360 let keys = vec![ArrangeKey::asc("age")];
1361
1362 let eager_view = TidyView::from_df(df.clone());
1364 let eager_arranged = eager_view.arrange(&keys).unwrap();
1365 let eager_df = eager_arranged.materialize().unwrap();
1366
1367 let lazy_frame = LazyView::from_df(df)
1369 .arrange(keys)
1370 .collect()
1371 .unwrap();
1372 let lazy_df = lazy_frame.borrow();
1373
1374 let eager_ages = match eager_df.get_column("age").unwrap() {
1375 Column::Int(v) => v.clone(),
1376 _ => panic!("expected Int"),
1377 };
1378 let lazy_ages = match lazy_df.get_column("age").unwrap() {
1379 Column::Int(v) => v.clone(),
1380 _ => panic!("expected Int"),
1381 };
1382 assert_eq!(eager_ages, lazy_ages);
1383 assert_eq!(eager_ages, vec![25, 25, 30, 35]);
1385 }
1386
1387 #[test]
1388 fn lazy_group_summarise_matches_eager() {
1389 let df = test_df();
1390
1391 let eager_view = TidyView::from_df(df.clone());
1393 let grouped = eager_view.group_by(&["age"]).unwrap();
1394 let eager_frame = grouped
1395 .summarise(&[("count", TidyAgg::Count)])
1396 .unwrap();
1397 let eager_df = eager_frame.borrow();
1398
1399 let lazy_frame = LazyView::from_df(df)
1401 .group_summarise(
1402 vec!["age".into()],
1403 vec![("count".into(), TidyAgg::Count)],
1404 )
1405 .collect()
1406 .unwrap();
1407 let lazy_df = lazy_frame.borrow();
1408
1409 assert_eq!(eager_df.nrows(), lazy_df.nrows());
1410 assert_eq!(eager_df.column_names(), lazy_df.column_names());
1411 }
1412
1413 #[test]
1416 fn predicate_pushdown_past_select() {
1417 let df = test_df();
1418 let predicate = DExpr::BinOp {
1419 op: DBinOp::Gt,
1420 left: Box::new(DExpr::Col("age".into())),
1421 right: Box::new(DExpr::LitInt(25)),
1422 };
1423
1424 let lazy = LazyView::from_df(df)
1426 .select(vec!["name".into(), "age".into()])
1427 .filter(predicate);
1428
1429 let optimized = lazy.optimized_plan();
1430
1431 let kinds = optimized.node_kinds();
1434 assert_eq!(kinds, vec!["Select", "Filter", "Scan"]);
1435 }
1436
1437 #[test]
1438 fn predicate_pushdown_past_arrange() {
1439 let df = test_df();
1440 let predicate = DExpr::BinOp {
1441 op: DBinOp::Gt,
1442 left: Box::new(DExpr::Col("age".into())),
1443 right: Box::new(DExpr::LitInt(25)),
1444 };
1445
1446 let lazy = LazyView::from_df(df)
1448 .arrange(vec![ArrangeKey::asc("age")])
1449 .filter(predicate);
1450
1451 let optimized = lazy.optimized_plan();
1452
1453 let kinds = optimized.node_kinds();
1455 assert_eq!(kinds, vec!["Arrange", "Filter", "Scan"]);
1456 }
1457
1458 #[test]
1459 fn predicate_not_pushed_past_mutate_when_dependent() {
1460 let df = test_df();
1461 let predicate = DExpr::BinOp {
1464 op: DBinOp::Gt,
1465 left: Box::new(DExpr::Col("doubled_age".into())),
1466 right: Box::new(DExpr::LitInt(50)),
1467 };
1468
1469 let lazy = LazyView::from_df(df)
1470 .mutate(vec![(
1471 "doubled_age".into(),
1472 DExpr::BinOp {
1473 op: DBinOp::Mul,
1474 left: Box::new(DExpr::Col("age".into())),
1475 right: Box::new(DExpr::LitInt(2)),
1476 },
1477 )])
1478 .filter(predicate);
1479
1480 let optimized = lazy.optimized_plan();
1481
1482 let kinds = optimized.node_kinds();
1484 assert_eq!(kinds, vec!["Filter", "Mutate", "Scan"]);
1485 }
1486
1487 #[test]
1488 fn predicate_pushed_past_mutate_when_independent() {
1489 let df = test_df();
1490 let predicate = DExpr::BinOp {
1493 op: DBinOp::Gt,
1494 left: Box::new(DExpr::Col("score".into())),
1495 right: Box::new(DExpr::LitFloat(85.0)),
1496 };
1497
1498 let lazy = LazyView::from_df(df)
1499 .mutate(vec![(
1500 "doubled_age".into(),
1501 DExpr::BinOp {
1502 op: DBinOp::Mul,
1503 left: Box::new(DExpr::Col("age".into())),
1504 right: Box::new(DExpr::LitInt(2)),
1505 },
1506 )])
1507 .filter(predicate);
1508
1509 let optimized = lazy.optimized_plan();
1510
1511 let kinds = optimized.node_kinds();
1513 assert_eq!(kinds, vec!["Mutate", "Filter", "Scan"]);
1514 }
1515
1516 #[test]
1517 fn predicate_not_pushed_past_group_summarise() {
1518 let df = test_df();
1519 let predicate = DExpr::BinOp {
1520 op: DBinOp::Gt,
1521 left: Box::new(DExpr::Col("count".into())),
1522 right: Box::new(DExpr::LitInt(1)),
1523 };
1524
1525 let lazy = LazyView::from_df(df)
1526 .group_summarise(
1527 vec!["age".into()],
1528 vec![("count".into(), TidyAgg::Count)],
1529 )
1530 .filter(predicate);
1531
1532 let optimized = lazy.optimized_plan();
1533
1534 let kinds = optimized.node_kinds();
1536 assert_eq!(kinds, vec!["Filter", "GroupSummarise", "Scan"]);
1537 }
1538
1539 #[test]
1542 fn consecutive_filters_merged() {
1543 let df = test_df();
1544 let pred1 = DExpr::BinOp {
1545 op: DBinOp::Gt,
1546 left: Box::new(DExpr::Col("age".into())),
1547 right: Box::new(DExpr::LitInt(20)),
1548 };
1549 let pred2 = DExpr::BinOp {
1550 op: DBinOp::Lt,
1551 left: Box::new(DExpr::Col("score".into())),
1552 right: Box::new(DExpr::LitFloat(95.0)),
1553 };
1554
1555 let lazy = LazyView::from_df(df).filter(pred1).filter(pred2);
1556
1557 let optimized = lazy.optimized_plan();
1558
1559 assert_eq!(optimized.count_filters(), 1);
1561
1562 let df2 = test_df();
1565 let result = LazyView::from_df(df2)
1566 .filter(DExpr::BinOp {
1567 op: DBinOp::Gt,
1568 left: Box::new(DExpr::Col("age".into())),
1569 right: Box::new(DExpr::LitInt(20)),
1570 })
1571 .filter(DExpr::BinOp {
1572 op: DBinOp::Lt,
1573 left: Box::new(DExpr::Col("score".into())),
1574 right: Box::new(DExpr::LitFloat(95.0)),
1575 })
1576 .collect()
1577 .unwrap();
1578
1579 let result_df = result.borrow();
1580 assert_eq!(result_df.nrows(), 3);
1581 }
1582
1583 #[test]
1586 fn redundant_select_eliminated() {
1587 let df = test_df();
1588
1589 let lazy = LazyView::from_df(df)
1591 .select(vec!["name".into(), "age".into(), "score".into()]);
1592
1593 let optimized = lazy.optimized_plan();
1594
1595 assert_eq!(optimized.kind(), "Scan");
1597 }
1598
1599 #[test]
1600 fn non_redundant_select_kept() {
1601 let df = test_df();
1602
1603 let lazy = LazyView::from_df(df).select(vec!["name".into(), "age".into()]);
1605
1606 let optimized = lazy.optimized_plan();
1607
1608 assert_eq!(optimized.kind(), "Select");
1609 }
1610
1611 #[test]
1614 fn determinism_3_runs_identical() {
1615 for _ in 0..3 {
1616 let df = test_df();
1617 let result = LazyView::from_df(df)
1618 .filter(DExpr::BinOp {
1619 op: DBinOp::Gt,
1620 left: Box::new(DExpr::Col("age".into())),
1621 right: Box::new(DExpr::LitInt(20)),
1622 })
1623 .select(vec!["name".into(), "age".into()])
1624 .arrange(vec![ArrangeKey::desc("age")])
1625 .collect()
1626 .unwrap();
1627
1628 let result_df = result.borrow();
1629 assert_eq!(result_df.nrows(), 4);
1630
1631 let ages = match result_df.get_column("age").unwrap() {
1632 Column::Int(v) => v.clone(),
1633 _ => panic!("expected Int"),
1634 };
1635 assert_eq!(ages, vec![35, 30, 25, 25]);
1637
1638 let names = match result_df.get_column("name").unwrap() {
1639 Column::Str(v) => v.clone(),
1640 _ => panic!("expected Str"),
1641 };
1642 assert_eq!(names, vec!["Carol", "Alice", "Bob", "Dave"]);
1643 }
1644 }
1645
1646 #[test]
1649 fn lazy_inner_join() {
1650 let left = test_df();
1651 let right = dept_df();
1652
1653 let result = LazyView::from_df(left)
1654 .join(
1655 LazyView::from_df(right),
1656 vec![("name".into(), "name".into())],
1657 JoinType::Inner,
1658 )
1659 .collect()
1660 .unwrap();
1661
1662 let result_df = result.borrow();
1663 assert_eq!(result_df.nrows(), 2);
1665 assert!(result_df.get_column("dept").is_some());
1666 }
1667
1668 #[test]
1669 fn lazy_semi_join() {
1670 let left = test_df();
1671 let right = dept_df();
1672
1673 let result = LazyView::from_df(left)
1674 .join(
1675 LazyView::from_df(right),
1676 vec![("name".into(), "name".into())],
1677 JoinType::Semi,
1678 )
1679 .collect()
1680 .unwrap();
1681
1682 let result_df = result.borrow();
1683 assert_eq!(result_df.nrows(), 2);
1685 assert!(result_df.get_column("dept").is_none());
1687 }
1688
1689 #[test]
1690 fn lazy_anti_join() {
1691 let left = test_df();
1692 let right = dept_df();
1693
1694 let result = LazyView::from_df(left)
1695 .join(
1696 LazyView::from_df(right),
1697 vec![("name".into(), "name".into())],
1698 JoinType::Anti,
1699 )
1700 .collect()
1701 .unwrap();
1702
1703 let result_df = result.borrow();
1704 assert_eq!(result_df.nrows(), 2);
1706 }
1707
1708 #[test]
1711 fn lazy_distinct() {
1712 let df = test_df();
1713
1714 let result = LazyView::from_df(df)
1715 .distinct(vec!["age".into()])
1716 .collect()
1717 .unwrap();
1718
1719 let result_df = result.borrow();
1720 assert_eq!(result_df.nrows(), 3);
1722 }
1723
1724 #[test]
1727 fn complex_lazy_chain() {
1728 let df = test_df();
1729
1730 let result = LazyView::from_df(df)
1732 .filter(DExpr::BinOp {
1733 op: DBinOp::Gt,
1734 left: Box::new(DExpr::Col("age".into())),
1735 right: Box::new(DExpr::LitInt(20)),
1736 })
1737 .mutate(vec![(
1738 "bonus".into(),
1739 DExpr::BinOp {
1740 op: DBinOp::Mul,
1741 left: Box::new(DExpr::Col("score".into())),
1742 right: Box::new(DExpr::LitFloat(1.1)),
1743 },
1744 )])
1745 .select(vec!["name".into(), "bonus".into()])
1746 .arrange(vec![ArrangeKey::desc("bonus")])
1747 .collect()
1748 .unwrap();
1749
1750 let result_df = result.borrow();
1751 assert_eq!(result_df.nrows(), 4);
1752 assert_eq!(result_df.ncols(), 2);
1753 assert_eq!(result_df.column_names(), vec!["name", "bonus"]);
1754 }
1755
1756 #[test]
1759 fn predicate_pushdown_into_join_left_side() {
1760 let left = test_df();
1761 let right = dept_df();
1762
1763 let lazy = LazyView::from_df(left)
1765 .join(
1766 LazyView::from_df(right),
1767 vec![("name".into(), "name".into())],
1768 JoinType::Inner,
1769 )
1770 .filter(DExpr::BinOp {
1771 op: DBinOp::Gt,
1772 left: Box::new(DExpr::Col("age".into())),
1773 right: Box::new(DExpr::LitInt(25)),
1774 });
1775
1776 let optimized = lazy.optimized_plan();
1777
1778 let kinds = optimized.node_kinds();
1780 assert_eq!(kinds[0], "Join");
1782 if let ViewNode::Join { left, right, .. } = &optimized {
1784 assert_eq!(left.kind(), "Filter");
1785 assert_eq!(right.kind(), "Scan");
1786 } else {
1787 panic!("expected Join at top");
1788 }
1789 }
1790
1791 fn assert_df_eq(a: &DataFrame, b: &DataFrame, context: &str) {
1797 assert_eq!(
1798 a.nrows(),
1799 b.nrows(),
1800 "{}: nrows differ ({} vs {})",
1801 context,
1802 a.nrows(),
1803 b.nrows()
1804 );
1805 assert_eq!(
1806 a.column_names(),
1807 b.column_names(),
1808 "{}: column names differ",
1809 context
1810 );
1811 for (name_a, col_a) in &a.columns {
1812 let col_b = b.get_column(name_a).unwrap_or_else(|| {
1813 panic!("{}: column '{}' missing in b", context, name_a)
1814 });
1815 assert_col_eq(col_a, col_b, &format!("{} col '{}'", context, name_a));
1816 }
1817 }
1818
1819 fn assert_col_eq(a: &Column, b: &Column, context: &str) {
1820 match (a, b) {
1821 (Column::Int(va), Column::Int(vb)) => assert_eq!(va, vb, "{}", context),
1822 (Column::Float(va), Column::Float(vb)) => {
1823 assert_eq!(va.len(), vb.len(), "{}: float len", context);
1824 for (i, (x, y)) in va.iter().zip(vb.iter()).enumerate() {
1825 assert!(
1826 (x - y).abs() < 1e-12,
1827 "{}: float[{}] {} != {}",
1828 context,
1829 i,
1830 x,
1831 y
1832 );
1833 }
1834 }
1835 (Column::Str(va), Column::Str(vb)) => assert_eq!(va, vb, "{}", context),
1836 (Column::Bool(va), Column::Bool(vb)) => assert_eq!(va, vb, "{}", context),
1837 _ => panic!("{}: column type mismatch", context),
1838 }
1839 }
1840
1841 #[test]
1844 fn batched_filter_parity() {
1845 let predicate = DExpr::BinOp {
1846 op: DBinOp::Gt,
1847 left: Box::new(DExpr::Col("age".into())),
1848 right: Box::new(DExpr::LitInt(25)),
1849 };
1850
1851 let eager = LazyView::from_df(test_df())
1852 .filter(predicate.clone())
1853 .collect()
1854 .unwrap();
1855 let batched = LazyView::from_df(test_df())
1856 .filter(predicate)
1857 .collect_batched()
1858 .unwrap();
1859
1860 assert_df_eq(&eager.borrow(), &batched.borrow(), "filter parity");
1861 }
1862
1863 #[test]
1864 fn batched_select_parity() {
1865 let cols = vec!["name".into(), "score".into()];
1866
1867 let eager = LazyView::from_df(test_df())
1868 .select(cols.clone())
1869 .collect()
1870 .unwrap();
1871 let batched = LazyView::from_df(test_df())
1872 .select(cols)
1873 .collect_batched()
1874 .unwrap();
1875
1876 assert_df_eq(&eager.borrow(), &batched.borrow(), "select parity");
1877 }
1878
1879 #[test]
1880 fn batched_mutate_parity() {
1881 let assignments = vec![(
1882 "doubled".into(),
1883 DExpr::BinOp {
1884 op: DBinOp::Mul,
1885 left: Box::new(DExpr::Col("age".into())),
1886 right: Box::new(DExpr::LitInt(2)),
1887 },
1888 )];
1889
1890 let eager = LazyView::from_df(test_df())
1891 .mutate(assignments.clone())
1892 .collect()
1893 .unwrap();
1894 let batched = LazyView::from_df(test_df())
1895 .mutate(assignments)
1896 .collect_batched()
1897 .unwrap();
1898
1899 assert_df_eq(&eager.borrow(), &batched.borrow(), "mutate parity");
1900 }
1901
1902 #[test]
1903 fn batched_filter_select_mutate_chain_parity() {
1904 let predicate = DExpr::BinOp {
1905 op: DBinOp::Gt,
1906 left: Box::new(DExpr::Col("age".into())),
1907 right: Box::new(DExpr::LitInt(20)),
1908 };
1909 let assignments = vec![(
1910 "bonus".into(),
1911 DExpr::BinOp {
1912 op: DBinOp::Mul,
1913 left: Box::new(DExpr::Col("score".into())),
1914 right: Box::new(DExpr::LitFloat(1.1)),
1915 },
1916 )];
1917
1918 let eager = LazyView::from_df(test_df())
1919 .filter(predicate.clone())
1920 .mutate(assignments.clone())
1921 .select(vec!["name".into(), "bonus".into()])
1922 .collect()
1923 .unwrap();
1924 let batched = LazyView::from_df(test_df())
1925 .filter(predicate)
1926 .mutate(assignments)
1927 .select(vec!["name".into(), "bonus".into()])
1928 .collect_batched()
1929 .unwrap();
1930
1931 assert_df_eq(
1932 &eager.borrow(),
1933 &batched.borrow(),
1934 "filter+mutate+select chain parity",
1935 );
1936 }
1937
1938 #[test]
1939 fn batched_group_summarise_parity() {
1940 let eager = LazyView::from_df(test_df())
1941 .group_summarise(
1942 vec!["age".into()],
1943 vec![("count".into(), TidyAgg::Count)],
1944 )
1945 .collect()
1946 .unwrap();
1947 let batched = LazyView::from_df(test_df())
1948 .group_summarise(
1949 vec!["age".into()],
1950 vec![("count".into(), TidyAgg::Count)],
1951 )
1952 .collect_batched()
1953 .unwrap();
1954
1955 assert_df_eq(
1956 &eager.borrow(),
1957 &batched.borrow(),
1958 "group_summarise parity",
1959 );
1960 }
1961
1962 #[test]
1963 fn batched_arrange_parity() {
1964 let keys = vec![ArrangeKey::asc("age")];
1965
1966 let eager = LazyView::from_df(test_df())
1967 .arrange(keys.clone())
1968 .collect()
1969 .unwrap();
1970 let batched = LazyView::from_df(test_df())
1971 .arrange(keys)
1972 .collect_batched()
1973 .unwrap();
1974
1975 assert_df_eq(&eager.borrow(), &batched.borrow(), "arrange parity");
1976 }
1977
1978 #[test]
1979 fn batched_distinct_parity() {
1980 let eager = LazyView::from_df(test_df())
1981 .distinct(vec!["age".into()])
1982 .collect()
1983 .unwrap();
1984 let batched = LazyView::from_df(test_df())
1985 .distinct(vec!["age".into()])
1986 .collect_batched()
1987 .unwrap();
1988
1989 assert_df_eq(&eager.borrow(), &batched.borrow(), "distinct parity");
1990 }
1991
1992 #[test]
1993 fn batched_join_parity() {
1994 let eager = LazyView::from_df(test_df())
1995 .join(
1996 LazyView::from_df(dept_df()),
1997 vec![("name".into(), "name".into())],
1998 JoinType::Inner,
1999 )
2000 .collect()
2001 .unwrap();
2002 let batched = LazyView::from_df(test_df())
2003 .join(
2004 LazyView::from_df(dept_df()),
2005 vec![("name".into(), "name".into())],
2006 JoinType::Inner,
2007 )
2008 .collect_batched()
2009 .unwrap();
2010
2011 assert_df_eq(&eager.borrow(), &batched.borrow(), "join parity");
2012 }
2013
2014 #[test]
2015 fn batched_complex_pipeline_parity() {
2016 let predicate = DExpr::BinOp {
2018 op: DBinOp::Gt,
2019 left: Box::new(DExpr::Col("age".into())),
2020 right: Box::new(DExpr::LitInt(20)),
2021 };
2022 let assignments = vec![(
2023 "bonus".into(),
2024 DExpr::BinOp {
2025 op: DBinOp::Mul,
2026 left: Box::new(DExpr::Col("score".into())),
2027 right: Box::new(DExpr::LitFloat(1.1)),
2028 },
2029 )];
2030
2031 let eager = LazyView::from_df(test_df())
2032 .filter(predicate.clone())
2033 .mutate(assignments.clone())
2034 .select(vec!["name".into(), "bonus".into()])
2035 .arrange(vec![ArrangeKey::desc("bonus")])
2036 .collect()
2037 .unwrap();
2038 let batched = LazyView::from_df(test_df())
2039 .filter(predicate)
2040 .mutate(assignments)
2041 .select(vec!["name".into(), "bonus".into()])
2042 .arrange(vec![ArrangeKey::desc("bonus")])
2043 .collect_batched()
2044 .unwrap();
2045
2046 assert_df_eq(
2047 &eager.borrow(),
2048 &batched.borrow(),
2049 "complex pipeline parity",
2050 );
2051 }
2052
2053 #[test]
2056 fn batched_determinism_3_runs() {
2057 let mut results: Vec<Vec<i64>> = Vec::new();
2058 let mut results_names: Vec<Vec<String>> = Vec::new();
2059
2060 for _ in 0..3 {
2061 let result = LazyView::from_df(test_df())
2062 .filter(DExpr::BinOp {
2063 op: DBinOp::Gt,
2064 left: Box::new(DExpr::Col("age".into())),
2065 right: Box::new(DExpr::LitInt(20)),
2066 })
2067 .select(vec!["name".into(), "age".into()])
2068 .arrange(vec![ArrangeKey::desc("age")])
2069 .collect_batched()
2070 .unwrap();
2071
2072 let df = result.borrow();
2073 let ages = match df.get_column("age").unwrap() {
2074 Column::Int(v) => v.clone(),
2075 _ => panic!("expected Int"),
2076 };
2077 let names = match df.get_column("name").unwrap() {
2078 Column::Str(v) => v.clone(),
2079 _ => panic!("expected Str"),
2080 };
2081 results.push(ages);
2082 results_names.push(names);
2083 }
2084
2085 assert_eq!(results[0], results[1]);
2087 assert_eq!(results[1], results[2]);
2088 assert_eq!(results_names[0], results_names[1]);
2089 assert_eq!(results_names[1], results_names[2]);
2090 assert_eq!(results[0], vec![35, 30, 25, 25]);
2092 assert_eq!(results_names[0], vec!["Carol", "Alice", "Bob", "Dave"]);
2093 }
2094
2095 fn large_df() -> DataFrame {
2099 let n = 10_000usize;
2100 let names: Vec<String> = (0..n).map(|i| format!("user_{}", i)).collect();
2101 let ages: Vec<i64> = (0..n).map(|i| (i % 80) as i64 + 18).collect();
2102 let scores: Vec<f64> = (0..n).map(|i| 50.0 + (i % 50) as f64).collect();
2103 DataFrame {
2104 columns: vec![
2105 ("name".to_string(), Column::Str(names)),
2106 ("age".to_string(), Column::Int(ages)),
2107 ("score".to_string(), Column::Float(scores)),
2108 ],
2109 }
2110 }
2111
2112 #[test]
2113 fn batched_large_data_filter_parity() {
2114 let predicate = DExpr::BinOp {
2115 op: DBinOp::Gt,
2116 left: Box::new(DExpr::Col("age".into())),
2117 right: Box::new(DExpr::LitInt(50)),
2118 };
2119
2120 let eager = LazyView::from_df(large_df())
2121 .filter(predicate.clone())
2122 .collect()
2123 .unwrap();
2124 let batched = LazyView::from_df(large_df())
2125 .filter(predicate)
2126 .collect_batched()
2127 .unwrap();
2128
2129 assert_df_eq(
2130 &eager.borrow(),
2131 &batched.borrow(),
2132 "large data filter parity",
2133 );
2134 assert!(eager.borrow().nrows() > 0);
2136 }
2137
2138 #[test]
2139 fn batched_large_data_chain_parity() {
2140 let predicate = DExpr::BinOp {
2141 op: DBinOp::Gt,
2142 left: Box::new(DExpr::Col("age".into())),
2143 right: Box::new(DExpr::LitInt(50)),
2144 };
2145 let assignments = vec![(
2146 "bonus".into(),
2147 DExpr::BinOp {
2148 op: DBinOp::Mul,
2149 left: Box::new(DExpr::Col("score".into())),
2150 right: Box::new(DExpr::LitFloat(1.5)),
2151 },
2152 )];
2153
2154 let eager = LazyView::from_df(large_df())
2155 .filter(predicate.clone())
2156 .mutate(assignments.clone())
2157 .select(vec!["name".into(), "bonus".into()])
2158 .collect()
2159 .unwrap();
2160 let batched = LazyView::from_df(large_df())
2161 .filter(predicate)
2162 .mutate(assignments)
2163 .select(vec!["name".into(), "bonus".into()])
2164 .collect_batched()
2165 .unwrap();
2166
2167 assert_df_eq(
2168 &eager.borrow(),
2169 &batched.borrow(),
2170 "large data chain parity",
2171 );
2172 }
2173
2174 #[test]
2175 fn batched_large_data_determinism() {
2176 let mut prev_ages: Option<Vec<i64>> = None;
2177 for _ in 0..3 {
2178 let result = LazyView::from_df(large_df())
2179 .filter(DExpr::BinOp {
2180 op: DBinOp::Gt,
2181 left: Box::new(DExpr::Col("age".into())),
2182 right: Box::new(DExpr::LitInt(90)),
2183 })
2184 .mutate(vec![(
2185 "double_age".into(),
2186 DExpr::BinOp {
2187 op: DBinOp::Mul,
2188 left: Box::new(DExpr::Col("age".into())),
2189 right: Box::new(DExpr::LitInt(2)),
2190 },
2191 )])
2192 .collect_batched()
2193 .unwrap();
2194
2195 let df = result.borrow();
2196 let ages = match df.get_column("age").unwrap() {
2197 Column::Int(v) => v.clone(),
2198 _ => panic!("expected Int"),
2199 };
2200 if let Some(ref prev) = prev_ages {
2201 assert_eq!(prev, &ages, "determinism: ages differ across runs");
2202 }
2203 prev_ages = Some(ages);
2204 }
2205 }
2206
2207 #[test]
2210 fn split_batches_correct_count() {
2211 let df = large_df();
2212 let batches = split_batches(&df);
2213 assert_eq!(batches.len(), 5);
2215 assert_eq!(batches[0].nrows, 2048);
2216 assert_eq!(batches[1].nrows, 2048);
2217 assert_eq!(batches[2].nrows, 2048);
2218 assert_eq!(batches[3].nrows, 2048);
2219 assert_eq!(batches[4].nrows, 10000 - 4 * 2048); let total: usize = batches.iter().map(|b| b.nrows).sum();
2221 assert_eq!(total, 10000);
2222 }
2223
2224 #[test]
2225 fn split_batches_small_df() {
2226 let df = test_df(); let batches = split_batches(&df);
2228 assert_eq!(batches.len(), 1);
2229 assert_eq!(batches[0].nrows, 4);
2230 }
2231
2232 #[test]
2233 fn merge_batches_roundtrip() {
2234 let df = large_df();
2235 let batches = split_batches(&df);
2236 let merged = merge_batches(batches).unwrap();
2237 assert_df_eq(&df, &merged, "merge roundtrip");
2238 }
2239}