1use crate::{ArrangeKey, Column, DExpr, DBinOp, DataFrame, TidyAgg, TidyError, TidyFrame};
14use std::collections::BTreeSet;
15use std::rc::Rc;
16
17#[derive(Debug, Clone)]
21pub enum ViewNode {
22 Scan { df: Rc<DataFrame> },
24 Filter {
26 input: Box<ViewNode>,
27 predicate: DExpr,
28 },
29 Select {
31 input: Box<ViewNode>,
32 columns: Vec<String>,
33 },
34 Mutate {
36 input: Box<ViewNode>,
37 assignments: Vec<(String, DExpr)>,
38 },
39 Arrange {
41 input: Box<ViewNode>,
42 keys: Vec<ArrangeKey>,
43 },
44 GroupSummarise {
46 input: Box<ViewNode>,
47 group_keys: Vec<String>,
48 aggregations: Vec<(String, TidyAgg)>,
49 },
50 StreamingGroupSummarise {
56 input: Box<ViewNode>,
57 group_keys: Vec<String>,
58 aggregations: Vec<(String, crate::StreamingAgg)>,
59 },
60 Distinct {
62 input: Box<ViewNode>,
63 columns: Vec<String>,
64 },
65 Join {
67 left: Box<ViewNode>,
68 right: Box<ViewNode>,
69 on: Vec<(String, String)>,
70 kind: JoinType,
71 },
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum JoinType {
77 Inner,
78 Left,
79 Semi,
80 Anti,
81}
82
83pub struct LazyView {
87 plan: ViewNode,
88}
89
90impl LazyView {
91 pub fn from_df(df: DataFrame) -> Self {
93 LazyView {
94 plan: ViewNode::Scan { df: Rc::new(df) },
95 }
96 }
97
98 pub fn from_rc(df: Rc<DataFrame>) -> Self {
100 LazyView {
101 plan: ViewNode::Scan { df },
102 }
103 }
104
105 pub fn filter(self, predicate: DExpr) -> Self {
107 LazyView {
108 plan: ViewNode::Filter {
109 input: Box::new(self.plan),
110 predicate,
111 },
112 }
113 }
114
115 pub fn select(self, columns: Vec<String>) -> Self {
117 LazyView {
118 plan: ViewNode::Select {
119 input: Box::new(self.plan),
120 columns,
121 },
122 }
123 }
124
125 pub fn mutate(self, assignments: Vec<(String, DExpr)>) -> Self {
127 LazyView {
128 plan: ViewNode::Mutate {
129 input: Box::new(self.plan),
130 assignments,
131 },
132 }
133 }
134
135 pub fn arrange(self, keys: Vec<ArrangeKey>) -> Self {
137 LazyView {
138 plan: ViewNode::Arrange {
139 input: Box::new(self.plan),
140 keys,
141 },
142 }
143 }
144
145 pub fn group_summarise(
147 self,
148 group_keys: Vec<String>,
149 aggregations: Vec<(String, TidyAgg)>,
150 ) -> Self {
151 LazyView {
152 plan: ViewNode::GroupSummarise {
153 input: Box::new(self.plan),
154 group_keys,
155 aggregations,
156 },
157 }
158 }
159
160 pub fn distinct(self, columns: Vec<String>) -> Self {
162 LazyView {
163 plan: ViewNode::Distinct {
164 input: Box::new(self.plan),
165 columns,
166 },
167 }
168 }
169
170 pub fn join(self, right: LazyView, on: Vec<(String, String)>, kind: JoinType) -> Self {
172 LazyView {
173 plan: ViewNode::Join {
174 left: Box::new(self.plan),
175 right: Box::new(right.plan),
176 on,
177 kind,
178 },
179 }
180 }
181
182 pub fn collect(self) -> Result<TidyFrame, TidyError> {
184 let optimized = optimize(self.plan);
185 execute(optimized)
186 }
187
188 pub fn plan(&self) -> &ViewNode {
190 &self.plan
191 }
192
193 pub fn optimized_plan(self) -> ViewNode {
195 optimize(self.plan)
196 }
197}
198
199pub fn optimize(plan: ViewNode) -> ViewNode {
206 let plan = merge_filters(plan);
207 let plan = push_predicates_down(plan);
208 let plan = eliminate_redundant_selects(plan);
209 let plan = annotate_streamable_summarise(plan);
213 plan
214}
215
216fn try_streaming_agg(agg: &TidyAgg) -> Option<crate::StreamingAgg> {
228 use crate::StreamingAgg;
229 match agg {
230 TidyAgg::Count => Some(StreamingAgg::Count),
231 TidyAgg::Sum(c) => Some(StreamingAgg::Sum(c.clone())),
232 TidyAgg::Mean(c) => Some(StreamingAgg::Mean(c.clone())),
233 TidyAgg::Min(c) => Some(StreamingAgg::Min(c.clone())),
234 TidyAgg::Max(c) => Some(StreamingAgg::Max(c.clone())),
235 TidyAgg::Var(c) => Some(StreamingAgg::Var(c.clone())),
236 TidyAgg::Sd(c) => Some(StreamingAgg::Sd(c.clone())),
237 _ => None,
240 }
241}
242
243fn annotate_streamable_summarise(plan: ViewNode) -> ViewNode {
244 match plan {
245 ViewNode::GroupSummarise {
246 input,
247 group_keys,
248 aggregations,
249 } => {
250 let input = Box::new(annotate_streamable_summarise(*input));
251 let all_streaming: Option<Vec<(String, crate::StreamingAgg)>> = aggregations
255 .iter()
256 .map(|(name, agg)| try_streaming_agg(agg).map(|sa| (name.clone(), sa)))
257 .collect();
258 match all_streaming {
259 Some(streaming_aggs) => ViewNode::StreamingGroupSummarise {
260 input,
261 group_keys,
262 aggregations: streaming_aggs,
263 },
264 None => ViewNode::GroupSummarise {
265 input,
266 group_keys,
267 aggregations,
268 },
269 }
270 }
271 ViewNode::Filter { input, predicate } => ViewNode::Filter {
272 input: Box::new(annotate_streamable_summarise(*input)),
273 predicate,
274 },
275 ViewNode::Select { input, columns } => ViewNode::Select {
276 input: Box::new(annotate_streamable_summarise(*input)),
277 columns,
278 },
279 ViewNode::Mutate { input, assignments } => ViewNode::Mutate {
280 input: Box::new(annotate_streamable_summarise(*input)),
281 assignments,
282 },
283 ViewNode::Arrange { input, keys } => ViewNode::Arrange {
284 input: Box::new(annotate_streamable_summarise(*input)),
285 keys,
286 },
287 ViewNode::Distinct { input, columns } => ViewNode::Distinct {
288 input: Box::new(annotate_streamable_summarise(*input)),
289 columns,
290 },
291 ViewNode::Join {
292 left,
293 right,
294 on,
295 kind,
296 } => ViewNode::Join {
297 left: Box::new(annotate_streamable_summarise(*left)),
298 right: Box::new(annotate_streamable_summarise(*right)),
299 on,
300 kind,
301 },
302 ViewNode::StreamingGroupSummarise { .. } => plan,
303 ViewNode::Scan { .. } => plan,
304 }
305}
306
307fn merge_filters(plan: ViewNode) -> ViewNode {
313 match plan {
314 ViewNode::Filter { input, predicate } => {
315 let merged_input = merge_filters(*input);
316 match merged_input {
317 ViewNode::Filter {
318 input: inner,
319 predicate: inner_pred,
320 } => {
321 let combined = DExpr::BinOp {
323 op: DBinOp::And,
324 left: Box::new(inner_pred),
325 right: Box::new(predicate),
326 };
327 ViewNode::Filter {
328 input: inner,
329 predicate: combined,
330 }
331 }
332 other => ViewNode::Filter {
333 input: Box::new(other),
334 predicate,
335 },
336 }
337 }
338 ViewNode::Select { input, columns } => ViewNode::Select {
340 input: Box::new(merge_filters(*input)),
341 columns,
342 },
343 ViewNode::Mutate {
344 input,
345 assignments,
346 } => ViewNode::Mutate {
347 input: Box::new(merge_filters(*input)),
348 assignments,
349 },
350 ViewNode::Arrange { input, keys } => ViewNode::Arrange {
351 input: Box::new(merge_filters(*input)),
352 keys,
353 },
354 ViewNode::GroupSummarise {
355 input,
356 group_keys,
357 aggregations,
358 } => ViewNode::GroupSummarise {
359 input: Box::new(merge_filters(*input)),
360 group_keys,
361 aggregations,
362 },
363 ViewNode::Distinct { input, columns } => ViewNode::Distinct {
364 input: Box::new(merge_filters(*input)),
365 columns,
366 },
367 ViewNode::Join {
368 left,
369 right,
370 on,
371 kind,
372 } => ViewNode::Join {
373 left: Box::new(merge_filters(*left)),
374 right: Box::new(merge_filters(*right)),
375 on,
376 kind,
377 },
378 other => other, }
380}
381
382fn push_predicates_down(plan: ViewNode) -> ViewNode {
394 match plan {
395 ViewNode::Filter { input, predicate } => {
396 let optimized_input = push_predicates_down(*input);
397 push_filter_into(optimized_input, predicate)
398 }
399 ViewNode::Select { input, columns } => ViewNode::Select {
401 input: Box::new(push_predicates_down(*input)),
402 columns,
403 },
404 ViewNode::Mutate {
405 input,
406 assignments,
407 } => ViewNode::Mutate {
408 input: Box::new(push_predicates_down(*input)),
409 assignments,
410 },
411 ViewNode::Arrange { input, keys } => ViewNode::Arrange {
412 input: Box::new(push_predicates_down(*input)),
413 keys,
414 },
415 ViewNode::GroupSummarise {
416 input,
417 group_keys,
418 aggregations,
419 } => ViewNode::GroupSummarise {
420 input: Box::new(push_predicates_down(*input)),
421 group_keys,
422 aggregations,
423 },
424 ViewNode::Distinct { input, columns } => ViewNode::Distinct {
425 input: Box::new(push_predicates_down(*input)),
426 columns,
427 },
428 ViewNode::Join {
429 left,
430 right,
431 on,
432 kind,
433 } => ViewNode::Join {
434 left: Box::new(push_predicates_down(*left)),
435 right: Box::new(push_predicates_down(*right)),
436 on,
437 kind,
438 },
439 other => other,
440 }
441}
442
443fn push_filter_into(node: ViewNode, predicate: DExpr) -> ViewNode {
445 match node {
446 ViewNode::Select { input, columns } => ViewNode::Select {
449 input: Box::new(push_filter_into(*input, predicate)),
450 columns,
451 },
452
453 ViewNode::Arrange { input, keys } => ViewNode::Arrange {
455 input: Box::new(push_filter_into(*input, predicate)),
456 keys,
457 },
458
459 ViewNode::Mutate {
462 input,
463 assignments,
464 } => {
465 let pred_cols = expr_columns(&predicate);
466 let mutated_cols: BTreeSet<String> =
467 assignments.iter().map(|(name, _)| name.clone()).collect();
468 let references_mutated = pred_cols.iter().any(|c| mutated_cols.contains(c));
469
470 if references_mutated {
471 ViewNode::Filter {
473 input: Box::new(ViewNode::Mutate {
474 input,
475 assignments,
476 }),
477 predicate,
478 }
479 } else {
480 ViewNode::Mutate {
482 input: Box::new(push_filter_into(*input, predicate)),
483 assignments,
484 }
485 }
486 }
487
488 ViewNode::Join {
491 left,
492 right,
493 on,
494 kind,
495 } => {
496 let pred_cols = expr_columns(&predicate);
497 let left_cols = node_output_columns(&left);
498 let right_cols = node_output_columns(&right);
499
500 let all_in_left = pred_cols.iter().all(|c| left_cols.contains(c));
501 let all_in_right = pred_cols.iter().all(|c| right_cols.contains(c));
502
503 if all_in_left {
504 ViewNode::Join {
505 left: Box::new(push_filter_into(*left, predicate)),
506 right,
507 on,
508 kind,
509 }
510 } else if all_in_right {
511 ViewNode::Join {
512 left,
513 right: Box::new(push_filter_into(*right, predicate)),
514 on,
515 kind,
516 }
517 } else {
518 ViewNode::Filter {
520 input: Box::new(ViewNode::Join {
521 left,
522 right,
523 on,
524 kind,
525 }),
526 predicate,
527 }
528 }
529 }
530
531 other => ViewNode::Filter {
533 input: Box::new(other),
534 predicate,
535 },
536 }
537}
538
539fn eliminate_redundant_selects(plan: ViewNode) -> ViewNode {
544 match plan {
545 ViewNode::Select { input, columns } => {
546 let optimized_input = eliminate_redundant_selects(*input);
547 let input_cols = node_output_columns(&optimized_input);
548
549 let select_set: BTreeSet<&str> = columns.iter().map(|s| s.as_str()).collect();
551 let input_set: BTreeSet<&str> = input_cols.iter().map(|s| s.as_str()).collect();
552
553 if select_set == input_set {
554 optimized_input
555 } else {
556 ViewNode::Select {
557 input: Box::new(optimized_input),
558 columns,
559 }
560 }
561 }
562 ViewNode::Filter { input, predicate } => ViewNode::Filter {
563 input: Box::new(eliminate_redundant_selects(*input)),
564 predicate,
565 },
566 ViewNode::Mutate {
567 input,
568 assignments,
569 } => ViewNode::Mutate {
570 input: Box::new(eliminate_redundant_selects(*input)),
571 assignments,
572 },
573 ViewNode::Arrange { input, keys } => ViewNode::Arrange {
574 input: Box::new(eliminate_redundant_selects(*input)),
575 keys,
576 },
577 ViewNode::GroupSummarise {
578 input,
579 group_keys,
580 aggregations,
581 } => ViewNode::GroupSummarise {
582 input: Box::new(eliminate_redundant_selects(*input)),
583 group_keys,
584 aggregations,
585 },
586 ViewNode::Distinct { input, columns } => ViewNode::Distinct {
587 input: Box::new(eliminate_redundant_selects(*input)),
588 columns,
589 },
590 ViewNode::Join {
591 left,
592 right,
593 on,
594 kind,
595 } => ViewNode::Join {
596 left: Box::new(eliminate_redundant_selects(*left)),
597 right: Box::new(eliminate_redundant_selects(*right)),
598 on,
599 kind,
600 },
601 other => other,
602 }
603}
604
605fn execute(node: ViewNode) -> Result<TidyFrame, TidyError> {
609 match node {
610 ViewNode::Scan { df } => Ok(TidyFrame::from_df((*df).clone())),
611
612 ViewNode::Filter { input, predicate } => {
613 let frame = execute(*input)?;
614 let view = frame.view();
615 let filtered = view.filter(&predicate)?;
616 let df = filtered.materialize()?;
617 Ok(TidyFrame::from_df(df))
618 }
619
620 ViewNode::Select { input, columns } => {
621 let frame = execute(*input)?;
622 let view = frame.view();
623 let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
624 let selected = view.select(&col_refs)?;
625 let df = selected.materialize()?;
626 Ok(TidyFrame::from_df(df))
627 }
628
629 ViewNode::Mutate {
630 input,
631 assignments,
632 } => {
633 let frame = execute(*input)?;
634 let view = frame.view();
635 let assign_refs: Vec<(&str, DExpr)> = assignments
636 .into_iter()
637 .map(|(name, expr)| (leaked_str(&name), expr))
638 .collect();
639 let result = view.mutate(&assign_refs.iter().map(|(n, e)| (*n, e.clone())).collect::<Vec<_>>())?;
641 Ok(result)
642 }
643
644 ViewNode::Arrange { input, keys } => {
645 let frame = execute(*input)?;
646 let view = frame.view();
647 let arranged = view.arrange(&keys)?;
648 let df = arranged.materialize()?;
649 Ok(TidyFrame::from_df(df))
650 }
651
652 ViewNode::GroupSummarise {
653 input,
654 group_keys,
655 aggregations,
656 } => {
657 let frame = execute(*input)?;
658 let view = frame.view();
659 let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
660 let grouped = view.group_by(&key_refs)?;
661 let agg_refs: Vec<(&str, TidyAgg)> = aggregations
662 .into_iter()
663 .map(|(name, agg)| (leaked_str(&name), agg))
664 .collect();
665 let result = grouped.summarise(
666 &agg_refs.iter().map(|(n, a)| (*n, a.clone())).collect::<Vec<_>>(),
667 )?;
668 Ok(result)
669 }
670
671 ViewNode::StreamingGroupSummarise {
672 input,
673 group_keys,
674 aggregations,
675 } => {
676 let frame = execute(*input)?;
677 let view = frame.view();
678 let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
679 let agg_owned: Vec<(String, crate::StreamingAgg)> = aggregations;
680 let agg_refs: Vec<(&str, crate::StreamingAgg)> = agg_owned
681 .iter()
682 .map(|(name, sa)| (leaked_str(name), sa.clone()))
683 .collect();
684 view.summarise_streaming(&key_refs, &agg_refs)
685 }
686
687 ViewNode::Distinct { input, columns } => {
688 let frame = execute(*input)?;
689 let view = frame.view();
690 let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
691 let distinct = view.distinct(&col_refs)?;
692 let df = distinct.materialize()?;
693 Ok(TidyFrame::from_df(df))
694 }
695
696 ViewNode::Join {
697 left,
698 right,
699 on,
700 kind,
701 } => {
702 let left_frame = execute(*left)?;
703 let right_frame = execute(*right)?;
704 let left_view = left_frame.view();
705 let right_view = right_frame.view();
706 let on_refs: Vec<(&str, &str)> = on
707 .iter()
708 .map(|(l, r)| (l.as_str(), r.as_str()))
709 .collect();
710
711 match kind {
712 JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
713 JoinType::Left => left_view.left_join(&right_view, &on_refs),
714 JoinType::Semi => {
715 let result = left_view.semi_join(&right_view, &on_refs)?;
716 let df = result.materialize()?;
717 Ok(TidyFrame::from_df(df))
718 }
719 JoinType::Anti => {
720 let result = left_view.anti_join(&right_view, &on_refs)?;
721 let df = result.materialize()?;
722 Ok(TidyFrame::from_df(df))
723 }
724 }
725 }
726 }
727}
728
729fn expr_columns(expr: &DExpr) -> BTreeSet<String> {
733 let mut cols = BTreeSet::new();
734 collect_expr_cols(expr, &mut cols);
735 cols
736}
737
738fn collect_expr_cols(expr: &DExpr, cols: &mut BTreeSet<String>) {
739 match expr {
740 DExpr::Col(name) => {
741 cols.insert(name.clone());
742 }
743 DExpr::BinOp { left, right, .. } => {
744 collect_expr_cols(left, cols);
745 collect_expr_cols(right, cols);
746 }
747 DExpr::Agg(_, inner) => collect_expr_cols(inner, cols),
748 DExpr::FnCall(_, args) => {
749 for arg in args {
750 collect_expr_cols(arg, cols);
751 }
752 }
753 DExpr::CumSum(e)
754 | DExpr::CumProd(e)
755 | DExpr::CumMax(e)
756 | DExpr::CumMin(e)
757 | DExpr::Lag(e, _)
758 | DExpr::Lead(e, _)
759 | DExpr::Rank(e)
760 | DExpr::DenseRank(e) => {
761 collect_expr_cols(e, cols);
762 }
763 DExpr::RollingSum(col, _)
765 | DExpr::RollingMean(col, _)
766 | DExpr::RollingMin(col, _)
767 | DExpr::RollingMax(col, _)
768 | DExpr::RollingVar(col, _)
769 | DExpr::RollingSd(col, _) => {
770 cols.insert(col.clone());
771 }
772 DExpr::LitInt(_)
773 | DExpr::LitFloat(_)
774 | DExpr::LitBool(_)
775 | DExpr::LitStr(_)
776 | DExpr::Count
777 | DExpr::RowNumber => {}
778 }
779}
780
781fn node_output_columns(node: &ViewNode) -> BTreeSet<String> {
788 match node {
789 ViewNode::Scan { df } => df.column_names().into_iter().map(|s| s.to_string()).collect(),
790 ViewNode::Filter { input, .. } => node_output_columns(input),
791 ViewNode::Select { columns, .. } => columns.iter().cloned().collect(),
792 ViewNode::Mutate {
793 input,
794 assignments,
795 } => {
796 let mut cols = node_output_columns(input);
797 for (name, _) in assignments {
798 cols.insert(name.clone());
799 }
800 cols
801 }
802 ViewNode::Arrange { input, .. } => node_output_columns(input),
803 ViewNode::GroupSummarise {
804 group_keys,
805 aggregations,
806 ..
807 } => {
808 let mut cols: BTreeSet<String> = group_keys.iter().cloned().collect();
809 for (name, _) in aggregations {
810 cols.insert(name.clone());
811 }
812 cols
813 }
814 ViewNode::StreamingGroupSummarise {
815 group_keys,
816 aggregations,
817 ..
818 } => {
819 let mut cols: BTreeSet<String> = group_keys.iter().cloned().collect();
820 for (name, _) in aggregations {
821 cols.insert(name.clone());
822 }
823 cols
824 }
825 ViewNode::Distinct { input, .. } => node_output_columns(input),
826 ViewNode::Join {
827 left, right, on, ..
828 } => {
829 let mut cols = node_output_columns(left);
830 let right_cols = node_output_columns(right);
831 let left_keys: BTreeSet<&str> = on.iter().map(|(l, _)| l.as_str()).collect();
833 let right_keys: BTreeSet<&str> = on.iter().map(|(_, r)| r.as_str()).collect();
834 for c in &right_cols {
835 if !right_keys.contains(c.as_str()) || !left_keys.contains(c.as_str()) {
836 cols.insert(c.clone());
837 }
838 }
839 cols
840 }
841 }
842}
843
844fn leaked_str(s: &str) -> &'static str {
852 Box::leak(s.to_string().into_boxed_str())
853}
854
855impl ViewNode {
858 pub fn count_filters(&self) -> usize {
860 match self {
861 ViewNode::Filter { input, .. } => 1 + input.count_filters(),
862 ViewNode::Select { input, .. } => input.count_filters(),
863 ViewNode::Mutate { input, .. } => input.count_filters(),
864 ViewNode::Arrange { input, .. } => input.count_filters(),
865 ViewNode::GroupSummarise { input, .. } => input.count_filters(),
866 ViewNode::StreamingGroupSummarise { input, .. } => input.count_filters(),
867 ViewNode::Distinct { input, .. } => input.count_filters(),
868 ViewNode::Join { left, right, .. } => {
869 left.count_filters() + right.count_filters()
870 }
871 ViewNode::Scan { .. } => 0,
872 }
873 }
874
875 pub fn is_filter_on_scan(&self) -> bool {
878 match self {
879 ViewNode::Filter { input, .. } => matches!(input.as_ref(), ViewNode::Scan { .. }),
880 _ => false,
881 }
882 }
883
884 pub fn innermost(&self) -> &ViewNode {
886 match self {
887 ViewNode::Filter { input, .. }
888 | ViewNode::Select { input, .. }
889 | ViewNode::Mutate { input, .. }
890 | ViewNode::Arrange { input, .. }
891 | ViewNode::GroupSummarise { input, .. }
892 | ViewNode::StreamingGroupSummarise { input, .. }
893 | ViewNode::Distinct { input, .. } => input.innermost(),
894 ViewNode::Join { left, .. } => left.innermost(),
895 ViewNode::Scan { .. } => self,
896 }
897 }
898
899 pub fn kind(&self) -> &'static str {
901 match self {
902 ViewNode::Scan { .. } => "Scan",
903 ViewNode::Filter { .. } => "Filter",
904 ViewNode::Select { .. } => "Select",
905 ViewNode::Mutate { .. } => "Mutate",
906 ViewNode::Arrange { .. } => "Arrange",
907 ViewNode::GroupSummarise { .. } => "GroupSummarise",
908 ViewNode::StreamingGroupSummarise { .. } => "StreamingGroupSummarise",
909 ViewNode::Distinct { .. } => "Distinct",
910 ViewNode::Join { .. } => "Join",
911 }
912 }
913
914 pub fn node_kinds(&self) -> Vec<&'static str> {
916 let mut out = vec![self.kind()];
917 match self {
918 ViewNode::Filter { input, .. }
919 | ViewNode::Select { input, .. }
920 | ViewNode::Mutate { input, .. }
921 | ViewNode::Arrange { input, .. }
922 | ViewNode::GroupSummarise { input, .. }
923 | ViewNode::StreamingGroupSummarise { input, .. }
924 | ViewNode::Distinct { input, .. } => {
925 out.extend(input.node_kinds());
926 }
927 ViewNode::Join { left, right, .. } => {
928 out.extend(left.node_kinds());
929 out.extend(right.node_kinds());
930 }
931 ViewNode::Scan { .. } => {}
932 }
933 out
934 }
935}
936
937const BATCH_SIZE: usize = 2048;
941
942#[derive(Debug, Clone)]
947pub struct Batch {
948 pub columns: Vec<(String, Column)>,
949 pub nrows: usize,
950}
951
952impl Batch {
953 fn into_dataframe(self) -> DataFrame {
955 DataFrame {
956 columns: self.columns,
957 }
958 }
959
960 fn get_column(&self, name: &str) -> Option<&Column> {
962 self.columns.iter().find(|(n, _)| n == name).map(|(_, c)| c)
963 }
964
965 fn column_names(&self) -> Vec<&str> {
967 self.columns.iter().map(|(n, _)| n.as_str()).collect()
968 }
969}
970
971fn slice_column(col: &Column, start: usize, end: usize) -> Column {
973 if matches!(col, Column::CategoricalAdaptive(_)) {
974 return slice_column(&col.to_legacy_categorical(), start, end);
975 }
976 match col {
977 Column::Float(v) => Column::Float(v[start..end].to_vec()),
978 Column::Int(v) => Column::Int(v[start..end].to_vec()),
979 Column::Str(v) => Column::Str(v[start..end].to_vec()),
980 Column::Bool(v) => Column::Bool(v[start..end].to_vec()),
981 Column::Categorical { levels, codes } => Column::Categorical {
982 levels: levels.clone(),
983 codes: codes[start..end].to_vec(),
984 },
985 Column::DateTime(v) => Column::DateTime(v[start..end].to_vec()),
986 Column::CategoricalAdaptive(_) => unreachable!("handled by early return"),
987 }
988}
989
990fn split_batches(df: &DataFrame) -> Vec<Batch> {
992 let nrows = df.nrows();
993 if nrows == 0 {
994 return vec![Batch {
995 columns: df.columns.iter().map(|(n, c)| {
996 (n.clone(), slice_column(c, 0, 0))
997 }).collect(),
998 nrows: 0,
999 }];
1000 }
1001 let mut batches = Vec::new();
1002 let mut offset = 0;
1003 while offset < nrows {
1004 let end = (offset + BATCH_SIZE).min(nrows);
1005 let batch_cols = df
1006 .columns
1007 .iter()
1008 .map(|(name, col)| (name.clone(), slice_column(col, offset, end)))
1009 .collect();
1010 batches.push(Batch {
1011 columns: batch_cols,
1012 nrows: end - offset,
1013 });
1014 offset = end;
1015 }
1016 batches
1017}
1018
1019fn merge_batches(batches: Vec<Batch>) -> Result<DataFrame, TidyError> {
1023 if batches.is_empty() {
1024 return Ok(DataFrame::new());
1025 }
1026
1027 let schema: Vec<String> = batches[0].column_names().iter().map(|s| s.to_string()).collect();
1029 if schema.is_empty() {
1030 return Ok(DataFrame::new());
1031 }
1032
1033 let total_rows: usize = batches.iter().map(|b| b.nrows).sum();
1035 let mut merged_cols: Vec<(String, Column)> = schema
1036 .iter()
1037 .map(|name| {
1038 let first_col = batches[0].get_column(name).unwrap();
1040 let empty = match first_col {
1041 Column::Float(_) => Column::Float(Vec::with_capacity(total_rows)),
1042 Column::Int(_) => Column::Int(Vec::with_capacity(total_rows)),
1043 Column::Str(_) => Column::Str(Vec::with_capacity(total_rows)),
1044 Column::Bool(_) => Column::Bool(Vec::with_capacity(total_rows)),
1045 Column::Categorical { levels, .. } => Column::Categorical {
1046 levels: levels.clone(),
1047 codes: Vec::with_capacity(total_rows),
1048 },
1049 Column::CategoricalAdaptive(_) => {
1050 let legacy = first_col.to_legacy_categorical();
1052 if let Column::Categorical { levels, .. } = legacy {
1053 Column::Categorical {
1054 levels,
1055 codes: Vec::with_capacity(total_rows),
1056 }
1057 } else {
1058 Column::Str(Vec::with_capacity(total_rows))
1060 }
1061 }
1062 Column::DateTime(_) => Column::DateTime(Vec::with_capacity(total_rows)),
1063 };
1064 (name.clone(), empty)
1065 })
1066 .collect();
1067
1068 for batch in &batches {
1070 if batch.nrows == 0 {
1071 continue;
1072 }
1073 for (i, (name, merged_col)) in merged_cols.iter_mut().enumerate() {
1074 let batch_col = batch.get_column(name).ok_or_else(|| {
1075 TidyError::ColumnNotFound(format!(
1076 "batch merge: column '{}' missing in batch (index {})",
1077 name, i
1078 ))
1079 })?;
1080 append_column(merged_col, batch_col);
1081 }
1082 }
1083
1084 Ok(DataFrame { columns: merged_cols })
1085}
1086
1087fn append_column(dst: &mut Column, src: &Column) {
1089 match (dst, src) {
1090 (Column::Float(d), Column::Float(s)) => d.extend_from_slice(s),
1091 (Column::Int(d), Column::Int(s)) => d.extend_from_slice(s),
1092 (Column::Str(d), Column::Str(s)) => d.extend(s.iter().cloned()),
1093 (Column::Bool(d), Column::Bool(s)) => d.extend_from_slice(s),
1094 (Column::Categorical { codes: d, .. }, Column::Categorical { codes: s, .. }) => {
1095 d.extend_from_slice(s);
1096 }
1097 (Column::DateTime(d), Column::DateTime(s)) => d.extend_from_slice(s),
1098 _ => {} }
1100}
1101
1102#[derive(Debug, Clone)]
1106enum StreamableOp {
1107 Filter { predicate: DExpr },
1108 Select { columns: Vec<String> },
1109 Mutate { assignments: Vec<(String, DExpr)> },
1110}
1111
1112fn is_pipeline_breaker(node: &ViewNode) -> bool {
1114 matches!(
1115 node,
1116 ViewNode::Arrange { .. }
1117 | ViewNode::GroupSummarise { .. }
1118 | ViewNode::StreamingGroupSummarise { .. }
1119 | ViewNode::Distinct { .. }
1120 | ViewNode::Join { .. }
1121 )
1122}
1123
1124fn collect_streamable_chain(node: ViewNode) -> (Vec<StreamableOp>, Box<ViewNode>) {
1130 let mut ops = Vec::new();
1131 let mut current = node;
1132
1133 loop {
1134 match current {
1135 ViewNode::Filter { input, predicate } => {
1136 ops.push(StreamableOp::Filter { predicate });
1137 current = *input;
1138 }
1139 ViewNode::Select { input, columns } => {
1140 ops.push(StreamableOp::Select { columns });
1141 current = *input;
1142 }
1143 ViewNode::Mutate { input, assignments } => {
1144 ops.push(StreamableOp::Mutate { assignments });
1145 current = *input;
1146 }
1147 other => {
1149 ops.reverse();
1151 return (ops, Box::new(other));
1152 }
1153 }
1154 }
1155}
1156
1157fn apply_op_to_batch(batch: Batch, op: &StreamableOp) -> Result<Batch, TidyError> {
1159 match op {
1160 StreamableOp::Filter { predicate } => {
1161 let df = batch.into_dataframe();
1163 if df.nrows() == 0 {
1164 return Ok(Batch {
1165 nrows: 0,
1166 columns: df.columns,
1167 });
1168 }
1169 let frame = TidyFrame::from_df(df);
1170 let view = frame.view();
1171 let filtered = view.filter(predicate)?;
1172 let result_df = filtered.materialize()?;
1173 let nrows = result_df.nrows();
1174 Ok(Batch {
1175 columns: result_df.columns,
1176 nrows,
1177 })
1178 }
1179 StreamableOp::Select { columns } => {
1180 let selected: Vec<(String, Column)> = columns
1182 .iter()
1183 .filter_map(|name| {
1184 batch
1185 .columns
1186 .iter()
1187 .find(|(n, _)| n == name)
1188 .cloned()
1189 })
1190 .collect();
1191 Ok(Batch {
1192 nrows: batch.nrows,
1193 columns: selected,
1194 })
1195 }
1196 StreamableOp::Mutate { assignments } => {
1197 let df = batch.into_dataframe();
1199 let frame = TidyFrame::from_df(df);
1200 let view = frame.view();
1201 let assign_refs: Vec<(&str, DExpr)> = assignments
1202 .iter()
1203 .map(|(name, expr)| (leaked_str(name), expr.clone()))
1204 .collect();
1205 let result = view.mutate(
1206 &assign_refs
1207 .iter()
1208 .map(|(n, e)| (*n, e.clone()))
1209 .collect::<Vec<_>>(),
1210 )?;
1211 let result_df = result.borrow().clone();
1212 let nrows = result_df.nrows();
1213 Ok(Batch {
1214 columns: result_df.columns,
1215 nrows,
1216 })
1217 }
1218 }
1219}
1220
1221fn apply_chain_batched(
1223 frame: &TidyFrame,
1224 chain: &[StreamableOp],
1225) -> Result<TidyFrame, TidyError> {
1226 let df = frame.borrow().clone();
1227 let batches = split_batches(&df);
1228
1229 let mut result_batches: Vec<Batch> = Vec::new();
1230 for batch in batches {
1231 let mut current = batch;
1232 for op in chain {
1233 current = apply_op_to_batch(current, op)?;
1234 }
1235 if current.nrows > 0 {
1236 result_batches.push(current);
1237 }
1238 }
1239
1240 if result_batches.is_empty() {
1241 let empty_df = DataFrame {
1243 columns: df
1244 .columns
1245 .iter()
1246 .map(|(name, col)| {
1247 (name.clone(), slice_column(col, 0, 0))
1248 })
1249 .collect(),
1250 };
1251 let mut result_cols: Option<Vec<String>> = None;
1253 for op in chain {
1254 if let StreamableOp::Select { columns } = op {
1255 result_cols = Some(columns.clone());
1256 }
1257 }
1258 if let Some(cols) = result_cols {
1259 let pruned: Vec<(String, Column)> = cols
1260 .iter()
1261 .filter_map(|name| {
1262 empty_df
1263 .columns
1264 .iter()
1265 .find(|(n, _)| n == name)
1266 .cloned()
1267 })
1268 .collect();
1269 return Ok(TidyFrame::from_df(DataFrame { columns: pruned }));
1270 }
1271 return Ok(TidyFrame::from_df(empty_df));
1272 }
1273
1274 let merged = merge_batches(result_batches)?;
1275 Ok(TidyFrame::from_df(merged))
1276}
1277
1278pub fn execute_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
1290 match &node {
1291 ViewNode::Scan { .. } => execute(node),
1293
1294 _ if !is_pipeline_breaker(&node) => {
1296 let (chain, base) = collect_streamable_chain(node);
1297 if chain.is_empty() {
1298 return execute_batched(*base);
1300 }
1301 let base_frame = execute_batched(*base)?;
1302 apply_chain_batched(&base_frame, &chain)
1303 }
1304
1305 _ => execute_breaker_batched(node),
1307 }
1308}
1309
1310fn execute_breaker_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
1313 match node {
1314 ViewNode::Arrange { input, keys } => {
1315 let frame = execute_batched(*input)?;
1316 let view = frame.view();
1317 let arranged = view.arrange(&keys)?;
1318 let df = arranged.materialize()?;
1319 Ok(TidyFrame::from_df(df))
1320 }
1321
1322 ViewNode::GroupSummarise {
1323 input,
1324 group_keys,
1325 aggregations,
1326 } => {
1327 let frame = execute_batched(*input)?;
1328 let view = frame.view();
1329 let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
1330 let grouped = view.group_by(&key_refs)?;
1331 let agg_refs: Vec<(&str, TidyAgg)> = aggregations
1332 .into_iter()
1333 .map(|(name, agg)| (leaked_str(&name), agg))
1334 .collect();
1335 let result = grouped.summarise(
1336 &agg_refs
1337 .iter()
1338 .map(|(n, a)| (*n, a.clone()))
1339 .collect::<Vec<_>>(),
1340 )?;
1341 Ok(result)
1342 }
1343
1344 ViewNode::StreamingGroupSummarise {
1345 input,
1346 group_keys,
1347 aggregations,
1348 } => {
1349 let frame = execute_batched(*input)?;
1350 let view = frame.view();
1351 let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
1352 let agg_owned: Vec<(String, crate::StreamingAgg)> = aggregations;
1353 let agg_refs: Vec<(&str, crate::StreamingAgg)> = agg_owned
1354 .iter()
1355 .map(|(name, sa)| (leaked_str(name), sa.clone()))
1356 .collect();
1357 view.summarise_streaming(&key_refs, &agg_refs)
1358 }
1359
1360 ViewNode::Distinct { input, columns } => {
1361 let frame = execute_batched(*input)?;
1362 let view = frame.view();
1363 let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
1364 let distinct = view.distinct(&col_refs)?;
1365 let df = distinct.materialize()?;
1366 Ok(TidyFrame::from_df(df))
1367 }
1368
1369 ViewNode::Join {
1370 left,
1371 right,
1372 on,
1373 kind,
1374 } => {
1375 let left_frame = execute_batched(*left)?;
1376 let right_frame = execute_batched(*right)?;
1377 let left_view = left_frame.view();
1378 let right_view = right_frame.view();
1379 let on_refs: Vec<(&str, &str)> =
1380 on.iter().map(|(l, r)| (l.as_str(), r.as_str())).collect();
1381
1382 match kind {
1383 JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
1384 JoinType::Left => left_view.left_join(&right_view, &on_refs),
1385 JoinType::Semi => {
1386 let result = left_view.semi_join(&right_view, &on_refs)?;
1387 let df = result.materialize()?;
1388 Ok(TidyFrame::from_df(df))
1389 }
1390 JoinType::Anti => {
1391 let result = left_view.anti_join(&right_view, &on_refs)?;
1392 let df = result.materialize()?;
1393 Ok(TidyFrame::from_df(df))
1394 }
1395 }
1396 }
1397
1398 other => execute(other),
1400 }
1401}
1402
1403impl LazyView {
1404 pub fn collect_batched(self) -> Result<TidyFrame, TidyError> {
1416 let optimized = optimize(self.plan);
1417 execute_batched(optimized)
1418 }
1419}
1420
1421#[cfg(test)]
1426mod tests {
1427 use super::*;
1428 use crate::{Column, DExpr, DBinOp, DataFrame, TidyAgg, ArrangeKey, TidyView};
1429
1430 fn test_df() -> DataFrame {
1432 DataFrame {
1433 columns: vec![
1434 (
1435 "name".to_string(),
1436 Column::Str(vec![
1437 "Alice".into(),
1438 "Bob".into(),
1439 "Carol".into(),
1440 "Dave".into(),
1441 ]),
1442 ),
1443 ("age".to_string(), Column::Int(vec![30, 25, 35, 25])),
1444 (
1445 "score".to_string(),
1446 Column::Float(vec![90.0, 85.0, 95.0, 80.0]),
1447 ),
1448 ],
1449 }
1450 }
1451
1452 fn dept_df() -> DataFrame {
1454 DataFrame {
1455 columns: vec![
1456 (
1457 "name".to_string(),
1458 Column::Str(vec!["Alice".into(), "Bob".into(), "Eve".into()]),
1459 ),
1460 (
1461 "dept".to_string(),
1462 Column::Str(vec!["Eng".into(), "Sales".into(), "Eng".into()]),
1463 ),
1464 ],
1465 }
1466 }
1467
1468 #[test]
1471 fn lazy_filter_matches_eager() {
1472 let df = test_df();
1473 let predicate = DExpr::BinOp {
1474 op: DBinOp::Gt,
1475 left: Box::new(DExpr::Col("age".into())),
1476 right: Box::new(DExpr::LitInt(25)),
1477 };
1478
1479 let eager_view = TidyView::from_df(df.clone());
1481 let eager_filtered = eager_view.filter(&predicate).unwrap();
1482 let eager_df = eager_filtered.materialize().unwrap();
1483
1484 let lazy_frame = LazyView::from_df(df)
1486 .filter(predicate)
1487 .collect()
1488 .unwrap();
1489 let lazy_df = lazy_frame.borrow();
1490
1491 assert_eq!(eager_df.nrows(), lazy_df.nrows());
1492 assert_eq!(eager_df.nrows(), 2); let eager_names: Vec<String> = match eager_df.get_column("name").unwrap() {
1496 Column::Str(v) => v.clone(),
1497 _ => panic!("expected Str"),
1498 };
1499 let lazy_names: Vec<String> = match lazy_df.get_column("name").unwrap() {
1500 Column::Str(v) => v.clone(),
1501 _ => panic!("expected Str"),
1502 };
1503 assert_eq!(eager_names, lazy_names);
1504 }
1505
1506 #[test]
1507 fn lazy_select_matches_eager() {
1508 let df = test_df();
1509
1510 let eager_view = TidyView::from_df(df.clone());
1512 let eager_selected = eager_view.select(&["name", "age"]).unwrap();
1513 let eager_df = eager_selected.materialize().unwrap();
1514
1515 let lazy_frame = LazyView::from_df(df)
1517 .select(vec!["name".into(), "age".into()])
1518 .collect()
1519 .unwrap();
1520 let lazy_df = lazy_frame.borrow();
1521
1522 assert_eq!(eager_df.ncols(), 2);
1523 assert_eq!(lazy_df.ncols(), 2);
1524 assert_eq!(eager_df.column_names(), lazy_df.column_names());
1525 }
1526
1527 #[test]
1528 fn lazy_arrange_matches_eager() {
1529 let df = test_df();
1530 let keys = vec![ArrangeKey::asc("age")];
1531
1532 let eager_view = TidyView::from_df(df.clone());
1534 let eager_arranged = eager_view.arrange(&keys).unwrap();
1535 let eager_df = eager_arranged.materialize().unwrap();
1536
1537 let lazy_frame = LazyView::from_df(df)
1539 .arrange(keys)
1540 .collect()
1541 .unwrap();
1542 let lazy_df = lazy_frame.borrow();
1543
1544 let eager_ages = match eager_df.get_column("age").unwrap() {
1545 Column::Int(v) => v.clone(),
1546 _ => panic!("expected Int"),
1547 };
1548 let lazy_ages = match lazy_df.get_column("age").unwrap() {
1549 Column::Int(v) => v.clone(),
1550 _ => panic!("expected Int"),
1551 };
1552 assert_eq!(eager_ages, lazy_ages);
1553 assert_eq!(eager_ages, vec![25, 25, 30, 35]);
1555 }
1556
1557 #[test]
1558 fn lazy_group_summarise_matches_eager() {
1559 let df = test_df();
1560
1561 let eager_view = TidyView::from_df(df.clone());
1563 let grouped = eager_view.group_by(&["age"]).unwrap();
1564 let eager_frame = grouped
1565 .summarise(&[("count", TidyAgg::Count)])
1566 .unwrap();
1567 let eager_df = eager_frame.borrow();
1568
1569 let lazy_frame = LazyView::from_df(df)
1571 .group_summarise(
1572 vec!["age".into()],
1573 vec![("count".into(), TidyAgg::Count)],
1574 )
1575 .collect()
1576 .unwrap();
1577 let lazy_df = lazy_frame.borrow();
1578
1579 assert_eq!(eager_df.nrows(), lazy_df.nrows());
1580 assert_eq!(eager_df.column_names(), lazy_df.column_names());
1581 }
1582
1583 #[test]
1586 fn predicate_pushdown_past_select() {
1587 let df = test_df();
1588 let predicate = DExpr::BinOp {
1589 op: DBinOp::Gt,
1590 left: Box::new(DExpr::Col("age".into())),
1591 right: Box::new(DExpr::LitInt(25)),
1592 };
1593
1594 let lazy = LazyView::from_df(df)
1596 .select(vec!["name".into(), "age".into()])
1597 .filter(predicate);
1598
1599 let optimized = lazy.optimized_plan();
1600
1601 let kinds = optimized.node_kinds();
1604 assert_eq!(kinds, vec!["Select", "Filter", "Scan"]);
1605 }
1606
1607 #[test]
1608 fn predicate_pushdown_past_arrange() {
1609 let df = test_df();
1610 let predicate = DExpr::BinOp {
1611 op: DBinOp::Gt,
1612 left: Box::new(DExpr::Col("age".into())),
1613 right: Box::new(DExpr::LitInt(25)),
1614 };
1615
1616 let lazy = LazyView::from_df(df)
1618 .arrange(vec![ArrangeKey::asc("age")])
1619 .filter(predicate);
1620
1621 let optimized = lazy.optimized_plan();
1622
1623 let kinds = optimized.node_kinds();
1625 assert_eq!(kinds, vec!["Arrange", "Filter", "Scan"]);
1626 }
1627
1628 #[test]
1629 fn predicate_not_pushed_past_mutate_when_dependent() {
1630 let df = test_df();
1631 let predicate = DExpr::BinOp {
1634 op: DBinOp::Gt,
1635 left: Box::new(DExpr::Col("doubled_age".into())),
1636 right: Box::new(DExpr::LitInt(50)),
1637 };
1638
1639 let lazy = LazyView::from_df(df)
1640 .mutate(vec![(
1641 "doubled_age".into(),
1642 DExpr::BinOp {
1643 op: DBinOp::Mul,
1644 left: Box::new(DExpr::Col("age".into())),
1645 right: Box::new(DExpr::LitInt(2)),
1646 },
1647 )])
1648 .filter(predicate);
1649
1650 let optimized = lazy.optimized_plan();
1651
1652 let kinds = optimized.node_kinds();
1654 assert_eq!(kinds, vec!["Filter", "Mutate", "Scan"]);
1655 }
1656
1657 #[test]
1658 fn predicate_pushed_past_mutate_when_independent() {
1659 let df = test_df();
1660 let predicate = DExpr::BinOp {
1663 op: DBinOp::Gt,
1664 left: Box::new(DExpr::Col("score".into())),
1665 right: Box::new(DExpr::LitFloat(85.0)),
1666 };
1667
1668 let lazy = LazyView::from_df(df)
1669 .mutate(vec![(
1670 "doubled_age".into(),
1671 DExpr::BinOp {
1672 op: DBinOp::Mul,
1673 left: Box::new(DExpr::Col("age".into())),
1674 right: Box::new(DExpr::LitInt(2)),
1675 },
1676 )])
1677 .filter(predicate);
1678
1679 let optimized = lazy.optimized_plan();
1680
1681 let kinds = optimized.node_kinds();
1683 assert_eq!(kinds, vec!["Mutate", "Filter", "Scan"]);
1684 }
1685
1686 #[test]
1687 fn predicate_not_pushed_past_group_summarise() {
1688 let df = test_df();
1689 let predicate = DExpr::BinOp {
1690 op: DBinOp::Gt,
1691 left: Box::new(DExpr::Col("count".into())),
1692 right: Box::new(DExpr::LitInt(1)),
1693 };
1694
1695 let lazy = LazyView::from_df(df)
1696 .group_summarise(
1697 vec!["age".into()],
1698 vec![("count".into(), TidyAgg::Count)],
1699 )
1700 .filter(predicate);
1701
1702 let optimized = lazy.optimized_plan();
1703
1704 let kinds = optimized.node_kinds();
1710 assert!(
1711 kinds == vec!["Filter", "GroupSummarise", "Scan"]
1712 || kinds == vec!["Filter", "StreamingGroupSummarise", "Scan"],
1713 "filter must stay above the group node, got {:?}",
1714 kinds
1715 );
1716 }
1717
1718 #[test]
1721 fn consecutive_filters_merged() {
1722 let df = test_df();
1723 let pred1 = DExpr::BinOp {
1724 op: DBinOp::Gt,
1725 left: Box::new(DExpr::Col("age".into())),
1726 right: Box::new(DExpr::LitInt(20)),
1727 };
1728 let pred2 = DExpr::BinOp {
1729 op: DBinOp::Lt,
1730 left: Box::new(DExpr::Col("score".into())),
1731 right: Box::new(DExpr::LitFloat(95.0)),
1732 };
1733
1734 let lazy = LazyView::from_df(df).filter(pred1).filter(pred2);
1735
1736 let optimized = lazy.optimized_plan();
1737
1738 assert_eq!(optimized.count_filters(), 1);
1740
1741 let df2 = test_df();
1744 let result = LazyView::from_df(df2)
1745 .filter(DExpr::BinOp {
1746 op: DBinOp::Gt,
1747 left: Box::new(DExpr::Col("age".into())),
1748 right: Box::new(DExpr::LitInt(20)),
1749 })
1750 .filter(DExpr::BinOp {
1751 op: DBinOp::Lt,
1752 left: Box::new(DExpr::Col("score".into())),
1753 right: Box::new(DExpr::LitFloat(95.0)),
1754 })
1755 .collect()
1756 .unwrap();
1757
1758 let result_df = result.borrow();
1759 assert_eq!(result_df.nrows(), 3);
1760 }
1761
1762 #[test]
1765 fn redundant_select_eliminated() {
1766 let df = test_df();
1767
1768 let lazy = LazyView::from_df(df)
1770 .select(vec!["name".into(), "age".into(), "score".into()]);
1771
1772 let optimized = lazy.optimized_plan();
1773
1774 assert_eq!(optimized.kind(), "Scan");
1776 }
1777
1778 #[test]
1779 fn non_redundant_select_kept() {
1780 let df = test_df();
1781
1782 let lazy = LazyView::from_df(df).select(vec!["name".into(), "age".into()]);
1784
1785 let optimized = lazy.optimized_plan();
1786
1787 assert_eq!(optimized.kind(), "Select");
1788 }
1789
1790 #[test]
1793 fn determinism_3_runs_identical() {
1794 for _ in 0..3 {
1795 let df = test_df();
1796 let result = LazyView::from_df(df)
1797 .filter(DExpr::BinOp {
1798 op: DBinOp::Gt,
1799 left: Box::new(DExpr::Col("age".into())),
1800 right: Box::new(DExpr::LitInt(20)),
1801 })
1802 .select(vec!["name".into(), "age".into()])
1803 .arrange(vec![ArrangeKey::desc("age")])
1804 .collect()
1805 .unwrap();
1806
1807 let result_df = result.borrow();
1808 assert_eq!(result_df.nrows(), 4);
1809
1810 let ages = match result_df.get_column("age").unwrap() {
1811 Column::Int(v) => v.clone(),
1812 _ => panic!("expected Int"),
1813 };
1814 assert_eq!(ages, vec![35, 30, 25, 25]);
1816
1817 let names = match result_df.get_column("name").unwrap() {
1818 Column::Str(v) => v.clone(),
1819 _ => panic!("expected Str"),
1820 };
1821 assert_eq!(names, vec!["Carol", "Alice", "Bob", "Dave"]);
1822 }
1823 }
1824
1825 #[test]
1828 fn lazy_inner_join() {
1829 let left = test_df();
1830 let right = dept_df();
1831
1832 let result = LazyView::from_df(left)
1833 .join(
1834 LazyView::from_df(right),
1835 vec![("name".into(), "name".into())],
1836 JoinType::Inner,
1837 )
1838 .collect()
1839 .unwrap();
1840
1841 let result_df = result.borrow();
1842 assert_eq!(result_df.nrows(), 2);
1844 assert!(result_df.get_column("dept").is_some());
1845 }
1846
1847 #[test]
1848 fn lazy_semi_join() {
1849 let left = test_df();
1850 let right = dept_df();
1851
1852 let result = LazyView::from_df(left)
1853 .join(
1854 LazyView::from_df(right),
1855 vec![("name".into(), "name".into())],
1856 JoinType::Semi,
1857 )
1858 .collect()
1859 .unwrap();
1860
1861 let result_df = result.borrow();
1862 assert_eq!(result_df.nrows(), 2);
1864 assert!(result_df.get_column("dept").is_none());
1866 }
1867
1868 #[test]
1869 fn lazy_anti_join() {
1870 let left = test_df();
1871 let right = dept_df();
1872
1873 let result = LazyView::from_df(left)
1874 .join(
1875 LazyView::from_df(right),
1876 vec![("name".into(), "name".into())],
1877 JoinType::Anti,
1878 )
1879 .collect()
1880 .unwrap();
1881
1882 let result_df = result.borrow();
1883 assert_eq!(result_df.nrows(), 2);
1885 }
1886
1887 #[test]
1890 fn lazy_distinct() {
1891 let df = test_df();
1892
1893 let result = LazyView::from_df(df)
1894 .distinct(vec!["age".into()])
1895 .collect()
1896 .unwrap();
1897
1898 let result_df = result.borrow();
1899 assert_eq!(result_df.nrows(), 3);
1901 }
1902
1903 #[test]
1906 fn complex_lazy_chain() {
1907 let df = test_df();
1908
1909 let result = LazyView::from_df(df)
1911 .filter(DExpr::BinOp {
1912 op: DBinOp::Gt,
1913 left: Box::new(DExpr::Col("age".into())),
1914 right: Box::new(DExpr::LitInt(20)),
1915 })
1916 .mutate(vec![(
1917 "bonus".into(),
1918 DExpr::BinOp {
1919 op: DBinOp::Mul,
1920 left: Box::new(DExpr::Col("score".into())),
1921 right: Box::new(DExpr::LitFloat(1.1)),
1922 },
1923 )])
1924 .select(vec!["name".into(), "bonus".into()])
1925 .arrange(vec![ArrangeKey::desc("bonus")])
1926 .collect()
1927 .unwrap();
1928
1929 let result_df = result.borrow();
1930 assert_eq!(result_df.nrows(), 4);
1931 assert_eq!(result_df.ncols(), 2);
1932 assert_eq!(result_df.column_names(), vec!["name", "bonus"]);
1933 }
1934
1935 #[test]
1938 fn predicate_pushdown_into_join_left_side() {
1939 let left = test_df();
1940 let right = dept_df();
1941
1942 let lazy = LazyView::from_df(left)
1944 .join(
1945 LazyView::from_df(right),
1946 vec![("name".into(), "name".into())],
1947 JoinType::Inner,
1948 )
1949 .filter(DExpr::BinOp {
1950 op: DBinOp::Gt,
1951 left: Box::new(DExpr::Col("age".into())),
1952 right: Box::new(DExpr::LitInt(25)),
1953 });
1954
1955 let optimized = lazy.optimized_plan();
1956
1957 let kinds = optimized.node_kinds();
1959 assert_eq!(kinds[0], "Join");
1961 if let ViewNode::Join { left, right, .. } = &optimized {
1963 assert_eq!(left.kind(), "Filter");
1964 assert_eq!(right.kind(), "Scan");
1965 } else {
1966 panic!("expected Join at top");
1967 }
1968 }
1969
1970 fn assert_df_eq(a: &DataFrame, b: &DataFrame, context: &str) {
1976 assert_eq!(
1977 a.nrows(),
1978 b.nrows(),
1979 "{}: nrows differ ({} vs {})",
1980 context,
1981 a.nrows(),
1982 b.nrows()
1983 );
1984 assert_eq!(
1985 a.column_names(),
1986 b.column_names(),
1987 "{}: column names differ",
1988 context
1989 );
1990 for (name_a, col_a) in &a.columns {
1991 let col_b = b.get_column(name_a).unwrap_or_else(|| {
1992 panic!("{}: column '{}' missing in b", context, name_a)
1993 });
1994 assert_col_eq(col_a, col_b, &format!("{} col '{}'", context, name_a));
1995 }
1996 }
1997
1998 fn assert_col_eq(a: &Column, b: &Column, context: &str) {
1999 match (a, b) {
2000 (Column::Int(va), Column::Int(vb)) => assert_eq!(va, vb, "{}", context),
2001 (Column::Float(va), Column::Float(vb)) => {
2002 assert_eq!(va.len(), vb.len(), "{}: float len", context);
2003 for (i, (x, y)) in va.iter().zip(vb.iter()).enumerate() {
2004 assert!(
2005 (x - y).abs() < 1e-12,
2006 "{}: float[{}] {} != {}",
2007 context,
2008 i,
2009 x,
2010 y
2011 );
2012 }
2013 }
2014 (Column::Str(va), Column::Str(vb)) => assert_eq!(va, vb, "{}", context),
2015 (Column::Bool(va), Column::Bool(vb)) => assert_eq!(va, vb, "{}", context),
2016 _ => panic!("{}: column type mismatch", context),
2017 }
2018 }
2019
2020 #[test]
2023 fn batched_filter_parity() {
2024 let predicate = DExpr::BinOp {
2025 op: DBinOp::Gt,
2026 left: Box::new(DExpr::Col("age".into())),
2027 right: Box::new(DExpr::LitInt(25)),
2028 };
2029
2030 let eager = LazyView::from_df(test_df())
2031 .filter(predicate.clone())
2032 .collect()
2033 .unwrap();
2034 let batched = LazyView::from_df(test_df())
2035 .filter(predicate)
2036 .collect_batched()
2037 .unwrap();
2038
2039 assert_df_eq(&eager.borrow(), &batched.borrow(), "filter parity");
2040 }
2041
2042 #[test]
2043 fn batched_select_parity() {
2044 let cols = vec!["name".into(), "score".into()];
2045
2046 let eager = LazyView::from_df(test_df())
2047 .select(cols.clone())
2048 .collect()
2049 .unwrap();
2050 let batched = LazyView::from_df(test_df())
2051 .select(cols)
2052 .collect_batched()
2053 .unwrap();
2054
2055 assert_df_eq(&eager.borrow(), &batched.borrow(), "select parity");
2056 }
2057
2058 #[test]
2059 fn batched_mutate_parity() {
2060 let assignments = vec![(
2061 "doubled".into(),
2062 DExpr::BinOp {
2063 op: DBinOp::Mul,
2064 left: Box::new(DExpr::Col("age".into())),
2065 right: Box::new(DExpr::LitInt(2)),
2066 },
2067 )];
2068
2069 let eager = LazyView::from_df(test_df())
2070 .mutate(assignments.clone())
2071 .collect()
2072 .unwrap();
2073 let batched = LazyView::from_df(test_df())
2074 .mutate(assignments)
2075 .collect_batched()
2076 .unwrap();
2077
2078 assert_df_eq(&eager.borrow(), &batched.borrow(), "mutate parity");
2079 }
2080
2081 #[test]
2082 fn batched_filter_select_mutate_chain_parity() {
2083 let predicate = DExpr::BinOp {
2084 op: DBinOp::Gt,
2085 left: Box::new(DExpr::Col("age".into())),
2086 right: Box::new(DExpr::LitInt(20)),
2087 };
2088 let assignments = vec![(
2089 "bonus".into(),
2090 DExpr::BinOp {
2091 op: DBinOp::Mul,
2092 left: Box::new(DExpr::Col("score".into())),
2093 right: Box::new(DExpr::LitFloat(1.1)),
2094 },
2095 )];
2096
2097 let eager = LazyView::from_df(test_df())
2098 .filter(predicate.clone())
2099 .mutate(assignments.clone())
2100 .select(vec!["name".into(), "bonus".into()])
2101 .collect()
2102 .unwrap();
2103 let batched = LazyView::from_df(test_df())
2104 .filter(predicate)
2105 .mutate(assignments)
2106 .select(vec!["name".into(), "bonus".into()])
2107 .collect_batched()
2108 .unwrap();
2109
2110 assert_df_eq(
2111 &eager.borrow(),
2112 &batched.borrow(),
2113 "filter+mutate+select chain parity",
2114 );
2115 }
2116
2117 #[test]
2118 fn batched_group_summarise_parity() {
2119 let eager = LazyView::from_df(test_df())
2120 .group_summarise(
2121 vec!["age".into()],
2122 vec![("count".into(), TidyAgg::Count)],
2123 )
2124 .collect()
2125 .unwrap();
2126 let batched = LazyView::from_df(test_df())
2127 .group_summarise(
2128 vec!["age".into()],
2129 vec![("count".into(), TidyAgg::Count)],
2130 )
2131 .collect_batched()
2132 .unwrap();
2133
2134 assert_df_eq(
2135 &eager.borrow(),
2136 &batched.borrow(),
2137 "group_summarise parity",
2138 );
2139 }
2140
2141 #[test]
2142 fn batched_arrange_parity() {
2143 let keys = vec![ArrangeKey::asc("age")];
2144
2145 let eager = LazyView::from_df(test_df())
2146 .arrange(keys.clone())
2147 .collect()
2148 .unwrap();
2149 let batched = LazyView::from_df(test_df())
2150 .arrange(keys)
2151 .collect_batched()
2152 .unwrap();
2153
2154 assert_df_eq(&eager.borrow(), &batched.borrow(), "arrange parity");
2155 }
2156
2157 #[test]
2158 fn batched_distinct_parity() {
2159 let eager = LazyView::from_df(test_df())
2160 .distinct(vec!["age".into()])
2161 .collect()
2162 .unwrap();
2163 let batched = LazyView::from_df(test_df())
2164 .distinct(vec!["age".into()])
2165 .collect_batched()
2166 .unwrap();
2167
2168 assert_df_eq(&eager.borrow(), &batched.borrow(), "distinct parity");
2169 }
2170
2171 #[test]
2172 fn batched_join_parity() {
2173 let eager = LazyView::from_df(test_df())
2174 .join(
2175 LazyView::from_df(dept_df()),
2176 vec![("name".into(), "name".into())],
2177 JoinType::Inner,
2178 )
2179 .collect()
2180 .unwrap();
2181 let batched = LazyView::from_df(test_df())
2182 .join(
2183 LazyView::from_df(dept_df()),
2184 vec![("name".into(), "name".into())],
2185 JoinType::Inner,
2186 )
2187 .collect_batched()
2188 .unwrap();
2189
2190 assert_df_eq(&eager.borrow(), &batched.borrow(), "join parity");
2191 }
2192
2193 #[test]
2194 fn batched_complex_pipeline_parity() {
2195 let predicate = DExpr::BinOp {
2197 op: DBinOp::Gt,
2198 left: Box::new(DExpr::Col("age".into())),
2199 right: Box::new(DExpr::LitInt(20)),
2200 };
2201 let assignments = vec![(
2202 "bonus".into(),
2203 DExpr::BinOp {
2204 op: DBinOp::Mul,
2205 left: Box::new(DExpr::Col("score".into())),
2206 right: Box::new(DExpr::LitFloat(1.1)),
2207 },
2208 )];
2209
2210 let eager = LazyView::from_df(test_df())
2211 .filter(predicate.clone())
2212 .mutate(assignments.clone())
2213 .select(vec!["name".into(), "bonus".into()])
2214 .arrange(vec![ArrangeKey::desc("bonus")])
2215 .collect()
2216 .unwrap();
2217 let batched = LazyView::from_df(test_df())
2218 .filter(predicate)
2219 .mutate(assignments)
2220 .select(vec!["name".into(), "bonus".into()])
2221 .arrange(vec![ArrangeKey::desc("bonus")])
2222 .collect_batched()
2223 .unwrap();
2224
2225 assert_df_eq(
2226 &eager.borrow(),
2227 &batched.borrow(),
2228 "complex pipeline parity",
2229 );
2230 }
2231
2232 #[test]
2235 fn batched_determinism_3_runs() {
2236 let mut results: Vec<Vec<i64>> = Vec::new();
2237 let mut results_names: Vec<Vec<String>> = Vec::new();
2238
2239 for _ in 0..3 {
2240 let result = LazyView::from_df(test_df())
2241 .filter(DExpr::BinOp {
2242 op: DBinOp::Gt,
2243 left: Box::new(DExpr::Col("age".into())),
2244 right: Box::new(DExpr::LitInt(20)),
2245 })
2246 .select(vec!["name".into(), "age".into()])
2247 .arrange(vec![ArrangeKey::desc("age")])
2248 .collect_batched()
2249 .unwrap();
2250
2251 let df = result.borrow();
2252 let ages = match df.get_column("age").unwrap() {
2253 Column::Int(v) => v.clone(),
2254 _ => panic!("expected Int"),
2255 };
2256 let names = match df.get_column("name").unwrap() {
2257 Column::Str(v) => v.clone(),
2258 _ => panic!("expected Str"),
2259 };
2260 results.push(ages);
2261 results_names.push(names);
2262 }
2263
2264 assert_eq!(results[0], results[1]);
2266 assert_eq!(results[1], results[2]);
2267 assert_eq!(results_names[0], results_names[1]);
2268 assert_eq!(results_names[1], results_names[2]);
2269 assert_eq!(results[0], vec![35, 30, 25, 25]);
2271 assert_eq!(results_names[0], vec!["Carol", "Alice", "Bob", "Dave"]);
2272 }
2273
2274 fn large_df() -> DataFrame {
2278 let n = 10_000usize;
2279 let names: Vec<String> = (0..n).map(|i| format!("user_{}", i)).collect();
2280 let ages: Vec<i64> = (0..n).map(|i| (i % 80) as i64 + 18).collect();
2281 let scores: Vec<f64> = (0..n).map(|i| 50.0 + (i % 50) as f64).collect();
2282 DataFrame {
2283 columns: vec![
2284 ("name".to_string(), Column::Str(names)),
2285 ("age".to_string(), Column::Int(ages)),
2286 ("score".to_string(), Column::Float(scores)),
2287 ],
2288 }
2289 }
2290
2291 #[test]
2292 fn batched_large_data_filter_parity() {
2293 let predicate = DExpr::BinOp {
2294 op: DBinOp::Gt,
2295 left: Box::new(DExpr::Col("age".into())),
2296 right: Box::new(DExpr::LitInt(50)),
2297 };
2298
2299 let eager = LazyView::from_df(large_df())
2300 .filter(predicate.clone())
2301 .collect()
2302 .unwrap();
2303 let batched = LazyView::from_df(large_df())
2304 .filter(predicate)
2305 .collect_batched()
2306 .unwrap();
2307
2308 assert_df_eq(
2309 &eager.borrow(),
2310 &batched.borrow(),
2311 "large data filter parity",
2312 );
2313 assert!(eager.borrow().nrows() > 0);
2315 }
2316
2317 #[test]
2318 fn batched_large_data_chain_parity() {
2319 let predicate = DExpr::BinOp {
2320 op: DBinOp::Gt,
2321 left: Box::new(DExpr::Col("age".into())),
2322 right: Box::new(DExpr::LitInt(50)),
2323 };
2324 let assignments = vec![(
2325 "bonus".into(),
2326 DExpr::BinOp {
2327 op: DBinOp::Mul,
2328 left: Box::new(DExpr::Col("score".into())),
2329 right: Box::new(DExpr::LitFloat(1.5)),
2330 },
2331 )];
2332
2333 let eager = LazyView::from_df(large_df())
2334 .filter(predicate.clone())
2335 .mutate(assignments.clone())
2336 .select(vec!["name".into(), "bonus".into()])
2337 .collect()
2338 .unwrap();
2339 let batched = LazyView::from_df(large_df())
2340 .filter(predicate)
2341 .mutate(assignments)
2342 .select(vec!["name".into(), "bonus".into()])
2343 .collect_batched()
2344 .unwrap();
2345
2346 assert_df_eq(
2347 &eager.borrow(),
2348 &batched.borrow(),
2349 "large data chain parity",
2350 );
2351 }
2352
2353 #[test]
2354 fn batched_large_data_determinism() {
2355 let mut prev_ages: Option<Vec<i64>> = None;
2356 for _ in 0..3 {
2357 let result = LazyView::from_df(large_df())
2358 .filter(DExpr::BinOp {
2359 op: DBinOp::Gt,
2360 left: Box::new(DExpr::Col("age".into())),
2361 right: Box::new(DExpr::LitInt(90)),
2362 })
2363 .mutate(vec![(
2364 "double_age".into(),
2365 DExpr::BinOp {
2366 op: DBinOp::Mul,
2367 left: Box::new(DExpr::Col("age".into())),
2368 right: Box::new(DExpr::LitInt(2)),
2369 },
2370 )])
2371 .collect_batched()
2372 .unwrap();
2373
2374 let df = result.borrow();
2375 let ages = match df.get_column("age").unwrap() {
2376 Column::Int(v) => v.clone(),
2377 _ => panic!("expected Int"),
2378 };
2379 if let Some(ref prev) = prev_ages {
2380 assert_eq!(prev, &ages, "determinism: ages differ across runs");
2381 }
2382 prev_ages = Some(ages);
2383 }
2384 }
2385
2386 #[test]
2389 fn split_batches_correct_count() {
2390 let df = large_df();
2391 let batches = split_batches(&df);
2392 assert_eq!(batches.len(), 5);
2394 assert_eq!(batches[0].nrows, 2048);
2395 assert_eq!(batches[1].nrows, 2048);
2396 assert_eq!(batches[2].nrows, 2048);
2397 assert_eq!(batches[3].nrows, 2048);
2398 assert_eq!(batches[4].nrows, 10000 - 4 * 2048); let total: usize = batches.iter().map(|b| b.nrows).sum();
2400 assert_eq!(total, 10000);
2401 }
2402
2403 #[test]
2404 fn split_batches_small_df() {
2405 let df = test_df(); let batches = split_batches(&df);
2407 assert_eq!(batches.len(), 1);
2408 assert_eq!(batches[0].nrows, 4);
2409 }
2410
2411 #[test]
2412 fn merge_batches_roundtrip() {
2413 let df = large_df();
2414 let batches = split_batches(&df);
2415 let merged = merge_batches(batches).unwrap();
2416 assert_df_eq(&df, &merged, "merge roundtrip");
2417 }
2418}