1use cjc_repro::kahan_sum_f64;
10use std::cell::RefCell;
11use std::collections::{BTreeMap, BTreeSet, VecDeque};
12use std::fmt;
13use std::rc::Rc;
14
15mod csv;
16pub use csv::{CsvConfig, CsvReader, StreamingCsvProcessor};
17
18pub mod agg_kernels;
19pub mod column_meta;
20pub mod dict_encoding;
21pub mod lazy;
22pub mod tidy_dispatch;
23
24#[derive(Debug, Clone)]
28pub enum Column {
29 Int(Vec<i64>),
31 Float(Vec<f64>),
33 Str(Vec<String>),
35 Bool(Vec<bool>),
37 Categorical {
39 levels: Vec<String>,
40 codes: Vec<u32>,
41 },
42 DateTime(Vec<i64>),
44}
45
46impl Column {
47 pub fn len(&self) -> usize {
49 match self {
50 Column::Int(v) => v.len(),
51 Column::Float(v) => v.len(),
52 Column::Str(v) => v.len(),
53 Column::Bool(v) => v.len(),
54 Column::Categorical { codes, .. } => codes.len(),
55 Column::DateTime(v) => v.len(),
56 }
57 }
58
59 pub fn is_empty(&self) -> bool {
61 self.len() == 0
62 }
63
64 pub fn type_name(&self) -> &'static str {
66 match self {
67 Column::Int(_) => "Int",
68 Column::Float(_) => "Float",
69 Column::Str(_) => "Str",
70 Column::Bool(_) => "Bool",
71 Column::Categorical { .. } => "Categorical",
72 Column::DateTime(_) => "DateTime",
73 }
74 }
75
76 pub fn get_display(&self, idx: usize) -> String {
78 match self {
79 Column::Int(v) => format!("{}", v[idx]),
80 Column::Float(v) => format!("{}", v[idx]),
81 Column::Str(v) => v[idx].clone(),
82 Column::Bool(v) => format!("{}", v[idx]),
83 Column::Categorical { levels, codes } => levels[codes[idx] as usize].clone(),
84 Column::DateTime(v) => format!("{}ms", v[idx]),
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
93pub struct DataFrame {
94 pub columns: Vec<(String, Column)>,
95}
96
97impl DataFrame {
98 pub fn new() -> Self {
100 Self {
101 columns: Vec::new(),
102 }
103 }
104
105 pub fn from_columns(columns: Vec<(String, Column)>) -> Result<Self, DataError> {
109 if columns.is_empty() {
110 return Ok(Self { columns });
111 }
112 let len = columns[0].1.len();
113 for (name, col) in &columns {
114 if col.len() != len {
115 return Err(DataError::ColumnLengthMismatch {
116 expected: len,
117 got: col.len(),
118 column: name.clone(),
119 });
120 }
121 }
122 Ok(Self { columns })
123 }
124
125 pub fn nrows(&self) -> usize {
127 self.columns.first().map(|(_, c)| c.len()).unwrap_or(0)
128 }
129
130 pub fn ncols(&self) -> usize {
132 self.columns.len()
133 }
134
135 pub fn column_names(&self) -> Vec<&str> {
137 self.columns.iter().map(|(n, _)| n.as_str()).collect()
138 }
139
140 pub fn get_column(&self, name: &str) -> Option<&Column> {
142 self.columns
143 .iter()
144 .find(|(n, _)| n == name)
145 .map(|(_, c)| c)
146 }
147
148 pub fn to_tensor_data(&self, col_names: &[&str]) -> Result<(Vec<f64>, Vec<usize>), DataError> {
150 let nrows = self.nrows();
151 let ncols = col_names.len();
152 let mut data = Vec::with_capacity(nrows * ncols);
153
154 for row in 0..nrows {
155 for &col_name in col_names {
156 let col = self
157 .get_column(col_name)
158 .ok_or_else(|| DataError::ColumnNotFound(col_name.to_string()))?;
159 let val = match col {
160 Column::Float(v) => v[row],
161 Column::Int(v) => v[row] as f64,
162 _ => {
163 return Err(DataError::InvalidOperation(format!(
164 "column `{}` is not numeric",
165 col_name
166 )))
167 }
168 };
169 data.push(val);
170 }
171 }
172
173 Ok((data, vec![nrows, ncols]))
174 }
175}
176
177impl Default for DataFrame {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183impl fmt::Display for DataFrame {
184 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185 if self.columns.is_empty() {
186 return write!(f, "(empty DataFrame)");
187 }
188
189 let names: Vec<&str> = self.columns.iter().map(|(n, _)| n.as_str()).collect();
191 let mut col_widths: Vec<usize> = names.iter().map(|n| n.len()).collect();
192
193 let nrows = self.nrows();
195 for (col_idx, (_, col)) in self.columns.iter().enumerate() {
196 for row in 0..nrows {
197 let s = col.get_display(row);
198 col_widths[col_idx] = col_widths[col_idx].max(s.len());
199 }
200 }
201
202 for (i, name) in names.iter().enumerate() {
204 if i > 0 {
205 write!(f, " | ")?;
206 }
207 write!(f, "{:>width$}", name, width = col_widths[i])?;
208 }
209 writeln!(f)?;
210
211 for (i, &w) in col_widths.iter().enumerate() {
213 if i > 0 {
214 write!(f, "-+-")?;
215 }
216 write!(f, "{}", "-".repeat(w))?;
217 }
218 writeln!(f)?;
219
220 for row in 0..nrows {
222 for (col_idx, (_, col)) in self.columns.iter().enumerate() {
223 if col_idx > 0 {
224 write!(f, " | ")?;
225 }
226 let s = col.get_display(row);
227 write!(f, "{:>width$}", s, width = col_widths[col_idx])?;
228 }
229 writeln!(f)?;
230 }
231
232 Ok(())
233 }
234}
235
236#[derive(Debug, Clone)]
240pub enum DExpr {
241 Col(String),
243 LitInt(i64),
245 LitFloat(f64),
247 LitBool(bool),
249 LitStr(String),
251 BinOp {
253 op: DBinOp,
254 left: Box<DExpr>,
255 right: Box<DExpr>,
256 },
257 Agg(AggFunc, Box<DExpr>),
259 Count,
261 FnCall(String, Vec<DExpr>),
263 CumSum(Box<DExpr>),
265 CumProd(Box<DExpr>),
267 CumMax(Box<DExpr>),
269 CumMin(Box<DExpr>),
271 Lag(Box<DExpr>, usize),
273 Lead(Box<DExpr>, usize),
275 Rank(Box<DExpr>),
277 DenseRank(Box<DExpr>),
279 RowNumber,
281 RollingSum(String, usize),
283 RollingMean(String, usize),
285 RollingMin(String, usize),
287 RollingMax(String, usize),
289 RollingVar(String, usize),
291 RollingSd(String, usize),
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq)]
297pub enum DBinOp {
298 Add,
300 Sub,
302 Mul,
304 Div,
306 Gt,
308 Lt,
310 Ge,
312 Le,
314 Eq,
316 Ne,
318 And,
320 Or,
322}
323
324#[derive(Debug, Clone, Copy, PartialEq, Eq)]
326pub enum AggFunc {
327 Sum,
329 Mean,
331 Min,
333 Max,
335 Count,
337}
338
339impl fmt::Display for DExpr {
340 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341 match self {
342 DExpr::Col(name) => write!(f, "col(\"{}\")", name),
343 DExpr::LitInt(v) => write!(f, "{}", v),
344 DExpr::LitFloat(v) => write!(f, "{}", v),
345 DExpr::LitBool(b) => write!(f, "{}", b),
346 DExpr::LitStr(s) => write!(f, "\"{}\"", s),
347 DExpr::BinOp { op, left, right } => {
348 let op_str = match op {
349 DBinOp::Add => "+",
350 DBinOp::Sub => "-",
351 DBinOp::Mul => "*",
352 DBinOp::Div => "/",
353 DBinOp::Gt => ">",
354 DBinOp::Lt => "<",
355 DBinOp::Ge => ">=",
356 DBinOp::Le => "<=",
357 DBinOp::Eq => "==",
358 DBinOp::Ne => "!=",
359 DBinOp::And => "&&",
360 DBinOp::Or => "||",
361 };
362 write!(f, "({} {} {})", left, op_str, right)
363 }
364 DExpr::Agg(func, expr) => {
365 let name = match func {
366 AggFunc::Sum => "sum",
367 AggFunc::Mean => "mean",
368 AggFunc::Min => "min",
369 AggFunc::Max => "max",
370 AggFunc::Count => "count",
371 };
372 write!(f, "{}({})", name, expr)
373 }
374 DExpr::Count => write!(f, "count()"),
375 DExpr::FnCall(name, args) => {
376 let args_str: Vec<String> = args.iter().map(|a| format!("{}", a)).collect();
377 write!(f, "{}({})", name, args_str.join(", "))
378 }
379 DExpr::CumSum(e) => write!(f, "cumsum({})", e),
380 DExpr::CumProd(e) => write!(f, "cumprod({})", e),
381 DExpr::CumMax(e) => write!(f, "cummax({})", e),
382 DExpr::CumMin(e) => write!(f, "cummin({})", e),
383 DExpr::Lag(e, k) => write!(f, "lag({}, {})", e, k),
384 DExpr::Lead(e, k) => write!(f, "lead({}, {})", e, k),
385 DExpr::Rank(e) => write!(f, "rank({})", e),
386 DExpr::DenseRank(e) => write!(f, "dense_rank({})", e),
387 DExpr::RowNumber => write!(f, "row_number()"),
388 DExpr::RollingSum(col, w) => write!(f, "rolling_sum(\"{}\", {})", col, w),
389 DExpr::RollingMean(col, w) => write!(f, "rolling_mean(\"{}\", {})", col, w),
390 DExpr::RollingMin(col, w) => write!(f, "rolling_min(\"{}\", {})", col, w),
391 DExpr::RollingMax(col, w) => write!(f, "rolling_max(\"{}\", {})", col, w),
392 DExpr::RollingVar(col, w) => write!(f, "rolling_var(\"{}\", {})", col, w),
393 DExpr::RollingSd(col, w) => write!(f, "rolling_sd(\"{}\", {})", col, w),
394 }
395 }
396}
397
398#[derive(Debug, Clone)]
402pub enum LogicalPlan {
403 Scan {
405 source: DataFrame,
406 },
407 Filter {
409 input: Box<LogicalPlan>,
410 predicate: DExpr,
411 },
412 GroupBy {
414 input: Box<LogicalPlan>,
415 keys: Vec<String>,
416 },
417 Aggregate {
419 input: Box<LogicalPlan>,
420 keys: Vec<String>,
421 aggs: Vec<(String, DExpr)>,
422 },
423 Project {
425 input: Box<LogicalPlan>,
426 columns: Vec<String>,
427 },
428 InnerJoin {
430 left: Box<LogicalPlan>,
431 right: Box<LogicalPlan>,
432 left_on: String,
433 right_on: String,
434 },
435 LeftJoin {
437 left: Box<LogicalPlan>,
438 right: Box<LogicalPlan>,
439 left_on: String,
440 right_on: String,
441 },
442 CrossJoin {
444 left: Box<LogicalPlan>,
445 right: Box<LogicalPlan>,
446 },
447}
448
449impl LogicalPlan {
450 pub fn referenced_columns(&self) -> Vec<String> {
452 let mut cols = Vec::new();
453 self.collect_columns(&mut cols);
454 cols.sort();
455 cols.dedup();
456 cols
457 }
458
459 fn collect_columns(&self, cols: &mut Vec<String>) {
460 match self {
461 LogicalPlan::Scan { .. } => {}
462 LogicalPlan::Filter { input, predicate } => {
463 input.collect_columns(cols);
464 collect_expr_columns(predicate, cols);
465 }
466 LogicalPlan::GroupBy { input, keys } => {
467 input.collect_columns(cols);
468 cols.extend(keys.clone());
469 }
470 LogicalPlan::Aggregate {
471 input, keys, aggs, ..
472 } => {
473 input.collect_columns(cols);
474 cols.extend(keys.clone());
475 for (_, expr) in aggs {
476 collect_expr_columns(expr, cols);
477 }
478 }
479 LogicalPlan::Project { input, columns } => {
480 input.collect_columns(cols);
481 cols.extend(columns.clone());
482 }
483 LogicalPlan::InnerJoin {
484 left,
485 right,
486 left_on,
487 right_on,
488 }
489 | LogicalPlan::LeftJoin {
490 left,
491 right,
492 left_on,
493 right_on,
494 } => {
495 left.collect_columns(cols);
496 right.collect_columns(cols);
497 cols.push(left_on.clone());
498 cols.push(right_on.clone());
499 }
500 LogicalPlan::CrossJoin { left, right } => {
501 left.collect_columns(cols);
502 right.collect_columns(cols);
503 }
504 }
505 }
506}
507
508fn collect_expr_columns(expr: &DExpr, cols: &mut Vec<String>) {
509 match expr {
510 DExpr::Col(name) => cols.push(name.clone()),
511 DExpr::BinOp { left, right, .. } => {
512 collect_expr_columns(left, cols);
513 collect_expr_columns(right, cols);
514 }
515 DExpr::Agg(_, inner) => collect_expr_columns(inner, cols),
516 DExpr::FnCall(_, args) => {
517 for arg in args {
518 collect_expr_columns(arg, cols);
519 }
520 }
521 DExpr::CumSum(e) | DExpr::CumProd(e) | DExpr::CumMax(e) | DExpr::CumMin(e)
522 | DExpr::Lag(e, _) | DExpr::Lead(e, _) | DExpr::Rank(e) | DExpr::DenseRank(e) => {
523 collect_expr_columns(e, cols);
524 }
525 DExpr::RollingSum(col, _) | DExpr::RollingMean(col, _)
526 | DExpr::RollingMin(col, _) | DExpr::RollingMax(col, _)
527 | DExpr::RollingVar(col, _) | DExpr::RollingSd(col, _) => {
528 cols.push(col.clone());
529 }
530 _ => {}
531 }
532}
533
534pub fn optimize(plan: LogicalPlan) -> LogicalPlan {
538 let plan = push_down_predicates(plan);
539 let plan = prune_columns(plan);
540 plan
541}
542
543fn push_down_predicates(plan: LogicalPlan) -> LogicalPlan {
545 match plan {
546 LogicalPlan::Filter {
547 input,
548 predicate,
549 } => {
550 let optimized_input = push_down_predicates(*input);
551 match optimized_input {
552 LogicalPlan::GroupBy {
554 input: inner,
555 keys,
556 } => {
557 let pred_cols = {
558 let mut c = Vec::new();
559 collect_expr_columns(&predicate, &mut c);
560 c
561 };
562 let can_push = pred_cols.iter().all(|c| !keys.contains(c))
563 || pred_cols.iter().all(|c| {
564 !keys.contains(c) || keys.contains(c)
566 });
567 if can_push && pred_cols.iter().all(|c| !keys.contains(c)) {
569 LogicalPlan::GroupBy {
570 input: Box::new(LogicalPlan::Filter {
571 input: inner,
572 predicate,
573 }),
574 keys,
575 }
576 } else {
577 LogicalPlan::Filter {
578 input: Box::new(LogicalPlan::GroupBy {
579 input: inner,
580 keys,
581 }),
582 predicate,
583 }
584 }
585 }
586 other => LogicalPlan::Filter {
587 input: Box::new(other),
588 predicate,
589 },
590 }
591 }
592 LogicalPlan::GroupBy { input, keys } => LogicalPlan::GroupBy {
593 input: Box::new(push_down_predicates(*input)),
594 keys,
595 },
596 LogicalPlan::Aggregate {
597 input,
598 keys,
599 aggs,
600 } => LogicalPlan::Aggregate {
601 input: Box::new(push_down_predicates(*input)),
602 keys,
603 aggs,
604 },
605 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
606 input: Box::new(push_down_predicates(*input)),
607 columns,
608 },
609 LogicalPlan::InnerJoin {
610 left,
611 right,
612 left_on,
613 right_on,
614 } => LogicalPlan::InnerJoin {
615 left: Box::new(push_down_predicates(*left)),
616 right: Box::new(push_down_predicates(*right)),
617 left_on,
618 right_on,
619 },
620 LogicalPlan::LeftJoin {
621 left,
622 right,
623 left_on,
624 right_on,
625 } => LogicalPlan::LeftJoin {
626 left: Box::new(push_down_predicates(*left)),
627 right: Box::new(push_down_predicates(*right)),
628 left_on,
629 right_on,
630 },
631 LogicalPlan::CrossJoin { left, right } => LogicalPlan::CrossJoin {
632 left: Box::new(push_down_predicates(*left)),
633 right: Box::new(push_down_predicates(*right)),
634 },
635 other => other,
636 }
637}
638
639fn prune_columns(plan: LogicalPlan) -> LogicalPlan {
641 plan
644}
645
646pub fn execute(plan: &LogicalPlan) -> Result<DataFrame, DataError> {
650 match plan {
651 LogicalPlan::Scan { source } => Ok(source.clone()),
652
653 LogicalPlan::Filter { input, predicate } => {
654 let df = execute(input)?;
655 execute_filter(&df, predicate)
656 }
657
658 LogicalPlan::GroupBy { input, keys: _ } => {
659 let df = execute(input)?;
661 Ok(df)
663 }
664
665 LogicalPlan::Aggregate { input, keys, aggs } => {
666 let df = execute(input)?;
667 execute_aggregate(&df, keys, aggs)
668 }
669
670 LogicalPlan::Project { input, columns } => {
671 let df = execute(input)?;
672 let projected = df
673 .columns
674 .into_iter()
675 .filter(|(name, _)| columns.contains(name))
676 .collect();
677 Ok(DataFrame { columns: projected })
678 }
679
680 LogicalPlan::InnerJoin {
681 left,
682 right,
683 left_on,
684 right_on,
685 } => {
686 let left_df = execute(left)?;
687 let right_df = execute(right)?;
688 execute_inner_join(&left_df, &right_df, left_on, right_on)
689 }
690
691 LogicalPlan::LeftJoin {
692 left,
693 right,
694 left_on,
695 right_on,
696 } => {
697 let left_df = execute(left)?;
698 let right_df = execute(right)?;
699 execute_left_join(&left_df, &right_df, left_on, right_on)
700 }
701
702 LogicalPlan::CrossJoin { left, right } => {
703 let left_df = execute(left)?;
704 let right_df = execute(right)?;
705 execute_cross_join(&left_df, &right_df)
706 }
707 }
708}
709
710fn execute_filter(df: &DataFrame, predicate: &DExpr) -> Result<DataFrame, DataError> {
711 let nrows = df.nrows();
712 let mut mask = vec![false; nrows];
713
714 for row in 0..nrows {
715 let val = eval_expr_row(df, predicate, row)?;
716 mask[row] = match val {
717 ExprValue::Bool(b) => b,
718 _ => return Err(DataError::InvalidOperation("filter predicate must be boolean".into())),
719 };
720 }
721
722 let mut new_columns = Vec::new();
723 for (name, col) in &df.columns {
724 let filtered = filter_column(col, &mask);
725 new_columns.push((name.clone(), filtered));
726 }
727
728 Ok(DataFrame {
729 columns: new_columns,
730 })
731}
732
733fn filter_column(col: &Column, mask: &[bool]) -> Column {
734 match col {
735 Column::Int(v) => Column::Int(
736 v.iter()
737 .zip(mask)
738 .filter(|(_, &m)| m)
739 .map(|(v, _)| *v)
740 .collect(),
741 ),
742 Column::Float(v) => Column::Float(
743 v.iter()
744 .zip(mask)
745 .filter(|(_, &m)| m)
746 .map(|(v, _)| *v)
747 .collect(),
748 ),
749 Column::Str(v) => Column::Str(
750 v.iter()
751 .zip(mask)
752 .filter(|(_, &m)| m)
753 .map(|(v, _)| v.clone())
754 .collect(),
755 ),
756 Column::Bool(v) => Column::Bool(
757 v.iter()
758 .zip(mask)
759 .filter(|(_, &m)| m)
760 .map(|(v, _)| *v)
761 .collect(),
762 ),
763 Column::Categorical { levels, codes } => Column::Categorical {
764 levels: levels.clone(),
765 codes: codes
766 .iter()
767 .zip(mask)
768 .filter(|(_, &m)| m)
769 .map(|(v, _)| *v)
770 .collect(),
771 },
772 Column::DateTime(v) => Column::DateTime(
773 v.iter()
774 .zip(mask)
775 .filter(|(_, &m)| m)
776 .map(|(v, _)| *v)
777 .collect(),
778 ),
779 }
780}
781
782fn execute_aggregate(
783 df: &DataFrame,
784 keys: &[String],
785 aggs: &[(String, DExpr)],
786) -> Result<DataFrame, DataError> {
787 let nrows = df.nrows();
789 let mut groups: BTreeMap<Vec<String>, Vec<usize>> = BTreeMap::new();
790
791 for row in 0..nrows {
792 let key: Vec<String> = keys
793 .iter()
794 .map(|k| {
795 df.get_column(k)
796 .map(|col| col.get_display(row))
797 .ok_or_else(|| DataError::ColumnNotFound(k.to_string()))
798 })
799 .collect::<Result<Vec<String>, DataError>>()?;
800 groups.entry(key).or_default().push(row);
801 }
802
803 let mut sorted_groups: Vec<(Vec<String>, Vec<usize>)> = groups.into_iter().collect();
805 sorted_groups.sort_by(|a, b| a.0.cmp(&b.0));
806
807 let mut result_columns: Vec<(String, Column)> = Vec::new();
809
810 for (key_idx, key_name) in keys.iter().enumerate() {
812 let values: Vec<String> = sorted_groups
813 .iter()
814 .map(|(key, _)| key[key_idx].clone())
815 .collect();
816 let source_col = df.get_column(key_name).ok_or_else(|| {
818 DataError::ColumnNotFound(key_name.clone())
819 })?;
820 match source_col {
821 Column::Int(_) => {
822 let int_vals: Vec<i64> = values.iter().map(|s| s.parse().unwrap_or(0)).collect();
823 result_columns.push((key_name.clone(), Column::Int(int_vals)));
824 }
825 Column::Str(_) => {
826 result_columns.push((key_name.clone(), Column::Str(values)));
827 }
828 _ => {
829 result_columns.push((key_name.clone(), Column::Str(values)));
830 }
831 }
832 }
833
834 for (agg_name, agg_expr) in aggs {
836 let mut values = Vec::new();
837 for (_, row_indices) in &sorted_groups {
838 let val = eval_agg_expr(df, agg_expr, row_indices)?;
839 values.push(val);
840 }
841 result_columns.push((agg_name.clone(), Column::Float(values)));
842 }
843
844 Ok(DataFrame {
845 columns: result_columns,
846 })
847}
848
849#[derive(Debug, Clone)]
852enum ExprValue {
853 Int(i64),
854 Float(f64),
855 Str(String),
856 Bool(bool),
857}
858
859fn eval_expr_row(df: &DataFrame, expr: &DExpr, row: usize) -> Result<ExprValue, DataError> {
860 match expr {
861 DExpr::Col(name) => {
862 let col = df
863 .get_column(name)
864 .ok_or_else(|| DataError::ColumnNotFound(name.clone()))?;
865 match col {
866 Column::Int(v) => Ok(ExprValue::Int(v[row])),
867 Column::Float(v) => Ok(ExprValue::Float(v[row])),
868 Column::Str(v) => Ok(ExprValue::Str(v[row].clone())),
869 Column::Bool(v) => Ok(ExprValue::Bool(v[row])),
870 Column::Categorical { levels, codes } => {
871 Ok(ExprValue::Str(levels[codes[row] as usize].clone()))
872 }
873 Column::DateTime(v) => Ok(ExprValue::Int(v[row])),
874 }
875 }
876 DExpr::LitInt(v) => Ok(ExprValue::Int(*v)),
877 DExpr::LitFloat(v) => Ok(ExprValue::Float(*v)),
878 DExpr::LitBool(b) => Ok(ExprValue::Bool(*b)),
879 DExpr::LitStr(s) => Ok(ExprValue::Str(s.clone())),
880 DExpr::BinOp { op, left, right } => {
881 let lv = eval_expr_row(df, left, row)?;
882 let rv = eval_expr_row(df, right, row)?;
883 eval_binop(*op, lv, rv)
884 }
885 DExpr::Agg(_, _) | DExpr::Count => Err(DataError::InvalidOperation(
886 "aggregation not allowed in row context".into(),
887 )),
888 DExpr::FnCall(name, args) => {
889 if args.len() != 1 {
890 return Err(DataError::InvalidOperation(
891 format!("FnCall '{}' requires exactly 1 argument, got {}", name, args.len()),
892 ));
893 }
894 let val = eval_expr_row(df, &args[0], row)?;
895 let x = match val {
896 ExprValue::Float(f) => f,
897 ExprValue::Int(i) => i as f64,
898 _ => return Err(DataError::InvalidOperation(
899 format!("FnCall '{}' requires numeric argument", name),
900 )),
901 };
902 let result = match name.as_str() {
903 "log" => x.ln(),
904 "exp" => x.exp(),
905 "sqrt" => x.sqrt(),
906 "abs" => x.abs(),
907 "ceil" => x.ceil(),
908 "floor" => x.floor(),
909 "round" => x.round(),
910 "sin" => x.sin(),
911 "cos" => x.cos(),
912 "tan" => x.tan(),
913 other => return Err(DataError::InvalidOperation(
914 format!("unknown DExpr function: {}", other),
915 )),
916 };
917 Ok(ExprValue::Float(result))
918 }
919 DExpr::CumSum(_) | DExpr::CumProd(_) | DExpr::CumMax(_) | DExpr::CumMin(_)
920 | DExpr::Lag(_, _) | DExpr::Lead(_, _) | DExpr::Rank(_) | DExpr::DenseRank(_)
921 | DExpr::RowNumber
922 | DExpr::RollingSum(..) | DExpr::RollingMean(..) | DExpr::RollingMin(..)
923 | DExpr::RollingMax(..) | DExpr::RollingVar(..) | DExpr::RollingSd(..) => {
924 Err(DataError::InvalidOperation(
925 "window function not allowed in row context; use eval_expr_column".into(),
926 ))
927 }
928 }
929}
930
931fn eval_binop(op: DBinOp, left: ExprValue, right: ExprValue) -> Result<ExprValue, DataError> {
932 match (left, right) {
933 (ExprValue::Int(a), ExprValue::Int(b)) => match op {
934 DBinOp::Add => Ok(ExprValue::Int(a + b)),
935 DBinOp::Sub => Ok(ExprValue::Int(a - b)),
936 DBinOp::Mul => Ok(ExprValue::Int(a * b)),
937 DBinOp::Div => Ok(ExprValue::Int(a / b)),
938 DBinOp::Gt => Ok(ExprValue::Bool(a > b)),
939 DBinOp::Lt => Ok(ExprValue::Bool(a < b)),
940 DBinOp::Ge => Ok(ExprValue::Bool(a >= b)),
941 DBinOp::Le => Ok(ExprValue::Bool(a <= b)),
942 DBinOp::Eq => Ok(ExprValue::Bool(a == b)),
943 DBinOp::Ne => Ok(ExprValue::Bool(a != b)),
944 _ => Err(DataError::InvalidOperation(format!(
945 "unsupported op {:?} on Int",
946 op
947 ))),
948 },
949 (ExprValue::Float(a), ExprValue::Float(b)) => match op {
950 DBinOp::Add => Ok(ExprValue::Float(a + b)),
951 DBinOp::Sub => Ok(ExprValue::Float(a - b)),
952 DBinOp::Mul => Ok(ExprValue::Float(a * b)),
953 DBinOp::Div => Ok(ExprValue::Float(a / b)),
954 DBinOp::Gt => Ok(ExprValue::Bool(a > b)),
955 DBinOp::Lt => Ok(ExprValue::Bool(a < b)),
956 DBinOp::Ge => Ok(ExprValue::Bool(a >= b)),
957 DBinOp::Le => Ok(ExprValue::Bool(a <= b)),
958 DBinOp::Eq => Ok(ExprValue::Bool(a == b)),
959 DBinOp::Ne => Ok(ExprValue::Bool(a != b)),
960 _ => Err(DataError::InvalidOperation(format!(
961 "unsupported op {:?} on Float",
962 op
963 ))),
964 },
965 (ExprValue::Int(a), ExprValue::Float(b)) => {
967 eval_binop(op, ExprValue::Float(a as f64), ExprValue::Float(b))
968 }
969 (ExprValue::Float(a), ExprValue::Int(b)) => {
970 eval_binop(op, ExprValue::Float(a), ExprValue::Float(b as f64))
971 }
972 (ExprValue::Bool(a), ExprValue::Bool(b)) => match op {
973 DBinOp::And => Ok(ExprValue::Bool(a && b)),
974 DBinOp::Or => Ok(ExprValue::Bool(a || b)),
975 DBinOp::Eq => Ok(ExprValue::Bool(a == b)),
976 DBinOp::Ne => Ok(ExprValue::Bool(a != b)),
977 _ => Err(DataError::InvalidOperation(format!(
978 "unsupported op {:?} on Bool",
979 op
980 ))),
981 },
982 (ExprValue::Str(a), ExprValue::Str(b)) => match op {
983 DBinOp::Eq => Ok(ExprValue::Bool(a == b)),
984 DBinOp::Ne => Ok(ExprValue::Bool(a != b)),
985 _ => Err(DataError::InvalidOperation(format!(
986 "unsupported op {:?} on String",
987 op
988 ))),
989 },
990 _ => Err(DataError::InvalidOperation(
991 "type mismatch in binary operation".into(),
992 )),
993 }
994}
995
996fn eval_agg_expr(
997 df: &DataFrame,
998 expr: &DExpr,
999 rows: &[usize],
1000) -> Result<f64, DataError> {
1001 match expr {
1002 DExpr::Agg(func, inner) => {
1003 let values = extract_float_values(df, inner, rows)?;
1004 match func {
1005 AggFunc::Sum => Ok(kahan_sum_f64(&values)),
1006 AggFunc::Mean => {
1007 if values.is_empty() {
1008 Ok(0.0)
1009 } else {
1010 Ok(kahan_sum_f64(&values) / values.len() as f64)
1011 }
1012 }
1013 AggFunc::Min => Ok(values
1014 .iter()
1015 .cloned()
1016 .fold(f64::INFINITY, f64::min)),
1017 AggFunc::Max => Ok(values
1018 .iter()
1019 .cloned()
1020 .fold(f64::NEG_INFINITY, f64::max)),
1021 AggFunc::Count => Ok(values.len() as f64),
1022 }
1023 }
1024 DExpr::Count => Ok(rows.len() as f64),
1025 _ => Err(DataError::InvalidOperation(
1026 "expected aggregation expression".into(),
1027 )),
1028 }
1029}
1030
1031fn extract_float_values(
1032 df: &DataFrame,
1033 expr: &DExpr,
1034 rows: &[usize],
1035) -> Result<Vec<f64>, DataError> {
1036 match expr {
1037 DExpr::Col(name) => {
1038 let col = df
1039 .get_column(name)
1040 .ok_or_else(|| DataError::ColumnNotFound(name.clone()))?;
1041 let vals: Vec<f64> = match col {
1042 Column::Float(v) => rows.iter().map(|&r| v[r]).collect(),
1043 Column::Int(v) => rows.iter().map(|&r| v[r] as f64).collect(),
1044 _ => {
1045 return Err(DataError::InvalidOperation(format!(
1046 "cannot aggregate non-numeric column `{}`",
1047 name
1048 )))
1049 }
1050 };
1051 Ok(vals)
1052 }
1053 _ => Err(DataError::InvalidOperation(
1054 "expected column reference in aggregation".into(),
1055 )),
1056 }
1057}
1058
1059pub struct Pipeline {
1063 plan: LogicalPlan,
1064}
1065
1066impl Pipeline {
1067 pub fn scan(df: DataFrame) -> Self {
1069 Self {
1070 plan: LogicalPlan::Scan { source: df },
1071 }
1072 }
1073
1074 pub fn filter(self, predicate: DExpr) -> Self {
1076 Self {
1077 plan: LogicalPlan::Filter {
1078 input: Box::new(self.plan),
1079 predicate,
1080 },
1081 }
1082 }
1083
1084 pub fn group_by(self, keys: Vec<String>) -> Self {
1086 Self {
1087 plan: LogicalPlan::GroupBy {
1088 input: Box::new(self.plan),
1089 keys,
1090 },
1091 }
1092 }
1093
1094 pub fn summarize(self, keys: Vec<String>, aggs: Vec<(String, DExpr)>) -> Self {
1096 Self {
1097 plan: LogicalPlan::Aggregate {
1098 input: Box::new(self.plan),
1099 keys,
1100 aggs,
1101 },
1102 }
1103 }
1104
1105 pub fn select(self, columns: Vec<String>) -> Self {
1107 Self {
1108 plan: LogicalPlan::Project {
1109 input: Box::new(self.plan),
1110 columns,
1111 },
1112 }
1113 }
1114
1115 pub fn inner_join(self, right: DataFrame, left_on: &str, right_on: &str) -> Self {
1117 Self {
1118 plan: LogicalPlan::InnerJoin {
1119 left: Box::new(self.plan),
1120 right: Box::new(LogicalPlan::Scan { source: right }),
1121 left_on: left_on.to_string(),
1122 right_on: right_on.to_string(),
1123 },
1124 }
1125 }
1126
1127 pub fn left_join(self, right: DataFrame, left_on: &str, right_on: &str) -> Self {
1129 Self {
1130 plan: LogicalPlan::LeftJoin {
1131 left: Box::new(self.plan),
1132 right: Box::new(LogicalPlan::Scan { source: right }),
1133 left_on: left_on.to_string(),
1134 right_on: right_on.to_string(),
1135 },
1136 }
1137 }
1138
1139 pub fn cross_join(self, right: DataFrame) -> Self {
1141 Self {
1142 plan: LogicalPlan::CrossJoin {
1143 left: Box::new(self.plan),
1144 right: Box::new(LogicalPlan::Scan { source: right }),
1145 },
1146 }
1147 }
1148
1149 pub fn collect(self) -> Result<DataFrame, DataError> {
1151 let optimized = optimize(self.plan);
1152 execute(&optimized)
1153 }
1154
1155 pub fn plan(&self) -> &LogicalPlan {
1157 &self.plan
1158 }
1159}
1160
1161#[derive(Debug, Clone)]
1165pub enum DataError {
1166 ColumnNotFound(String),
1168 ColumnLengthMismatch {
1170 expected: usize,
1172 got: usize,
1174 column: String,
1176 },
1177 InvalidOperation(String),
1179}
1180
1181impl fmt::Display for DataError {
1182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1183 match self {
1184 DataError::ColumnNotFound(name) => write!(f, "column `{}` not found", name),
1185 DataError::ColumnLengthMismatch {
1186 expected,
1187 got,
1188 column,
1189 } => write!(
1190 f,
1191 "column `{}` has {} rows, expected {}",
1192 column, got, expected
1193 ),
1194 DataError::InvalidOperation(msg) => write!(f, "invalid operation: {}", msg),
1195 }
1196 }
1197}
1198
1199impl std::error::Error for DataError {}
1200
1201fn column_value_str(col: &Column, row: usize) -> String {
1205 match col {
1206 Column::Int(v) => v[row].to_string(),
1207 Column::Float(v) => v[row].to_string(),
1208 Column::Str(v) => v[row].clone(),
1209 Column::Bool(v) => v[row].to_string(),
1210 Column::Categorical { levels, codes } => levels[codes[row] as usize].clone(),
1211 Column::DateTime(v) => v[row].to_string(),
1212 }
1213}
1214
1215fn execute_inner_join(
1216 left: &DataFrame,
1217 right: &DataFrame,
1218 left_on: &str,
1219 right_on: &str,
1220) -> Result<DataFrame, DataError> {
1221 let left_col = left.get_column(left_on)
1222 .ok_or_else(|| DataError::InvalidOperation(format!("join key `{}` not found in left", left_on)))?;
1223 let right_col = right.get_column(right_on)
1224 .ok_or_else(|| DataError::InvalidOperation(format!("join key `{}` not found in right", right_on)))?;
1225
1226 let right_nrows = right.nrows();
1228 let mut index: std::collections::BTreeMap<String, Vec<usize>> = std::collections::BTreeMap::new();
1229 for i in 0..right_nrows {
1230 let key = column_value_str(right_col, i);
1231 index.entry(key).or_default().push(i);
1232 }
1233
1234 let left_nrows = left.nrows();
1235 let mut left_indices = Vec::new();
1236 let mut right_indices = Vec::new();
1237
1238 for i in 0..left_nrows {
1239 let key = column_value_str(left_col, i);
1240 if let Some(matches) = index.get(&key) {
1241 for &j in matches {
1242 left_indices.push(i);
1243 right_indices.push(j);
1244 }
1245 }
1246 }
1247
1248 build_join_result(left, right, &left_indices, &right_indices, right_on)
1249}
1250
1251fn execute_left_join(
1252 left: &DataFrame,
1253 right: &DataFrame,
1254 left_on: &str,
1255 right_on: &str,
1256) -> Result<DataFrame, DataError> {
1257 let left_col = left.get_column(left_on)
1258 .ok_or_else(|| DataError::InvalidOperation(format!("join key `{}` not found in left", left_on)))?;
1259 let right_col = right.get_column(right_on)
1260 .ok_or_else(|| DataError::InvalidOperation(format!("join key `{}` not found in right", right_on)))?;
1261
1262 let right_nrows = right.nrows();
1263 let mut index: std::collections::BTreeMap<String, Vec<usize>> = std::collections::BTreeMap::new();
1264 for i in 0..right_nrows {
1265 let key = column_value_str(right_col, i);
1266 index.entry(key).or_default().push(i);
1267 }
1268
1269 let left_nrows = left.nrows();
1270 let mut left_indices = Vec::new();
1271 let mut right_indices: Vec<Option<usize>> = Vec::new();
1272
1273 for i in 0..left_nrows {
1274 let key = column_value_str(left_col, i);
1275 if let Some(matches) = index.get(&key) {
1276 for &j in matches {
1277 left_indices.push(i);
1278 right_indices.push(Some(j));
1279 }
1280 } else {
1281 left_indices.push(i);
1282 right_indices.push(None);
1283 }
1284 }
1285
1286 build_left_join_result(left, right, &left_indices, &right_indices, right_on)
1287}
1288
1289fn execute_cross_join(left: &DataFrame, right: &DataFrame) -> Result<DataFrame, DataError> {
1290 let left_nrows = left.nrows();
1291 let right_nrows = right.nrows();
1292 let mut left_indices = Vec::with_capacity(left_nrows * right_nrows);
1293 let mut right_indices = Vec::with_capacity(left_nrows * right_nrows);
1294
1295 for i in 0..left_nrows {
1296 for j in 0..right_nrows {
1297 left_indices.push(i);
1298 right_indices.push(j);
1299 }
1300 }
1301
1302 build_join_result(left, right, &left_indices, &right_indices, "")
1303}
1304
1305fn build_join_result(
1306 left: &DataFrame,
1307 right: &DataFrame,
1308 left_indices: &[usize],
1309 right_indices: &[usize],
1310 right_on: &str,
1311) -> Result<DataFrame, DataError> {
1312 let mut columns = Vec::new();
1313
1314 for (name, col) in &left.columns {
1316 columns.push((name.clone(), gather_column(col, left_indices)));
1317 }
1318
1319 for (name, col) in &right.columns {
1321 if name == right_on {
1322 continue;
1323 }
1324 let out_name = if left.get_column(name).is_some() {
1325 format!("{}_right", name)
1326 } else {
1327 name.clone()
1328 };
1329 columns.push((out_name, gather_column(col, right_indices)));
1330 }
1331
1332 Ok(DataFrame { columns })
1333}
1334
1335fn build_left_join_result(
1336 left: &DataFrame,
1337 right: &DataFrame,
1338 left_indices: &[usize],
1339 right_indices: &[Option<usize>],
1340 right_on: &str,
1341) -> Result<DataFrame, DataError> {
1342 let mut columns = Vec::new();
1343
1344 for (name, col) in &left.columns {
1345 columns.push((name.clone(), gather_column(col, left_indices)));
1346 }
1347
1348 for (name, col) in &right.columns {
1349 if name == right_on {
1350 continue;
1351 }
1352 let out_name = if left.get_column(name).is_some() {
1353 format!("{}_right", name)
1354 } else {
1355 name.clone()
1356 };
1357 columns.push((out_name, gather_column_nullable(col, right_indices)));
1358 }
1359
1360 Ok(DataFrame { columns })
1361}
1362
1363fn gather_column(col: &Column, indices: &[usize]) -> Column {
1364 match col {
1365 Column::Int(v) => Column::Int(indices.iter().map(|&i| v[i]).collect()),
1366 Column::Float(v) => Column::Float(indices.iter().map(|&i| v[i]).collect()),
1367 Column::Str(v) => Column::Str(indices.iter().map(|&i| v[i].clone()).collect()),
1368 Column::Bool(v) => Column::Bool(indices.iter().map(|&i| v[i]).collect()),
1369 Column::Categorical { levels, codes } => Column::Categorical {
1370 levels: levels.clone(),
1371 codes: indices.iter().map(|&i| codes[i]).collect(),
1372 },
1373 Column::DateTime(v) => Column::DateTime(indices.iter().map(|&i| v[i]).collect()),
1374 }
1375}
1376
1377fn gather_column_nullable(col: &Column, indices: &[Option<usize>]) -> Column {
1378 match col {
1379 Column::Int(v) => Column::Int(indices.iter().map(|opt| opt.map_or(0, |i| v[i])).collect()),
1380 Column::Float(v) => Column::Float(indices.iter().map(|opt| opt.map_or(f64::NAN, |i| v[i])).collect()),
1381 Column::Str(v) => Column::Str(indices.iter().map(|opt| opt.map_or_else(String::new, |i| v[i].clone())).collect()),
1382 Column::Bool(v) => Column::Bool(indices.iter().map(|opt| opt.map_or(false, |i| v[i])).collect()),
1383 Column::Categorical { levels, codes } => Column::Categorical {
1384 levels: levels.clone(),
1385 codes: indices.iter().map(|opt| opt.map_or(0, |i| codes[i])).collect(),
1386 },
1387 Column::DateTime(v) => Column::DateTime(indices.iter().map(|opt| opt.map_or(0, |i| v[i])).collect()),
1388 }
1389}
1390
1391#[cfg(test)]
1394mod tests {
1395 use super::*;
1396
1397 fn sample_df() -> DataFrame {
1398 DataFrame::from_columns(vec![
1399 (
1400 "name".into(),
1401 Column::Str(vec![
1402 "Alice".into(),
1403 "Bob".into(),
1404 "Carol".into(),
1405 "Dave".into(),
1406 "Eve".into(),
1407 "Frank".into(),
1408 ]),
1409 ),
1410 (
1411 "dept".into(),
1412 Column::Str(vec![
1413 "eng".into(),
1414 "eng".into(),
1415 "sales".into(),
1416 "eng".into(),
1417 "sales".into(),
1418 "eng".into(),
1419 ]),
1420 ),
1421 (
1422 "salary".into(),
1423 Column::Float(vec![95000.0, 102000.0, 78000.0, 110000.0, 82000.0, 98000.0]),
1424 ),
1425 (
1426 "tenure".into(),
1427 Column::Int(vec![3, 7, 2, 10, 1, 5]),
1428 ),
1429 ])
1430 .unwrap()
1431 }
1432
1433 #[test]
1434 fn test_dataframe_creation() {
1435 let df = sample_df();
1436 assert_eq!(df.nrows(), 6);
1437 assert_eq!(df.ncols(), 4);
1438 assert_eq!(
1439 df.column_names(),
1440 vec!["name", "dept", "salary", "tenure"]
1441 );
1442 }
1443
1444 #[test]
1445 fn test_filter() {
1446 let df = sample_df();
1447
1448 let result = Pipeline::scan(df)
1450 .filter(DExpr::BinOp {
1451 op: DBinOp::Gt,
1452 left: Box::new(DExpr::Col("tenure".into())),
1453 right: Box::new(DExpr::LitInt(2)),
1454 })
1455 .collect()
1456 .unwrap();
1457
1458 assert_eq!(result.nrows(), 4); }
1460
1461 #[test]
1462 fn test_group_by_summarize() {
1463 let df = sample_df();
1464
1465 let result = Pipeline::scan(df)
1466 .summarize(
1467 vec!["dept".into()],
1468 vec![
1469 (
1470 "avg_salary".into(),
1471 DExpr::Agg(AggFunc::Mean, Box::new(DExpr::Col("salary".into()))),
1472 ),
1473 ("headcount".into(), DExpr::Count),
1474 ],
1475 )
1476 .collect()
1477 .unwrap();
1478
1479 assert_eq!(result.nrows(), 2); let dept_col = result.get_column("dept").unwrap();
1483 let avg_col = result.get_column("avg_salary").unwrap();
1484 let count_col = result.get_column("headcount").unwrap();
1485
1486 if let (Column::Str(depts), Column::Float(avgs), Column::Float(counts)) =
1487 (dept_col, avg_col, count_col)
1488 {
1489 let eng_idx = depts.iter().position(|d| d == "eng").unwrap();
1490 let sales_idx = depts.iter().position(|d| d == "sales").unwrap();
1491
1492 assert!((avgs[eng_idx] - 101250.0).abs() < 0.01);
1494 assert!((counts[eng_idx] - 4.0).abs() < 0.01);
1495
1496 assert!((avgs[sales_idx] - 80000.0).abs() < 0.01);
1498 assert!((counts[sales_idx] - 2.0).abs() < 0.01);
1499 } else {
1500 panic!("unexpected column types");
1501 }
1502 }
1503
1504 #[test]
1505 fn test_filter_then_aggregate() {
1506 let df = sample_df();
1507
1508 let result = Pipeline::scan(df)
1510 .filter(DExpr::BinOp {
1511 op: DBinOp::Gt,
1512 left: Box::new(DExpr::Col("tenure".into())),
1513 right: Box::new(DExpr::LitInt(2)),
1514 })
1515 .summarize(
1516 vec!["dept".into()],
1517 vec![
1518 (
1519 "avg_salary".into(),
1520 DExpr::Agg(AggFunc::Mean, Box::new(DExpr::Col("salary".into()))),
1521 ),
1522 (
1523 "max_tenure".into(),
1524 DExpr::Agg(AggFunc::Max, Box::new(DExpr::Col("tenure".into()))),
1525 ),
1526 ("headcount".into(), DExpr::Count),
1527 ],
1528 )
1529 .collect()
1530 .unwrap();
1531
1532 assert_eq!(result.nrows(), 1);
1535
1536 if let Column::Float(avgs) = result.get_column("avg_salary").unwrap() {
1537 assert!((avgs[0] - 101250.0).abs() < 0.01);
1539 }
1540 if let Column::Float(maxes) = result.get_column("max_tenure").unwrap() {
1541 assert!((maxes[0] - 10.0).abs() < 0.01);
1542 }
1543 }
1544
1545 #[test]
1546 fn test_to_tensor_data() {
1547 let df = DataFrame::from_columns(vec![
1548 ("x".into(), Column::Float(vec![1.0, 2.0, 3.0])),
1549 ("y".into(), Column::Float(vec![4.0, 5.0, 6.0])),
1550 ])
1551 .unwrap();
1552
1553 let (data, shape) = df.to_tensor_data(&["x", "y"]).unwrap();
1554 assert_eq!(shape, vec![3, 2]);
1555 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1556 }
1557
1558 #[test]
1559 fn test_display() {
1560 let df = DataFrame::from_columns(vec![
1561 ("x".into(), Column::Int(vec![1, 2, 3])),
1562 ("y".into(), Column::Float(vec![4.5, 5.5, 6.5])),
1563 ])
1564 .unwrap();
1565
1566 let output = format!("{}", df);
1567 assert!(output.contains("x"));
1568 assert!(output.contains("y"));
1569 assert!(output.contains("4.5"));
1570 }
1571
1572 #[test]
1573 fn test_column_not_found() {
1574 let df = sample_df();
1575 let result = Pipeline::scan(df)
1576 .filter(DExpr::BinOp {
1577 op: DBinOp::Gt,
1578 left: Box::new(DExpr::Col("nonexistent".into())),
1579 right: Box::new(DExpr::LitInt(0)),
1580 })
1581 .collect();
1582
1583 assert!(result.is_err());
1584 }
1585
1586 #[test]
1587 fn test_aggregation_functions() {
1588 let df = DataFrame::from_columns(vec![
1589 ("group".into(), Column::Str(vec!["a".into(), "a".into(), "a".into()])),
1590 ("val".into(), Column::Float(vec![10.0, 20.0, 30.0])),
1591 ])
1592 .unwrap();
1593
1594 let result = Pipeline::scan(df)
1595 .summarize(
1596 vec!["group".into()],
1597 vec![
1598 ("total".into(), DExpr::Agg(AggFunc::Sum, Box::new(DExpr::Col("val".into())))),
1599 ("avg".into(), DExpr::Agg(AggFunc::Mean, Box::new(DExpr::Col("val".into())))),
1600 ("lo".into(), DExpr::Agg(AggFunc::Min, Box::new(DExpr::Col("val".into())))),
1601 ("hi".into(), DExpr::Agg(AggFunc::Max, Box::new(DExpr::Col("val".into())))),
1602 ("n".into(), DExpr::Count),
1603 ],
1604 )
1605 .collect()
1606 .unwrap();
1607
1608 if let Column::Float(totals) = result.get_column("total").unwrap() {
1609 assert!((totals[0] - 60.0).abs() < 0.01);
1610 }
1611 if let Column::Float(avgs) = result.get_column("avg").unwrap() {
1612 assert!((avgs[0] - 20.0).abs() < 0.01);
1613 }
1614 if let Column::Float(mins) = result.get_column("lo").unwrap() {
1615 assert!((mins[0] - 10.0).abs() < 0.01);
1616 }
1617 if let Column::Float(maxs) = result.get_column("hi").unwrap() {
1618 assert!((maxs[0] - 30.0).abs() < 0.01);
1619 }
1620 if let Column::Float(counts) = result.get_column("n").unwrap() {
1621 assert!((counts[0] - 3.0).abs() < 0.01);
1622 }
1623 }
1624
1625 #[test]
1626 fn test_empty_dataframe() {
1627 let df = DataFrame::new();
1628 assert_eq!(df.nrows(), 0);
1629 assert_eq!(df.ncols(), 0);
1630 }
1631
1632 #[test]
1633 fn test_expr_display() {
1634 let expr = DExpr::BinOp {
1635 op: DBinOp::Gt,
1636 left: Box::new(DExpr::Col("age".into())),
1637 right: Box::new(DExpr::LitInt(18)),
1638 };
1639 assert_eq!(format!("{}", expr), "(col(\"age\") > 18)");
1640 }
1641
1642 #[test]
1645 fn test_categorical_column_basics() {
1646 let col = Column::Categorical {
1647 levels: vec!["bird".into(), "cat".into(), "dog".into()],
1648 codes: vec![1, 2, 1, 0],
1649 };
1650 assert_eq!(col.len(), 4);
1651 assert_eq!(col.type_name(), "Categorical");
1652 assert_eq!(col.get_display(0), "cat");
1653 assert_eq!(col.get_display(1), "dog");
1654 assert_eq!(col.get_display(2), "cat");
1655 assert_eq!(col.get_display(3), "bird");
1656 }
1657
1658 #[test]
1659 fn test_datetime_column_basics() {
1660 let col = Column::DateTime(vec![1000, 2000, 3000]);
1661 assert_eq!(col.len(), 3);
1662 assert_eq!(col.type_name(), "DateTime");
1663 assert_eq!(col.get_display(0), "1000ms");
1664 assert_eq!(col.get_display(1), "2000ms");
1665 }
1666
1667 #[test]
1668 fn test_label_encode() {
1669 let data: Vec<String> = vec!["cat".into(), "dog".into(), "cat".into(), "bird".into()];
1670 let (levels, codes) = label_encode(&data);
1671 assert_eq!(levels, vec!["bird", "cat", "dog"]);
1672 assert_eq!(codes, vec![1, 2, 1, 0]);
1673 }
1674
1675 #[test]
1676 fn test_label_encode_empty() {
1677 let data: Vec<String> = vec![];
1678 let (levels, codes) = label_encode(&data);
1679 assert!(levels.is_empty());
1680 assert!(codes.is_empty());
1681 }
1682
1683 #[test]
1684 fn test_label_encode_single_level() {
1685 let data: Vec<String> = vec!["x".into(), "x".into(), "x".into()];
1686 let (levels, codes) = label_encode(&data);
1687 assert_eq!(levels, vec!["x"]);
1688 assert_eq!(codes, vec![0, 0, 0]);
1689 }
1690
1691 #[test]
1692 fn test_label_encode_deterministic() {
1693 let data: Vec<String> = vec!["z".into(), "a".into(), "m".into(), "a".into(), "z".into()];
1695 let (levels1, codes1) = label_encode(&data);
1696 let (levels2, codes2) = label_encode(&data);
1697 assert_eq!(levels1, levels2);
1698 assert_eq!(codes1, codes2);
1699 assert_eq!(levels1, vec!["a", "m", "z"]);
1701 }
1702
1703 #[test]
1704 fn test_ordinal_encode() {
1705 let data: Vec<String> = vec!["low".into(), "high".into(), "medium".into(), "low".into()];
1706 let order: Vec<String> = vec!["low".into(), "medium".into(), "high".into()];
1707 let (levels, codes) = ordinal_encode(&data, &order).unwrap();
1708 assert_eq!(levels, vec!["low", "medium", "high"]);
1709 assert_eq!(codes, vec![0, 2, 1, 0]);
1710 }
1711
1712 #[test]
1713 fn test_ordinal_encode_missing_value() {
1714 let data: Vec<String> = vec!["low".into(), "unknown".into()];
1715 let order: Vec<String> = vec!["low".into(), "medium".into(), "high".into()];
1716 let result = ordinal_encode(&data, &order);
1717 assert!(result.is_err());
1718 assert!(result.unwrap_err().contains("unknown"));
1719 }
1720
1721 #[test]
1722 fn test_one_hot_encode() {
1723 let levels = vec!["bird".to_string(), "cat".to_string(), "dog".to_string()];
1724 let codes = vec![1u32, 2, 1, 0];
1725 let (names, cols) = one_hot_encode(&levels, &codes);
1726 assert_eq!(names, vec!["bird", "cat", "dog"]);
1727 assert_eq!(cols.len(), 3);
1728 assert_eq!(cols[0], vec![false, false, false, true]);
1730 assert_eq!(cols[1], vec![true, false, true, false]);
1732 assert_eq!(cols[2], vec![false, true, false, false]);
1734
1735 for row in 0..4 {
1737 let count: usize = cols.iter().map(|c| if c[row] { 1 } else { 0 }).sum();
1738 assert_eq!(count, 1, "row {} should have exactly one true", row);
1739 }
1740 }
1741
1742 #[test]
1743 fn test_one_hot_encode_empty() {
1744 let levels = vec!["a".to_string(), "b".to_string()];
1745 let codes: Vec<u32> = vec![];
1746 let (names, cols) = one_hot_encode(&levels, &codes);
1747 assert_eq!(names.len(), 2);
1748 assert!(cols[0].is_empty());
1749 assert!(cols[1].is_empty());
1750 }
1751
1752 #[test]
1753 fn test_categorical_column_in_dataframe() {
1754 let data: Vec<String> = vec!["cat".into(), "dog".into(), "cat".into()];
1755 let (levels, codes) = label_encode(&data);
1756 let df = DataFrame::from_columns(vec![
1757 ("animal".into(), Column::Categorical { levels, codes }),
1758 ("score".into(), Column::Float(vec![1.0, 2.0, 3.0])),
1759 ])
1760 .unwrap();
1761 assert_eq!(df.nrows(), 3);
1762 assert_eq!(df.ncols(), 2);
1763 assert_eq!(df.get_column("animal").unwrap().type_name(), "Categorical");
1764 }
1765
1766 #[test]
1767 fn test_datetime_column_in_dataframe() {
1768 let df = DataFrame::from_columns(vec![
1769 ("ts".into(), Column::DateTime(vec![1000, 2000, 3000])),
1770 ("val".into(), Column::Float(vec![1.0, 2.0, 3.0])),
1771 ])
1772 .unwrap();
1773 assert_eq!(df.nrows(), 3);
1774 assert_eq!(df.get_column("ts").unwrap().type_name(), "DateTime");
1775 }
1776
1777 #[test]
1778 fn test_label_encode_to_column_roundtrip() {
1779 let data: Vec<String> = vec!["cat".into(), "dog".into(), "cat".into(), "bird".into()];
1780 let (levels, codes) = label_encode(&data);
1781 let col = Column::Categorical { levels: levels.clone(), codes: codes.clone() };
1782 for (i, original) in data.iter().enumerate() {
1784 assert_eq!(col.get_display(i), *original);
1785 }
1786 }
1787}
1788
1789impl DataFrame {
1794 pub fn to_tensor(
1801 &self,
1802 col_names: &[&str],
1803 ) -> Result<cjc_runtime::Tensor, DataError> {
1804 let (data, shape) = self.to_tensor_data(col_names)?;
1805 cjc_runtime::Tensor::from_vec(data, &shape)
1806 .map_err(|e| DataError::InvalidOperation(format!("tensor conversion: {}", e)))
1807 }
1808
1809 pub fn push_row(&mut self, values: &[&str]) -> Result<(), DataError> {
1818 if values.len() != self.ncols() {
1819 return Err(DataError::ColumnLengthMismatch {
1820 expected: self.ncols(),
1821 got: values.len(),
1822 column: "row".to_string(),
1823 });
1824 }
1825 for (i, (_, col)) in self.columns.iter_mut().enumerate() {
1826 let s = values[i];
1827 match col {
1828 Column::Float(v) => v.push(s.trim().parse::<f64>().unwrap_or(0.0)),
1829 Column::Int(v) => v.push(s.trim().parse::<i64>().unwrap_or(0)),
1830 Column::Str(v) => v.push(s.to_string()),
1831 Column::Bool(v) => v.push(matches!(s.trim(), "true" | "1")),
1832 Column::Categorical { .. } => {
1833 }
1835 Column::DateTime(v) => v.push(s.trim().parse::<i64>().unwrap_or(0)),
1836 }
1837 }
1838 Ok(())
1839 }
1840}
1841
1842#[derive(Debug, Clone, PartialEq, Eq)]
1858pub struct BitMask {
1859 words: Vec<u64>,
1860 nrows: usize,
1861}
1862
1863impl BitMask {
1864 pub fn all_true(nrows: usize) -> Self {
1866 let nwords = nwords_for(nrows);
1867 let mut words = vec![u64::MAX; nwords];
1868 if nrows % 64 != 0 && nwords > 0 {
1870 let tail = nrows % 64;
1871 words[nwords - 1] = (1u64 << tail) - 1;
1872 }
1873 BitMask { words, nrows }
1874 }
1875
1876 pub fn all_false(nrows: usize) -> Self {
1878 let nwords = nwords_for(nrows);
1879 BitMask {
1880 words: vec![0u64; nwords],
1881 nrows,
1882 }
1883 }
1884
1885 pub fn from_bools(bools: &[bool]) -> Self {
1887 let nrows = bools.len();
1888 let nwords = nwords_for(nrows);
1889 let mut words = vec![0u64; nwords];
1890 for (i, &b) in bools.iter().enumerate() {
1891 if b {
1892 words[i / 64] |= 1u64 << (i % 64);
1893 }
1894 }
1895 BitMask { words, nrows }
1896 }
1897
1898 #[inline]
1900 pub fn get(&self, i: usize) -> bool {
1901 debug_assert!(i < self.nrows);
1902 (self.words[i / 64] >> (i % 64)) & 1 == 1
1903 }
1904
1905 pub fn count_ones(&self) -> usize {
1907 self.words.iter().map(|w| w.count_ones() as usize).sum()
1908 }
1909
1910 pub fn and(&self, other: &BitMask) -> BitMask {
1914 assert_eq!(
1915 self.nrows, other.nrows,
1916 "BitMask::and: nrows mismatch ({} vs {})",
1917 self.nrows, other.nrows
1918 );
1919 let words = self
1920 .words
1921 .iter()
1922 .zip(other.words.iter())
1923 .map(|(a, b)| a & b)
1924 .collect();
1925 BitMask {
1926 words,
1927 nrows: self.nrows,
1928 }
1929 }
1930
1931 pub fn iter_set(&self) -> impl Iterator<Item = usize> + '_ {
1933 (0..self.nrows).filter(move |&i| self.get(i))
1934 }
1935
1936 pub fn nrows(&self) -> usize {
1938 self.nrows
1939 }
1940
1941 pub fn nwords(&self) -> usize {
1943 self.words.len()
1944 }
1945}
1946
1947#[inline]
1948fn nwords_for(nrows: usize) -> usize {
1949 (nrows + 63) / 64
1950}
1951
1952#[derive(Debug, Clone, PartialEq, Eq)]
1960pub struct ProjectionMap {
1961 indices: Vec<usize>,
1963}
1964
1965impl ProjectionMap {
1966 pub fn identity(ncols: usize) -> Self {
1968 ProjectionMap {
1969 indices: (0..ncols).collect(),
1970 }
1971 }
1972
1973 pub fn from_indices(indices: Vec<usize>) -> Self {
1975 ProjectionMap { indices }
1976 }
1977
1978 pub fn len(&self) -> usize {
1980 self.indices.len()
1981 }
1982
1983 pub fn is_empty(&self) -> bool {
1985 self.indices.is_empty()
1986 }
1987
1988 pub fn indices(&self) -> &[usize] {
1990 &self.indices
1991 }
1992}
1993
1994#[derive(Debug, Clone)]
2005pub struct TidyView {
2006 base: Rc<DataFrame>,
2007 mask: BitMask,
2008 proj: ProjectionMap,
2009}
2010
2011fn try_eval_predicate_columnar(
2025 base: &DataFrame,
2026 predicate: &DExpr,
2027 existing_mask: &BitMask,
2028) -> Option<BitMask> {
2029 match predicate {
2030 DExpr::BinOp {
2032 op: DBinOp::And,
2033 left,
2034 right,
2035 } => {
2036 let lmask = try_eval_predicate_columnar(base, left, existing_mask)?;
2037 let rmask = try_eval_predicate_columnar(base, right, &lmask)?;
2038 Some(rmask)
2039 }
2040 DExpr::BinOp {
2043 op: DBinOp::Or,
2044 left,
2045 right,
2046 } => {
2047 let all_mask = BitMask::all_true(existing_mask.nrows);
2050 let lmask = try_eval_predicate_columnar(base, left, &all_mask)?;
2051 let rmask = try_eval_predicate_columnar(base, right, &all_mask)?;
2052 let nrows = existing_mask.nrows;
2054 let or_words: Vec<u64> = lmask
2055 .words
2056 .iter()
2057 .zip(rmask.words.iter())
2058 .map(|(a, b)| a | b)
2059 .collect();
2060 let final_words: Vec<u64> = or_words
2062 .iter()
2063 .zip(existing_mask.words.iter())
2064 .map(|(a, b)| a & b)
2065 .collect();
2066 Some(BitMask {
2067 words: final_words,
2068 nrows,
2069 })
2070 }
2071 DExpr::BinOp { op, left, right } => {
2073 if !matches!(
2075 op,
2076 DBinOp::Gt | DBinOp::Lt | DBinOp::Ge | DBinOp::Le | DBinOp::Eq | DBinOp::Ne
2077 ) {
2078 return None;
2079 }
2080
2081 enum LitVal {
2084 F(f64),
2085 I(i64),
2086 }
2087
2088 let (col_name, lit, reversed) = match (left.as_ref(), right.as_ref()) {
2089 (DExpr::Col(name), DExpr::LitFloat(v)) => (name.as_str(), LitVal::F(*v), false),
2090 (DExpr::LitFloat(v), DExpr::Col(name)) => (name.as_str(), LitVal::F(*v), true),
2091 (DExpr::Col(name), DExpr::LitInt(v)) => (name.as_str(), LitVal::I(*v), false),
2092 (DExpr::LitInt(v), DExpr::Col(name)) => (name.as_str(), LitVal::I(*v), true),
2093 _ => return None,
2094 };
2095
2096 let column = base.get_column(col_name)?;
2097
2098 let effective_op = if reversed {
2100 match op {
2101 DBinOp::Gt => DBinOp::Lt,
2102 DBinOp::Lt => DBinOp::Gt,
2103 DBinOp::Ge => DBinOp::Le,
2104 DBinOp::Le => DBinOp::Ge,
2105 other => *other, }
2107 } else {
2108 *op
2109 };
2110
2111 let nrows = existing_mask.nrows;
2112 let nwords = nwords_for(nrows);
2113 let mut words = vec![0u64; nwords];
2114
2115 match (column, &lit) {
2116 (Column::Float(data), LitVal::F(v)) => {
2118 columnar_cmp_f64(data, *v, effective_op, &mut words);
2119 }
2120 (Column::Float(data), LitVal::I(v)) => {
2122 columnar_cmp_f64(data, *v as f64, effective_op, &mut words);
2123 }
2124 (Column::Int(data), LitVal::I(v)) => {
2126 columnar_cmp_i64(data, *v, effective_op, &mut words);
2127 }
2128 (Column::Int(data), LitVal::F(v)) => {
2130 let floats: Vec<f64> = data.iter().map(|&x| x as f64).collect();
2132 columnar_cmp_f64(&floats, *v, effective_op, &mut words);
2133 }
2134 _ => return None,
2135 }
2136
2137 for (w, ew) in words.iter_mut().zip(existing_mask.words.iter()) {
2139 *w &= *ew;
2140 }
2141
2142 Some(BitMask { words, nrows })
2143 }
2144 _ => None,
2145 }
2146}
2147
2148#[inline]
2152fn columnar_cmp_f64(data: &[f64], lit: f64, op: DBinOp, out_words: &mut [u64]) {
2153 for (i, &val) in data.iter().enumerate() {
2154 let pass = match op {
2155 DBinOp::Gt => val > lit,
2156 DBinOp::Lt => val < lit,
2157 DBinOp::Ge => val >= lit,
2158 DBinOp::Le => val <= lit,
2159 DBinOp::Eq => val == lit,
2160 DBinOp::Ne => val != lit,
2161 _ => false,
2162 };
2163 if pass {
2164 out_words[i / 64] |= 1u64 << (i % 64);
2165 }
2166 }
2167}
2168
2169#[inline]
2172fn columnar_cmp_i64(data: &[i64], lit: i64, op: DBinOp, out_words: &mut [u64]) {
2173 for (i, &val) in data.iter().enumerate() {
2174 let pass = match op {
2175 DBinOp::Gt => val > lit,
2176 DBinOp::Lt => val < lit,
2177 DBinOp::Ge => val >= lit,
2178 DBinOp::Le => val <= lit,
2179 DBinOp::Eq => val == lit,
2180 DBinOp::Ne => val != lit,
2181 _ => false,
2182 };
2183 if pass {
2184 out_words[i / 64] |= 1u64 << (i % 64);
2185 }
2186 }
2187}
2188
2189impl TidyView {
2190 pub fn from_df(df: DataFrame) -> Self {
2194 let nrows = df.nrows();
2195 let ncols = df.ncols();
2196 TidyView {
2197 base: Rc::new(df),
2198 mask: BitMask::all_true(nrows),
2199 proj: ProjectionMap::identity(ncols),
2200 }
2201 }
2202
2203 pub fn from_rc(df: Rc<DataFrame>) -> Self {
2205 let nrows = df.nrows();
2206 let ncols = df.ncols();
2207 TidyView {
2208 base: df,
2209 mask: BitMask::all_true(nrows),
2210 proj: ProjectionMap::identity(ncols),
2211 }
2212 }
2213
2214 pub fn nrows(&self) -> usize {
2218 self.mask.count_ones()
2219 }
2220
2221 pub fn ncols(&self) -> usize {
2223 self.proj.len()
2224 }
2225
2226 pub fn column_names(&self) -> Vec<&str> {
2228 self.proj
2229 .indices()
2230 .iter()
2231 .map(|&ci| self.base.columns[ci].0.as_str())
2232 .collect()
2233 }
2234
2235 pub fn filter(&self, predicate: &DExpr) -> Result<TidyView, TidyError> {
2248 validate_expr_columns_proj(predicate, &self.base, &self.proj)?;
2250
2251 if let Some(new_mask) = try_eval_predicate_columnar(&self.base, predicate, &self.mask) {
2253 return Ok(TidyView {
2254 base: Rc::clone(&self.base),
2255 mask: new_mask,
2256 proj: self.proj.clone(),
2257 });
2258 }
2259
2260 let nrows_base = self.base.nrows();
2262 let mut new_words = self.mask.words.clone();
2263
2264 for row in self.mask.iter_set() {
2267 let b = eval_expr_row_proj(&self.base, predicate, row, &self.proj)?;
2268 let pass = match b {
2269 ExprValue::Bool(v) => v,
2270 _ => {
2271 return Err(TidyError::PredicateNotBool {
2272 got: b.type_name().to_string(),
2273 })
2274 }
2275 };
2276 if !pass {
2277 new_words[row / 64] &= !(1u64 << (row % 64));
2278 }
2279 }
2280
2281 Ok(TidyView {
2282 base: Rc::clone(&self.base),
2283 mask: BitMask {
2284 words: new_words,
2285 nrows: nrows_base,
2286 },
2287 proj: self.proj.clone(),
2288 })
2289 }
2290
2291 pub fn select(&self, cols: &[&str]) -> Result<TidyView, TidyError> {
2304 {
2306 let mut seen = std::collections::BTreeSet::new();
2307 for &name in cols {
2308 if !seen.insert(name) {
2309 return Err(TidyError::DuplicateColumn(name.to_string()));
2310 }
2311 }
2312 }
2313
2314 let mut new_indices = Vec::with_capacity(cols.len());
2316 for &name in cols {
2317 let idx = self
2318 .base
2319 .columns
2320 .iter()
2321 .position(|(n, _)| n == name)
2322 .ok_or_else(|| TidyError::ColumnNotFound(name.to_string()))?;
2323 new_indices.push(idx);
2324 }
2325
2326 Ok(TidyView {
2327 base: Rc::clone(&self.base),
2328 mask: self.mask.clone(),
2329 proj: ProjectionMap::from_indices(new_indices),
2330 })
2331 }
2332
2333 pub fn mutate(&self, assignments: &[(&str, DExpr)]) -> Result<TidyFrame, TidyError> {
2353 {
2355 let mut seen = std::collections::BTreeSet::new();
2356 for &(name, _) in assignments {
2357 if !seen.insert(name) {
2358 return Err(TidyError::DuplicateColumn(name.to_string()));
2359 }
2360 }
2361 }
2362
2363 let mut df = self.materialize()?;
2365
2366 let snapshot_names: Vec<String> = df.columns.iter().map(|(n, _)| n.clone()).collect();
2368
2369 for &(col_name, ref expr) in assignments {
2370 validate_expr_columns_snapshot(expr, &snapshot_names)?;
2372
2373 let nrows = df.nrows();
2374 let new_col = eval_expr_column(&df, expr, nrows)?;
2376
2377 if let Some(pos) = df.columns.iter().position(|(n, _)| n == col_name) {
2379 df.columns[pos].1 = new_col;
2380 } else {
2381 df.columns.push((col_name.to_string(), new_col));
2382 }
2383 }
2384
2385 Ok(TidyFrame {
2386 inner: Rc::new(RefCell::new(df)),
2387 })
2388 }
2389
2390 pub fn materialize(&self) -> Result<DataFrame, TidyError> {
2402 let row_indices: Vec<usize> = self.mask.iter_set().collect();
2403
2404 let mut columns = Vec::with_capacity(self.proj.len());
2405 for &ci in self.proj.indices() {
2406 let (name, col) = &self.base.columns[ci];
2407 let new_col = gather_column(col, &row_indices);
2408 columns.push((name.clone(), new_col));
2409 }
2410
2411 DataFrame::from_columns(columns)
2412 .map_err(|e| TidyError::Internal(e.to_string()))
2413 }
2414
2415 pub fn to_tensor(&self, col_names: &[&str]) -> Result<cjc_runtime::Tensor, TidyError> {
2419 let df = self.materialize()?;
2420 df.to_tensor(col_names)
2421 .map_err(|e| TidyError::Internal(e.to_string()))
2422 }
2423
2424 pub fn mask(&self) -> &BitMask {
2426 &self.mask
2427 }
2428
2429 pub fn proj(&self) -> &ProjectionMap {
2431 &self.proj
2432 }
2433
2434 pub fn base_column(&self, name: &str) -> Option<&Column> {
2439 self.base.columns.iter()
2440 .find(|(n, _)| n == name)
2441 .map(|(_, c)| c)
2442 }
2443}
2444
2445#[derive(Debug, Clone)]
2453pub struct TidyFrame {
2454 inner: Rc<RefCell<DataFrame>>,
2455}
2456
2457impl TidyFrame {
2458 pub fn from_df(df: DataFrame) -> Self {
2460 TidyFrame {
2461 inner: Rc::new(RefCell::new(df)),
2462 }
2463 }
2464
2465 pub fn borrow(&self) -> std::cell::Ref<'_, DataFrame> {
2467 self.inner.borrow()
2468 }
2469
2470 pub fn view(&self) -> TidyView {
2472 let df = self.inner.borrow().clone();
2473 TidyView::from_df(df)
2474 }
2475
2476 pub fn mutate(&mut self, assignments: &[(&str, DExpr)]) -> Result<(), TidyError> {
2478 if Rc::strong_count(&self.inner) > 1 {
2480 let cloned = self.inner.borrow().clone();
2481 self.inner = Rc::new(RefCell::new(cloned));
2482 }
2483
2484 {
2486 let mut seen = std::collections::BTreeSet::new();
2487 for &(name, _) in assignments {
2488 if !seen.insert(name) {
2489 return Err(TidyError::DuplicateColumn(name.to_string()));
2490 }
2491 }
2492 }
2493
2494 let mut df = self.inner.borrow_mut();
2495
2496 let snapshot_names: Vec<String> = df.columns.iter().map(|(n, _)| n.clone()).collect();
2498
2499 for &(col_name, ref expr) in assignments {
2500 validate_expr_columns_snapshot(expr, &snapshot_names)?;
2501
2502 let nrows = df.nrows();
2503 let new_col = eval_expr_column(&df, expr, nrows)?;
2504
2505 if let Some(pos) = df.columns.iter().position(|(n, _)| n == col_name) {
2506 df.columns[pos].1 = new_col;
2507 } else {
2508 df.columns.push((col_name.to_string(), new_col));
2509 }
2510 }
2511
2512 Ok(())
2513 }
2514}
2515
2516#[derive(Debug, Clone, PartialEq)]
2520pub enum TidyError {
2521 ColumnNotFound(String),
2523 DuplicateColumn(String),
2525 PredicateNotBool { got: String },
2527 TypeMismatch { expected: String, got: String },
2529 LengthMismatch { expected: usize, got: usize },
2531 Internal(String),
2533 EmptyGroup,
2535 CapacityExceeded { limit: usize, got: usize },
2537}
2538
2539impl fmt::Display for TidyError {
2540 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2541 match self {
2542 TidyError::ColumnNotFound(n) => write!(f, "column `{}` not found", n),
2543 TidyError::DuplicateColumn(n) => write!(f, "duplicate column `{}`", n),
2544 TidyError::PredicateNotBool { got } => {
2545 write!(f, "filter predicate must be Bool, got {}", got)
2546 }
2547 TidyError::TypeMismatch { expected, got } => {
2548 write!(f, "type mismatch: expected {}, got {}", expected, got)
2549 }
2550 TidyError::LengthMismatch { expected, got } => {
2551 write!(
2552 f,
2553 "length mismatch: expected {} rows, got {}",
2554 expected, got
2555 )
2556 }
2557 TidyError::Internal(msg) => write!(f, "internal error: {}", msg),
2558 TidyError::EmptyGroup => write!(f, "aggregation on empty group"),
2559 TidyError::CapacityExceeded { limit, got } => {
2560 write!(f, "factor capacity exceeded: limit {} distinct levels, got {}", limit, got)
2561 }
2562 }
2563 }
2564}
2565
2566impl std::error::Error for TidyError {}
2567
2568fn extract_f64_column(df: &DataFrame, expr: &DExpr, nrows: usize) -> Result<Vec<f64>, TidyError> {
2580 let col = eval_expr_column(df, expr, nrows)?;
2581 match col {
2582 Column::Float(v) => Ok(v),
2583 Column::Int(v) => Ok(v.into_iter().map(|i| i as f64).collect()),
2584 _ => Err(TidyError::TypeMismatch {
2585 expected: "numeric".into(),
2586 got: "non-numeric".into(),
2587 }),
2588 }
2589}
2590
2591fn eval_window_column(
2594 df: &DataFrame,
2595 expr: &DExpr,
2596 nrows: usize,
2597) -> Result<Option<Column>, TidyError> {
2598 match expr {
2599 DExpr::RowNumber => {
2600 let vals: Vec<i64> = (1..=nrows as i64).collect();
2601 Ok(Some(Column::Int(vals)))
2602 }
2603 DExpr::CumSum(inner) => {
2604 let src = extract_f64_column(df, inner, nrows)?;
2605 let mut result = Vec::with_capacity(nrows);
2606 let mut sum = 0.0_f64;
2607 let mut comp = 0.0_f64; for &v in &src {
2609 let y = v - comp;
2610 let t = sum + y;
2611 comp = (t - sum) - y;
2612 sum = t;
2613 result.push(sum);
2614 }
2615 Ok(Some(Column::Float(result)))
2616 }
2617 DExpr::CumProd(inner) => {
2618 let src = extract_f64_column(df, inner, nrows)?;
2619 let mut result = Vec::with_capacity(nrows);
2620 let mut prod = 1.0_f64;
2621 for &v in &src {
2622 prod *= v;
2623 result.push(prod);
2624 }
2625 Ok(Some(Column::Float(result)))
2626 }
2627 DExpr::CumMax(inner) => {
2628 let src = extract_f64_column(df, inner, nrows)?;
2629 let mut result = Vec::with_capacity(nrows);
2630 let mut max = f64::NEG_INFINITY;
2631 for &v in &src {
2632 if v > max { max = v; }
2633 result.push(max);
2634 }
2635 Ok(Some(Column::Float(result)))
2636 }
2637 DExpr::CumMin(inner) => {
2638 let src = extract_f64_column(df, inner, nrows)?;
2639 let mut result = Vec::with_capacity(nrows);
2640 let mut min = f64::INFINITY;
2641 for &v in &src {
2642 if v < min { min = v; }
2643 result.push(min);
2644 }
2645 Ok(Some(Column::Float(result)))
2646 }
2647 DExpr::Lag(inner, k) => {
2648 let src = extract_f64_column(df, inner, nrows)?;
2649 let mut result = Vec::with_capacity(nrows);
2650 for i in 0..nrows {
2651 if i < *k {
2652 result.push(f64::NAN);
2653 } else {
2654 result.push(src[i - k]);
2655 }
2656 }
2657 Ok(Some(Column::Float(result)))
2658 }
2659 DExpr::Lead(inner, k) => {
2660 let src = extract_f64_column(df, inner, nrows)?;
2661 let mut result = Vec::with_capacity(nrows);
2662 for i in 0..nrows {
2663 if i + k >= nrows {
2664 result.push(f64::NAN);
2665 } else {
2666 result.push(src[i + k]);
2667 }
2668 }
2669 Ok(Some(Column::Float(result)))
2670 }
2671 DExpr::Rank(inner) => {
2672 let src = extract_f64_column(df, inner, nrows)?;
2673 let mut indexed: Vec<(usize, f64)> = src.iter().cloned().enumerate().collect();
2675 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
2676 let mut ranks = vec![0.0_f64; nrows];
2677 let mut i = 0;
2678 while i < nrows {
2679 let mut j = i;
2680 while j < nrows && indexed[j].1 == indexed[i].1 {
2681 j += 1;
2682 }
2683 let avg_rank = (i + 1 + j) as f64 / 2.0; for idx in i..j {
2685 ranks[indexed[idx].0] = avg_rank;
2686 }
2687 i = j;
2688 }
2689 Ok(Some(Column::Float(ranks)))
2690 }
2691 DExpr::DenseRank(inner) => {
2692 let src = extract_f64_column(df, inner, nrows)?;
2693 let mut indexed: Vec<(usize, f64)> = src.iter().cloned().enumerate().collect();
2694 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
2695 let mut ranks = vec![0_i64; nrows];
2696 let mut rank = 0_i64;
2697 let mut prev: Option<f64> = None;
2698 for &(orig_idx, val) in &indexed {
2699 if prev.is_none() || prev.unwrap() != val {
2700 rank += 1;
2701 }
2702 ranks[orig_idx] = rank;
2703 prev = Some(val);
2704 }
2705 Ok(Some(Column::Int(ranks)))
2706 }
2707 DExpr::RollingSum(col_name, window) => {
2708 let vals = rolling_get_floats(df, col_name)?;
2709 let n = vals.len();
2710 let w = *window;
2711 let mut result = Vec::with_capacity(n);
2712 let mut sum = 0.0_f64;
2713 let mut comp = 0.0_f64;
2714 for i in 0..n {
2715 let y = vals[i] - comp;
2717 let t = sum + y;
2718 comp = (t - sum) - y;
2719 sum = t;
2720 if i >= w {
2722 let y2 = -vals[i - w] - comp;
2723 let t2 = sum + y2;
2724 comp = (t2 - sum) - y2;
2725 sum = t2;
2726 }
2727 result.push(sum);
2728 }
2729 Ok(Some(Column::Float(result)))
2730 }
2731 DExpr::RollingMean(col_name, window) => {
2732 let vals = rolling_get_floats(df, col_name)?;
2733 let n = vals.len();
2734 let w = *window;
2735 let mut result = Vec::with_capacity(n);
2736 let mut sum = 0.0_f64;
2737 let mut comp = 0.0_f64;
2738 for i in 0..n {
2739 let y = vals[i] - comp;
2740 let t = sum + y;
2741 comp = (t - sum) - y;
2742 sum = t;
2743 if i >= w {
2744 let y2 = -vals[i - w] - comp;
2745 let t2 = sum + y2;
2746 comp = (t2 - sum) - y2;
2747 sum = t2;
2748 }
2749 let count = if i < w { i + 1 } else { w };
2750 result.push(sum / count as f64);
2751 }
2752 Ok(Some(Column::Float(result)))
2753 }
2754 DExpr::RollingMin(col_name, window) => {
2755 let vals = rolling_get_floats(df, col_name)?;
2756 let n = vals.len();
2757 let w = *window;
2758 let mut result = Vec::with_capacity(n);
2759 let mut deque: VecDeque<usize> = VecDeque::new();
2760 for i in 0..n {
2761 while !deque.is_empty() && *deque.front().unwrap() + w <= i {
2763 deque.pop_front();
2764 }
2765 while !deque.is_empty() && vals[*deque.back().unwrap()] >= vals[i] {
2767 deque.pop_back();
2768 }
2769 deque.push_back(i);
2770 result.push(vals[*deque.front().unwrap()]);
2771 }
2772 Ok(Some(Column::Float(result)))
2773 }
2774 DExpr::RollingMax(col_name, window) => {
2775 let vals = rolling_get_floats(df, col_name)?;
2776 let n = vals.len();
2777 let w = *window;
2778 let mut result = Vec::with_capacity(n);
2779 let mut deque: VecDeque<usize> = VecDeque::new();
2780 for i in 0..n {
2781 while !deque.is_empty() && *deque.front().unwrap() + w <= i {
2782 deque.pop_front();
2783 }
2784 while !deque.is_empty() && vals[*deque.back().unwrap()] <= vals[i] {
2786 deque.pop_back();
2787 }
2788 deque.push_back(i);
2789 result.push(vals[*deque.front().unwrap()]);
2790 }
2791 Ok(Some(Column::Float(result)))
2792 }
2793 DExpr::RollingVar(col_name, window) => {
2794 let vals = rolling_get_floats(df, col_name)?;
2795 let n = vals.len();
2796 let w = *window;
2797 let mut result = Vec::with_capacity(n);
2798 let mut count = 0_usize;
2800 let mut mean = 0.0_f64;
2801 let mut m2 = 0.0_f64;
2802 for i in 0..n {
2803 count += 1;
2805 let delta = vals[i] - mean;
2806 mean += delta / count as f64;
2807 let delta2 = vals[i] - mean;
2808 m2 += delta * delta2;
2809 if i >= w {
2811 let old = vals[i - w];
2812 count -= 1;
2813 if count == 0 {
2814 mean = 0.0;
2815 m2 = 0.0;
2816 } else {
2817 let delta_old = old - mean;
2818 mean -= delta_old / count as f64;
2819 let delta_old2 = old - mean;
2820 m2 -= delta_old * delta_old2;
2821 }
2822 }
2823 if count < 2 {
2824 result.push(0.0);
2825 } else {
2826 result.push(m2 / (count - 1) as f64);
2829 }
2830 }
2831 Ok(Some(Column::Float(result)))
2832 }
2833 DExpr::RollingSd(col_name, window) => {
2834 let vals = rolling_get_floats(df, col_name)?;
2835 let n = vals.len();
2836 let w = *window;
2837 let mut result = Vec::with_capacity(n);
2838 let mut count = 0_usize;
2839 let mut mean = 0.0_f64;
2840 let mut m2 = 0.0_f64;
2841 for i in 0..n {
2842 count += 1;
2843 let delta = vals[i] - mean;
2844 mean += delta / count as f64;
2845 let delta2 = vals[i] - mean;
2846 m2 += delta * delta2;
2847 if i >= w {
2848 let old = vals[i - w];
2849 count -= 1;
2850 if count == 0 {
2851 mean = 0.0;
2852 m2 = 0.0;
2853 } else {
2854 let delta_old = old - mean;
2855 mean -= delta_old / count as f64;
2856 let delta_old2 = old - mean;
2857 m2 -= delta_old * delta_old2;
2858 }
2859 }
2860 if count < 2 {
2861 result.push(0.0);
2862 } else {
2863 result.push((m2 / (count - 1) as f64).sqrt());
2864 }
2865 }
2866 Ok(Some(Column::Float(result)))
2867 }
2868 _ => Ok(None),
2869 }
2870}
2871
2872fn rolling_get_floats(df: &DataFrame, col_name: &str) -> Result<Vec<f64>, TidyError> {
2874 let col = df
2875 .get_column(col_name)
2876 .ok_or_else(|| TidyError::ColumnNotFound(col_name.to_string()))?;
2877 match col {
2878 Column::Float(v) => Ok(v.clone()),
2879 Column::Int(v) => Ok(v.iter().map(|&i| i as f64).collect()),
2880 _ => Err(TidyError::TypeMismatch {
2881 expected: "numeric".into(),
2882 got: "non-numeric".into(),
2883 }),
2884 }
2885}
2886
2887fn vectorized_binop(op: DBinOp, left: &Column, right: &Column) -> Result<Column, TidyError> {
2892 match (left, right) {
2893 (Column::Int(a), Column::Int(b)) => {
2894 let n = a.len();
2895 match op {
2896 DBinOp::Add => { let mut r = vec![0i64; n]; for i in 0..n { r[i] = a[i] + b[i]; } Ok(Column::Int(r)) }
2897 DBinOp::Sub => { let mut r = vec![0i64; n]; for i in 0..n { r[i] = a[i] - b[i]; } Ok(Column::Int(r)) }
2898 DBinOp::Mul => { let mut r = vec![0i64; n]; for i in 0..n { r[i] = a[i] * b[i]; } Ok(Column::Int(r)) }
2899 DBinOp::Div => { let mut r = vec![0i64; n]; for i in 0..n { r[i] = a[i] / b[i]; } Ok(Column::Int(r)) }
2900 DBinOp::Gt => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] > b[i]; } Ok(Column::Bool(r)) }
2901 DBinOp::Lt => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] < b[i]; } Ok(Column::Bool(r)) }
2902 DBinOp::Ge => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] >= b[i]; } Ok(Column::Bool(r)) }
2903 DBinOp::Le => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] <= b[i]; } Ok(Column::Bool(r)) }
2904 DBinOp::Eq => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] == b[i]; } Ok(Column::Bool(r)) }
2905 DBinOp::Ne => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] != b[i]; } Ok(Column::Bool(r)) }
2906 _ => Err(TidyError::Internal(format!("unsupported op {:?} on Int", op))),
2907 }
2908 }
2909 (Column::Float(a), Column::Float(b)) => {
2910 let n = a.len();
2911 match op {
2912 DBinOp::Add => { let mut r = vec![0.0f64; n]; for i in 0..n { r[i] = a[i] + b[i]; } Ok(Column::Float(r)) }
2913 DBinOp::Sub => { let mut r = vec![0.0f64; n]; for i in 0..n { r[i] = a[i] - b[i]; } Ok(Column::Float(r)) }
2914 DBinOp::Mul => { let mut r = vec![0.0f64; n]; for i in 0..n { r[i] = a[i] * b[i]; } Ok(Column::Float(r)) }
2915 DBinOp::Div => { let mut r = vec![0.0f64; n]; for i in 0..n { r[i] = a[i] / b[i]; } Ok(Column::Float(r)) }
2916 DBinOp::Gt => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] > b[i]; } Ok(Column::Bool(r)) }
2917 DBinOp::Lt => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] < b[i]; } Ok(Column::Bool(r)) }
2918 DBinOp::Ge => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] >= b[i]; } Ok(Column::Bool(r)) }
2919 DBinOp::Le => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] <= b[i]; } Ok(Column::Bool(r)) }
2920 DBinOp::Eq => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] == b[i]; } Ok(Column::Bool(r)) }
2921 DBinOp::Ne => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] != b[i]; } Ok(Column::Bool(r)) }
2922 _ => Err(TidyError::Internal(format!("unsupported op {:?} on Float", op))),
2923 }
2924 }
2925 (Column::Int(a), Column::Float(_b)) => {
2926 let promoted: Vec<f64> = a.iter().map(|&v| v as f64).collect();
2927 vectorized_binop(op, &Column::Float(promoted), right)
2928 }
2929 (Column::Float(_a), Column::Int(b)) => {
2930 let promoted: Vec<f64> = b.iter().map(|&v| v as f64).collect();
2931 vectorized_binop(op, left, &Column::Float(promoted))
2932 }
2933 (Column::Bool(a), Column::Bool(b)) => {
2934 let n = a.len();
2935 match op {
2936 DBinOp::And => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] && b[i]; } Ok(Column::Bool(r)) }
2937 DBinOp::Or => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] || b[i]; } Ok(Column::Bool(r)) }
2938 DBinOp::Eq => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] == b[i]; } Ok(Column::Bool(r)) }
2939 DBinOp::Ne => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] != b[i]; } Ok(Column::Bool(r)) }
2940 _ => Err(TidyError::Internal(format!("unsupported op {:?} on Bool", op))),
2941 }
2942 }
2943 (Column::Str(a), Column::Str(b)) => {
2944 let n = a.len();
2945 match op {
2946 DBinOp::Eq => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] == b[i]; } Ok(Column::Bool(r)) }
2947 DBinOp::Ne => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] != b[i]; } Ok(Column::Bool(r)) }
2948 _ => Err(TidyError::Internal(format!("unsupported op {:?} on String", op))),
2949 }
2950 }
2951 _ => Err(TidyError::Internal("type mismatch in binary operation".into())),
2952 }
2953}
2954
2955fn vectorized_fn_call(name: &str, arg: &Column) -> Result<Column, TidyError> {
2958 let floats: Vec<f64> = match arg {
2959 Column::Float(v) => v.clone(),
2960 Column::Int(v) => v.iter().map(|&i| i as f64).collect(),
2961 _ => return Err(TidyError::Internal(format!(
2962 "FnCall '{}' requires numeric argument", name
2963 ))),
2964 };
2965 let f: fn(f64) -> f64 = match name {
2966 "log" => f64::ln,
2967 "exp" => f64::exp,
2968 "sqrt" => f64::sqrt,
2969 "abs" => f64::abs,
2970 "ceil" => f64::ceil,
2971 "floor" => f64::floor,
2972 "round" => f64::round,
2973 "sin" => f64::sin,
2974 "cos" => f64::cos,
2975 "tan" => f64::tan,
2976 _ => return Err(TidyError::Internal(format!(
2977 "unknown DExpr function: {}", name
2978 ))),
2979 };
2980 let mut result = vec![0.0f64; floats.len()];
2981 for i in 0..floats.len() {
2982 result[i] = f(floats[i]);
2983 }
2984 Ok(Column::Float(result))
2985}
2986
2987fn try_eval_expr_column_vectorized(
2991 df: &DataFrame,
2992 expr: &DExpr,
2993 nrows: usize,
2994) -> Option<Result<Column, TidyError>> {
2995 match expr {
2996 DExpr::Col(name) => {
2997 let col = df.get_column(name)?;
2998 let result = match col {
2999 Column::Int(v) => Column::Int(v[..nrows].to_vec()),
3000 Column::Float(v) => Column::Float(v[..nrows].to_vec()),
3001 Column::Str(v) => Column::Str(v[..nrows].to_vec()),
3002 Column::Bool(v) => Column::Bool(v[..nrows].to_vec()),
3003 Column::Categorical { levels, codes } => {
3004 let strs: Vec<String> = codes[..nrows]
3005 .iter()
3006 .map(|&c| levels[c as usize].clone())
3007 .collect();
3008 Column::Str(strs)
3009 }
3010 Column::DateTime(v) => Column::Int(v[..nrows].to_vec()),
3011 };
3012 Some(Ok(result))
3013 }
3014 DExpr::LitFloat(v) => Some(Ok(Column::Float(vec![*v; nrows]))),
3015 DExpr::LitInt(v) => Some(Ok(Column::Int(vec![*v; nrows]))),
3016 DExpr::LitBool(b) => Some(Ok(Column::Bool(vec![*b; nrows]))),
3017 DExpr::LitStr(s) => Some(Ok(Column::Str(vec![s.clone(); nrows]))),
3018 DExpr::BinOp { op, left, right } => {
3019 let left_col = try_eval_expr_column_vectorized(df, left, nrows)?.ok()?;
3020 let right_col = try_eval_expr_column_vectorized(df, right, nrows)?.ok()?;
3021 Some(vectorized_binop(*op, &left_col, &right_col))
3022 }
3023 DExpr::FnCall(name, args) if args.len() == 1 => {
3024 let arg_col = try_eval_expr_column_vectorized(df, &args[0], nrows)?.ok()?;
3025 Some(vectorized_fn_call(name, &arg_col))
3026 }
3027 _ => None,
3028 }
3029}
3030
3031fn eval_expr_column(df: &DataFrame, expr: &DExpr, nrows: usize) -> Result<Column, TidyError> {
3032 if nrows == 0 {
3033 return Ok(Column::Float(vec![]));
3035 }
3036
3037 if let Some(col) = eval_window_column(df, expr, nrows)? {
3039 return Ok(col);
3040 }
3041
3042 if let Some(result) = try_eval_expr_column_vectorized(df, expr, nrows) {
3044 return result;
3045 }
3046
3047 let sample = eval_dexpr_row(df, expr, 0)?;
3049 match sample {
3050 ExprValue::Int(_) => {
3051 let vals: Result<Vec<i64>, TidyError> = (0..nrows)
3052 .map(|r| {
3053 eval_dexpr_row(df, expr, r).and_then(|v| match v {
3054 ExprValue::Int(i) => Ok(i),
3055 ExprValue::Float(f) => Ok(f as i64),
3056 other => Err(TidyError::TypeMismatch {
3057 expected: "Int".into(),
3058 got: other.type_name().into(),
3059 }),
3060 })
3061 })
3062 .collect();
3063 Ok(Column::Int(vals?))
3064 }
3065 ExprValue::Float(_) => {
3066 let vals: Result<Vec<f64>, TidyError> = (0..nrows)
3067 .map(|r| {
3068 eval_dexpr_row(df, expr, r).and_then(|v| match v {
3069 ExprValue::Float(f) => Ok(f),
3070 ExprValue::Int(i) => Ok(i as f64),
3071 other => Err(TidyError::TypeMismatch {
3072 expected: "Float".into(),
3073 got: other.type_name().into(),
3074 }),
3075 })
3076 })
3077 .collect();
3078 Ok(Column::Float(vals?))
3079 }
3080 ExprValue::Bool(_) => {
3081 let vals: Result<Vec<bool>, TidyError> = (0..nrows)
3082 .map(|r| {
3083 eval_dexpr_row(df, expr, r).and_then(|v| match v {
3084 ExprValue::Bool(b) => Ok(b),
3085 other => Err(TidyError::TypeMismatch {
3086 expected: "Bool".into(),
3087 got: other.type_name().into(),
3088 }),
3089 })
3090 })
3091 .collect();
3092 Ok(Column::Bool(vals?))
3093 }
3094 ExprValue::Str(_) => {
3095 let vals: Result<Vec<String>, TidyError> = (0..nrows)
3096 .map(|r| {
3097 eval_dexpr_row(df, expr, r).and_then(|v| match v {
3098 ExprValue::Str(s) => Ok(s),
3099 other => Err(TidyError::TypeMismatch {
3100 expected: "Str".into(),
3101 got: other.type_name().into(),
3102 }),
3103 })
3104 })
3105 .collect();
3106 Ok(Column::Str(vals?))
3107 }
3108 }
3109}
3110
3111fn eval_dexpr_row(df: &DataFrame, expr: &DExpr, row: usize) -> Result<ExprValue, TidyError> {
3113 eval_expr_row(df, expr, row).map_err(|e| TidyError::Internal(e.to_string()))
3114}
3115
3116fn eval_expr_row_proj(
3118 base: &DataFrame,
3119 expr: &DExpr,
3120 row: usize,
3121 _proj: &ProjectionMap,
3122) -> Result<ExprValue, TidyError> {
3123 eval_expr_row(base, expr, row).map_err(|e| TidyError::Internal(e.to_string()))
3127}
3128
3129fn validate_expr_columns_proj(
3136 expr: &DExpr,
3137 base: &DataFrame,
3138 _proj: &ProjectionMap,
3139) -> Result<(), TidyError> {
3140 let mut refs = Vec::new();
3141 collect_expr_columns(expr, &mut refs);
3142 for col_name in refs {
3143 if base.get_column(&col_name).is_none() {
3144 return Err(TidyError::ColumnNotFound(col_name));
3145 }
3146 }
3147 Ok(())
3148}
3149
3150fn validate_expr_columns_snapshot(
3152 expr: &DExpr,
3153 snapshot_names: &[String],
3154) -> Result<(), TidyError> {
3155 let mut refs = Vec::new();
3156 collect_expr_columns(expr, &mut refs);
3157 for col_name in refs {
3158 if !snapshot_names.iter().any(|n| n == &col_name) {
3159 return Err(TidyError::ColumnNotFound(col_name));
3160 }
3161 }
3162 Ok(())
3163}
3164
3165impl ExprValue {
3166 fn type_name(&self) -> &'static str {
3167 match self {
3168 ExprValue::Int(_) => "Int",
3169 ExprValue::Float(_) => "Float",
3170 ExprValue::Str(_) => "Str",
3171 ExprValue::Bool(_) => "Bool",
3172 }
3173 }
3174}
3175
3176impl DataFrame {
3179 pub fn tidy(self) -> TidyView {
3183 TidyView::from_df(self)
3184 }
3185
3186 pub fn tidy_mut(self) -> TidyFrame {
3188 TidyFrame::from_df(self)
3189 }
3190}
3191
3192#[derive(Debug, Clone, PartialEq, Eq)]
3250pub struct RowIndexMap {
3251 pub(crate) indices: Vec<usize>,
3254}
3255
3256impl RowIndexMap {
3257 pub fn new(indices: Vec<usize>) -> Self {
3259 RowIndexMap { indices }
3260 }
3261
3262 pub fn len(&self) -> usize {
3264 self.indices.len()
3265 }
3266
3267 pub fn is_empty(&self) -> bool {
3269 self.indices.is_empty()
3270 }
3271
3272 pub fn as_slice(&self) -> &[usize] {
3274 &self.indices
3275 }
3276}
3277
3278#[derive(Debug, Clone)]
3282pub struct GroupMeta {
3283 pub key_values: Vec<String>,
3285 pub row_indices: Vec<usize>,
3287}
3288
3289#[derive(Debug, Clone)]
3300pub struct GroupIndex {
3301 pub groups: Vec<GroupMeta>,
3303 pub key_names: Vec<String>,
3305}
3306
3307impl GroupIndex {
3308 pub fn build(
3313 base: &DataFrame,
3314 key_col_indices: &[usize],
3315 visible_rows: &[usize],
3316 key_names: Vec<String>,
3317 ) -> Self {
3318 let mut group_order: Vec<Vec<String>> = Vec::new(); let mut group_map: Vec<(Vec<String>, usize)> = Vec::new(); for &row in visible_rows {
3324 let key: Vec<String> = key_col_indices
3325 .iter()
3326 .map(|&ci| base.columns[ci].1.get_display(row))
3327 .collect();
3328
3329 let slot = group_map
3331 .iter()
3332 .position(|(k, _)| k == &key)
3333 .unwrap_or_else(|| {
3334 let s = group_map.len();
3335 group_map.push((key.clone(), s));
3336 group_order.push(key);
3337 s
3338 });
3339
3340 let _ = slot; }
3342
3343 let mut groups: Vec<GroupMeta> = group_order
3345 .iter()
3346 .map(|k| GroupMeta {
3347 key_values: k.clone(),
3348 row_indices: Vec::new(),
3349 })
3350 .collect();
3351
3352 let key_to_slot: Vec<(Vec<String>, usize)> = group_order
3354 .iter()
3355 .enumerate()
3356 .map(|(i, k)| (k.clone(), i))
3357 .collect();
3358
3359 for &row in visible_rows {
3360 let key: Vec<String> = key_col_indices
3361 .iter()
3362 .map(|&ci| base.columns[ci].1.get_display(row))
3363 .collect();
3364 if let Some((_, slot)) = key_to_slot.iter().find(|(k, _)| k == &key) {
3365 groups[*slot].row_indices.push(row);
3366 }
3367 }
3368
3369 GroupIndex { groups, key_names }
3370 }
3371}
3372
3373#[derive(Debug, Clone)]
3383pub struct GroupedTidyView {
3389 view: TidyView,
3390 index: GroupIndex,
3391}
3392
3393impl GroupedTidyView {
3394 pub fn ngroups(&self) -> usize {
3396 self.index.groups.len()
3397 }
3398
3399 pub fn ungroup(self) -> TidyView {
3401 self.view
3402 }
3403
3404 pub fn group_index(&self) -> &GroupIndex {
3406 &self.index
3407 }
3408
3409 pub fn summarise(
3428 &self,
3429 assignments: &[(&str, TidyAgg)],
3430 ) -> Result<TidyFrame, TidyError> {
3431 {
3433 let mut seen = std::collections::BTreeSet::new();
3434 for &(name, _) in assignments {
3435 if !seen.insert(name) {
3436 return Err(TidyError::DuplicateColumn(name.to_string()));
3437 }
3438 }
3439 }
3440
3441 let base = &self.view.base;
3442 let n_groups = self.index.groups.len();
3443
3444 let mut result_columns: Vec<(String, Column)> = Vec::new();
3446
3447 for key_name in &self.index.key_names {
3448 let base_col = base
3449 .get_column(key_name)
3450 .ok_or_else(|| TidyError::ColumnNotFound(key_name.clone()))?;
3451
3452 let col = match base_col {
3453 Column::Int(_) => {
3454 let vals: Vec<i64> = self
3455 .index
3456 .groups
3457 .iter()
3458 .map(|g| {
3459 g.key_values[self
3460 .index
3461 .key_names
3462 .iter()
3463 .position(|k| k == key_name)
3464 .unwrap()]
3465 .parse::<i64>()
3466 .unwrap_or(0)
3467 })
3468 .collect();
3469 Column::Int(vals)
3470 }
3471 Column::Bool(_) => {
3472 let vals: Vec<bool> = self
3473 .index
3474 .groups
3475 .iter()
3476 .map(|g| {
3477 let s = &g.key_values[self
3478 .index
3479 .key_names
3480 .iter()
3481 .position(|k| k == key_name)
3482 .unwrap()];
3483 matches!(s.as_str(), "true" | "1")
3484 })
3485 .collect();
3486 Column::Bool(vals)
3487 }
3488 _ => {
3489 let vals: Vec<String> = self
3491 .index
3492 .groups
3493 .iter()
3494 .map(|g| {
3495 g.key_values[self
3496 .index
3497 .key_names
3498 .iter()
3499 .position(|k| k == key_name)
3500 .unwrap()]
3501 .clone()
3502 })
3503 .collect();
3504 Column::Str(vals)
3505 }
3506 };
3507 result_columns.push((key_name.clone(), col));
3508 }
3509
3510 for &(out_name, ref agg) in assignments {
3512 let col_vals = self.eval_agg_over_groups_fast(agg, n_groups, base)?;
3513 result_columns.push((out_name.to_string(), col_vals));
3514 }
3515
3516 let df = DataFrame::from_columns(result_columns)
3517 .map_err(|e| TidyError::Internal(e.to_string()))?;
3518 Ok(TidyFrame::from_df(df))
3519 }
3520
3521 #[allow(dead_code)]
3523 fn eval_agg_over_groups(
3524 &self,
3525 agg: &TidyAgg,
3526 n_groups: usize,
3527 base: &DataFrame,
3528 ) -> Result<Column, TidyError> {
3529 match agg {
3530 TidyAgg::Count => {
3531 let counts: Vec<i64> = self
3532 .index
3533 .groups
3534 .iter()
3535 .map(|g| g.row_indices.len() as i64)
3536 .collect();
3537 Ok(Column::Int(counts))
3538 }
3539
3540 TidyAgg::Sum(col_name) | TidyAgg::Mean(col_name)
3541 | TidyAgg::Min(col_name) | TidyAgg::Max(col_name)
3542 | TidyAgg::First(col_name) | TidyAgg::Last(col_name)
3543 | TidyAgg::Median(col_name) | TidyAgg::Sd(col_name)
3544 | TidyAgg::Var(col_name) | TidyAgg::Quantile(col_name, _)
3545 | TidyAgg::NDistinct(col_name) | TidyAgg::Iqr(col_name) => {
3546 let src = base
3547 .get_column(col_name)
3548 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3549
3550 let mut vals = Vec::with_capacity(n_groups);
3551 for group in &self.index.groups {
3552 let v = agg_reduce(agg, src, &group.row_indices)?;
3553 vals.push(v);
3554 }
3555 Ok(Column::Float(vals))
3556 }
3557 }
3558 }
3559
3560 fn eval_agg_over_groups_fast(
3563 &self,
3564 agg: &TidyAgg,
3565 n_groups: usize,
3566 base: &DataFrame,
3567 ) -> Result<Column, TidyError> {
3568 match agg {
3569 TidyAgg::Count => {
3570 let counts: Vec<i64> = self
3571 .index
3572 .groups
3573 .iter()
3574 .map(|g| g.row_indices.len() as i64)
3575 .collect();
3576 Ok(Column::Int(counts))
3577 }
3578 TidyAgg::Sum(col_name) => {
3579 let src = base.get_column(col_name)
3580 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3581 Ok(Column::Float(fast_agg_sum(&self.index.groups, src)?))
3582 }
3583 TidyAgg::Mean(col_name) => {
3584 let src = base.get_column(col_name)
3585 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3586 Ok(Column::Float(fast_agg_mean(&self.index.groups, src)?))
3587 }
3588 TidyAgg::Min(col_name) => {
3589 let src = base.get_column(col_name)
3590 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3591 Ok(Column::Float(fast_agg_min(&self.index.groups, src)?))
3592 }
3593 TidyAgg::Max(col_name) => {
3594 let src = base.get_column(col_name)
3595 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3596 Ok(Column::Float(fast_agg_max(&self.index.groups, src)?))
3597 }
3598 TidyAgg::First(col_name) => {
3599 let src = base.get_column(col_name)
3600 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3601 Ok(Column::Float(fast_agg_first(&self.index.groups, src)?))
3602 }
3603 TidyAgg::Last(col_name) => {
3604 let src = base.get_column(col_name)
3605 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3606 Ok(Column::Float(fast_agg_last(&self.index.groups, src)?))
3607 }
3608 TidyAgg::Var(col_name)
3609 | TidyAgg::Sd(col_name)
3610 | TidyAgg::Median(col_name)
3611 | TidyAgg::Quantile(col_name, _)
3612 | TidyAgg::NDistinct(col_name)
3613 | TidyAgg::Iqr(col_name) => {
3614 let src = base.get_column(col_name)
3615 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3616 Ok(Column::Float(fast_agg_arena(
3617 agg, &self.index.groups, src, n_groups,
3618 )?))
3619 }
3620 }
3621 }
3622}
3623
3624enum ColRef<'a> {
3627 Float(&'a [f64]),
3628 Int(&'a [i64]),
3629}
3630
3631fn col_to_ref(col: &Column) -> Result<ColRef<'_>, TidyError> {
3632 match col {
3633 Column::Float(v) => Ok(ColRef::Float(v)),
3634 Column::Int(v) => Ok(ColRef::Int(v)),
3635 _ => Err(TidyError::TypeMismatch {
3636 expected: "numeric (Int or Float)".into(),
3637 got: col.type_name().into(),
3638 }),
3639 }
3640}
3641
3642fn fast_agg_sum(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3643 use cjc_repro::kahan::KahanAccumulatorF64;
3644 let cr = col_to_ref(col)?;
3645 Ok(groups.iter().map(|g| {
3646 let mut acc = KahanAccumulatorF64::new();
3647 match cr {
3648 ColRef::Float(d) => { for &i in &g.row_indices { acc.add(d[i]); } }
3649 ColRef::Int(d) => { for &i in &g.row_indices { acc.add(d[i] as f64); } }
3650 }
3651 acc.finalize()
3652 }).collect())
3653}
3654
3655fn fast_agg_mean(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3656 use cjc_repro::kahan::KahanAccumulatorF64;
3657 let cr = col_to_ref(col)?;
3658 Ok(groups.iter().map(|g| {
3659 if g.row_indices.is_empty() { return f64::NAN; }
3660 let mut acc = KahanAccumulatorF64::new();
3661 match cr {
3662 ColRef::Float(d) => { for &i in &g.row_indices { acc.add(d[i]); } }
3663 ColRef::Int(d) => { for &i in &g.row_indices { acc.add(d[i] as f64); } }
3664 }
3665 acc.finalize() / g.row_indices.len() as f64
3666 }).collect())
3667}
3668
3669fn fast_agg_min(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3670 let cr = col_to_ref(col)?;
3671 Ok(groups.iter().map(|g| {
3672 if g.row_indices.is_empty() { return f64::NAN; }
3673 match cr {
3674 ColRef::Float(d) => g.row_indices.iter().fold(f64::INFINITY, |a, &i| {
3675 let b = d[i]; if b.is_nan() || b < a { b } else { a }
3676 }),
3677 ColRef::Int(d) => g.row_indices.iter().fold(f64::INFINITY, |a, &i| {
3678 let b = d[i] as f64; if b.is_nan() || b < a { b } else { a }
3679 }),
3680 }
3681 }).collect())
3682}
3683
3684fn fast_agg_max(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3685 let cr = col_to_ref(col)?;
3686 Ok(groups.iter().map(|g| {
3687 if g.row_indices.is_empty() { return f64::NAN; }
3688 match cr {
3689 ColRef::Float(d) => g.row_indices.iter().fold(f64::NEG_INFINITY, |a, &i| {
3690 let b = d[i]; if b.is_nan() || b > a { b } else { a }
3691 }),
3692 ColRef::Int(d) => g.row_indices.iter().fold(f64::NEG_INFINITY, |a, &i| {
3693 let b = d[i] as f64; if b.is_nan() || b > a { b } else { a }
3694 }),
3695 }
3696 }).collect())
3697}
3698
3699fn fast_agg_first(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3700 let cr = col_to_ref(col)?;
3701 let mut vals = Vec::with_capacity(groups.len());
3702 for g in groups {
3703 if g.row_indices.is_empty() { return Err(TidyError::EmptyGroup); }
3704 match cr {
3705 ColRef::Float(d) => vals.push(d[g.row_indices[0]]),
3706 ColRef::Int(d) => vals.push(d[g.row_indices[0]] as f64),
3707 }
3708 }
3709 Ok(vals)
3710}
3711
3712fn fast_agg_last(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3713 let cr = col_to_ref(col)?;
3714 let mut vals = Vec::with_capacity(groups.len());
3715 for g in groups {
3716 if g.row_indices.is_empty() { return Err(TidyError::EmptyGroup); }
3717 let last = *g.row_indices.last().unwrap();
3718 match cr {
3719 ColRef::Float(d) => vals.push(d[last]),
3720 ColRef::Int(d) => vals.push(d[last] as f64),
3721 }
3722 }
3723 Ok(vals)
3724}
3725
3726fn fast_agg_arena(
3729 agg: &TidyAgg,
3730 groups: &[GroupMeta],
3731 col: &Column,
3732 n_groups: usize,
3733) -> Result<Vec<f64>, TidyError> {
3734 let cr = col_to_ref(col)?;
3735 let max_size = groups.iter().map(|g| g.row_indices.len()).max().unwrap_or(0);
3736 let mut arena: Vec<f64> = Vec::with_capacity(max_size);
3737 let mut results = Vec::with_capacity(n_groups);
3738 for group in groups {
3739 arena.clear();
3740 match cr {
3741 ColRef::Float(d) => { for &i in &group.row_indices { arena.push(d[i]); } }
3742 ColRef::Int(d) => { for &i in &group.row_indices { arena.push(d[i] as f64); } }
3743 }
3744 let val = agg_reduce_slice(agg, &mut arena)?;
3745 results.push(val);
3746 }
3747 Ok(results)
3748}
3749
3750fn agg_reduce_slice(agg: &TidyAgg, values: &mut [f64]) -> Result<f64, TidyError> {
3753 match agg {
3754 TidyAgg::Var(_) => {
3755 if values.len() < 2 {
3756 Ok(f64::NAN)
3757 } else {
3758 let n = values.len() as f64;
3759 let mean = kahan_sum_f64(values) / n;
3760 let sq_diffs: Vec<f64> = values.iter().map(|v| (v - mean) * (v - mean)).collect();
3761 Ok(kahan_sum_f64(&sq_diffs) / (n - 1.0))
3762 }
3763 }
3764 TidyAgg::Sd(_) => {
3765 if values.len() < 2 {
3766 Ok(f64::NAN)
3767 } else {
3768 let n = values.len() as f64;
3769 let mean = kahan_sum_f64(values) / n;
3770 let sq_diffs: Vec<f64> = values.iter().map(|v| (v - mean) * (v - mean)).collect();
3771 Ok((kahan_sum_f64(&sq_diffs) / (n - 1.0)).sqrt())
3772 }
3773 }
3774 TidyAgg::Median(_) => {
3775 if values.is_empty() {
3776 Ok(f64::NAN)
3777 } else {
3778 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3779 let n = values.len();
3780 if n % 2 == 1 { Ok(values[n / 2]) }
3781 else { Ok((values[n / 2 - 1] + values[n / 2]) / 2.0) }
3782 }
3783 }
3784 TidyAgg::Quantile(_, p) => {
3785 if values.is_empty() {
3786 Ok(f64::NAN)
3787 } else {
3788 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3789 let n = values.len();
3790 if n == 1 { return Ok(values[0]); }
3791 let pos = p * (n as f64 - 1.0);
3792 let lo = pos.floor() as usize;
3793 let hi = pos.ceil() as usize;
3794 let frac = pos - lo as f64;
3795 if lo == hi || hi >= n { Ok(values[lo.min(n - 1)]) }
3796 else { Ok(values[lo] + frac * (values[hi] - values[lo])) }
3797 }
3798 }
3799 TidyAgg::NDistinct(_) => {
3800 let distinct: std::collections::BTreeSet<u64> = values.iter().map(|v| v.to_bits()).collect();
3801 Ok(distinct.len() as f64)
3802 }
3803 TidyAgg::Iqr(_) => {
3804 if values.is_empty() {
3805 Ok(f64::NAN)
3806 } else {
3807 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3808 let n = values.len();
3809 if n == 1 { return Ok(0.0); }
3810 let q1 = {
3811 let pos = 0.25 * (n as f64 - 1.0);
3812 let lo = pos.floor() as usize;
3813 let hi = pos.ceil() as usize;
3814 let frac = pos - lo as f64;
3815 if lo == hi || hi >= n { values[lo.min(n - 1)] }
3816 else { values[lo] + frac * (values[hi] - values[lo]) }
3817 };
3818 let q3 = {
3819 let pos = 0.75 * (n as f64 - 1.0);
3820 let lo = pos.floor() as usize;
3821 let hi = pos.ceil() as usize;
3822 let frac = pos - lo as f64;
3823 if lo == hi || hi >= n { values[lo.min(n - 1)] }
3824 else { values[lo] + frac * (values[hi] - values[lo]) }
3825 };
3826 Ok(q3 - q1)
3827 }
3828 }
3829 _ => unreachable!("agg_reduce_slice called for non-arena aggregator"),
3830 }
3831}
3832
3833#[allow(dead_code)]
3835fn agg_reduce(
3836 agg: &TidyAgg,
3837 col: &Column,
3838 rows: &[usize],
3839) -> Result<f64, TidyError> {
3840 let values: Vec<f64> = match col {
3842 Column::Int(v) => rows.iter().map(|&r| v[r] as f64).collect(),
3843 Column::Float(v) => rows.iter().map(|&r| v[r]).collect(),
3844 _ => {
3845 return Err(TidyError::TypeMismatch {
3846 expected: "numeric (Int or Float)".into(),
3847 got: col.type_name().into(),
3848 })
3849 }
3850 };
3851
3852 match agg {
3853 TidyAgg::Sum(_) => Ok(kahan_sum_f64(&values)),
3854 TidyAgg::Mean(_) => {
3855 if values.is_empty() {
3856 Ok(f64::NAN)
3857 } else {
3858 Ok(kahan_sum_f64(&values) / values.len() as f64)
3859 }
3860 }
3861 TidyAgg::Min(_) => {
3862 if values.is_empty() {
3863 Ok(f64::NAN)
3864 } else {
3865 Ok(values.iter().cloned().fold(f64::INFINITY, |a, b| {
3866 if b.is_nan() || b < a { b } else { a }
3867 }))
3868 }
3869 }
3870 TidyAgg::Max(_) => {
3871 if values.is_empty() {
3872 Ok(f64::NAN)
3873 } else {
3874 Ok(values.iter().cloned().fold(f64::NEG_INFINITY, |a, b| {
3875 if b.is_nan() || b > a { b } else { a }
3876 }))
3877 }
3878 }
3879 TidyAgg::First(_) => {
3880 if values.is_empty() {
3881 Err(TidyError::EmptyGroup)
3882 } else {
3883 Ok(values[0])
3884 }
3885 }
3886 TidyAgg::Last(_) => {
3887 if values.is_empty() {
3888 Err(TidyError::EmptyGroup)
3889 } else {
3890 Ok(*values.last().unwrap())
3891 }
3892 }
3893 TidyAgg::Count => Ok(values.len() as f64),
3894 TidyAgg::Median(_) => {
3895 if values.is_empty() {
3896 Ok(f64::NAN)
3897 } else {
3898 let mut sorted = values.clone();
3899 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3900 let n = sorted.len();
3901 if n % 2 == 1 {
3902 Ok(sorted[n / 2])
3903 } else {
3904 Ok((sorted[n / 2 - 1] + sorted[n / 2]) / 2.0)
3905 }
3906 }
3907 }
3908 TidyAgg::Var(_) => {
3909 if values.len() < 2 {
3910 Ok(f64::NAN)
3911 } else {
3912 let n = values.len() as f64;
3913 let mean = kahan_sum_f64(&values) / n;
3914 let sq_diffs: Vec<f64> = values.iter().map(|v| (v - mean) * (v - mean)).collect();
3915 Ok(kahan_sum_f64(&sq_diffs) / (n - 1.0))
3916 }
3917 }
3918 TidyAgg::Sd(_) => {
3919 if values.len() < 2 {
3920 Ok(f64::NAN)
3921 } else {
3922 let n = values.len() as f64;
3923 let mean = kahan_sum_f64(&values) / n;
3924 let sq_diffs: Vec<f64> = values.iter().map(|v| (v - mean) * (v - mean)).collect();
3925 Ok((kahan_sum_f64(&sq_diffs) / (n - 1.0)).sqrt())
3926 }
3927 }
3928 TidyAgg::Quantile(_, p) => {
3929 if values.is_empty() {
3930 Ok(f64::NAN)
3931 } else {
3932 let mut sorted = values.clone();
3933 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3934 let n = sorted.len();
3935 if n == 1 {
3936 return Ok(sorted[0]);
3937 }
3938 let pos = p * (n as f64 - 1.0);
3939 let lo = pos.floor() as usize;
3940 let hi = pos.ceil() as usize;
3941 let frac = pos - lo as f64;
3942 if lo == hi || hi >= n {
3943 Ok(sorted[lo.min(n - 1)])
3944 } else {
3945 Ok(sorted[lo] + frac * (sorted[hi] - sorted[lo]))
3946 }
3947 }
3948 }
3949 TidyAgg::NDistinct(_) => {
3950 use std::collections::BTreeSet;
3951 let distinct: BTreeSet<u64> = values.iter().map(|v| v.to_bits()).collect();
3952 Ok(distinct.len() as f64)
3953 }
3954 TidyAgg::Iqr(_) => {
3955 if values.is_empty() {
3956 Ok(f64::NAN)
3957 } else {
3958 let mut sorted = values.clone();
3959 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3960 let n = sorted.len();
3961 if n == 1 {
3962 return Ok(0.0);
3963 }
3964 let q1 = {
3965 let pos = 0.25 * (n as f64 - 1.0);
3966 let lo = pos.floor() as usize;
3967 let hi = pos.ceil() as usize;
3968 let frac = pos - lo as f64;
3969 if lo == hi || hi >= n { sorted[lo.min(n - 1)] }
3970 else { sorted[lo] + frac * (sorted[hi] - sorted[lo]) }
3971 };
3972 let q3 = {
3973 let pos = 0.75 * (n as f64 - 1.0);
3974 let lo = pos.floor() as usize;
3975 let hi = pos.ceil() as usize;
3976 let frac = pos - lo as f64;
3977 if lo == hi || hi >= n { sorted[lo.min(n - 1)] }
3978 else { sorted[lo] + frac * (sorted[hi] - sorted[lo]) }
3979 };
3980 Ok(q3 - q1)
3981 }
3982 }
3983 }
3984}
3985
3986#[derive(Debug, Clone)]
3990pub enum TidyAgg {
3991 Count,
3993 Sum(String),
3995 Mean(String),
3997 Min(String),
3999 Max(String),
4001 First(String),
4003 Last(String),
4005 Median(String),
4007 Sd(String),
4009 Var(String),
4011 Quantile(String, f64),
4013 NDistinct(String),
4015 Iqr(String),
4017}
4018
4019#[derive(Debug, Clone)]
4023pub struct ArrangeKey {
4024 pub col_name: String,
4026 pub descending: bool,
4028}
4029
4030impl ArrangeKey {
4031 pub fn asc(col_name: &str) -> Self {
4033 ArrangeKey { col_name: col_name.to_string(), descending: false }
4034 }
4035 pub fn desc(col_name: &str) -> Self {
4037 ArrangeKey { col_name: col_name.to_string(), descending: true }
4038 }
4039}
4040
4041impl TidyView {
4044
4045 pub fn group_by(&self, keys: &[&str]) -> Result<GroupedTidyView, TidyError> {
4059 let mut key_col_indices = Vec::with_capacity(keys.len());
4061 for &key in keys {
4062 let idx = self
4063 .base
4064 .columns
4065 .iter()
4066 .position(|(n, _)| n == key)
4067 .ok_or_else(|| TidyError::ColumnNotFound(key.to_string()))?;
4068 key_col_indices.push(idx);
4069 }
4070
4071 let visible_rows: Vec<usize> = self.mask.iter_set().collect();
4072 let key_names: Vec<String> = keys.iter().map(|s| s.to_string()).collect();
4073
4074 let index = GroupIndex::build_fast(&self.base, &key_col_indices, &visible_rows, key_names);
4076
4077 Ok(GroupedTidyView {
4078 view: self.clone(),
4079 index,
4080 })
4081 }
4082
4083 pub fn arrange(&self, keys: &[ArrangeKey]) -> Result<TidyView, TidyError> {
4103 for key in keys {
4105 if self.base.get_column(&key.col_name).is_none() {
4106 return Err(TidyError::ColumnNotFound(key.col_name.clone()));
4107 }
4108 }
4109
4110 let mut row_indices: Vec<usize> = self.mask.iter_set().collect();
4112
4113 row_indices.sort_by(|&a, &b| {
4115 for key in keys {
4116 let col = self.base.get_column(&key.col_name).unwrap();
4117 let ord = compare_column_rows(col, a, b);
4118 let ord = if key.descending { ord.reverse() } else { ord };
4119 if ord != std::cmp::Ordering::Equal {
4120 return ord;
4121 }
4122 }
4123 std::cmp::Ordering::Equal
4124 });
4125
4126 let mut new_columns = Vec::with_capacity(self.proj.len());
4128 for &ci in self.proj.indices() {
4129 let (name, col) = &self.base.columns[ci];
4130 let new_col = gather_column(col, &row_indices);
4131 new_columns.push((name.clone(), new_col));
4132 }
4133 let mut sorted_all_cols = Vec::with_capacity(self.base.ncols());
4136 for (name, col) in &self.base.columns {
4137 sorted_all_cols.push((name.clone(), gather_column(col, &row_indices)));
4138 }
4139
4140 let new_base = DataFrame::from_columns(sorted_all_cols)
4141 .map_err(|e| TidyError::Internal(e.to_string()))?;
4142 let nrows = new_base.nrows();
4143 let new_proj = self.proj.clone();
4144
4145 Ok(TidyView {
4146 base: Rc::new(new_base),
4147 mask: BitMask::all_true(nrows),
4148 proj: new_proj,
4149 })
4150 }
4151
4152 pub fn slice(&self, start: usize, end: usize) -> TidyView {
4159 let visible: Vec<usize> = self.mask.iter_set().collect();
4160 let n = visible.len();
4161 let s = start.min(n);
4162 let e = end.min(n);
4163 let selected = if s >= e { vec![] } else { visible[s..e].to_vec() };
4164 self.view_from_row_indices(selected)
4165 }
4166
4167 pub fn slice_head(&self, n: usize) -> TidyView {
4169 self.slice(0, n)
4170 }
4171
4172 pub fn slice_tail(&self, n: usize) -> TidyView {
4174 let total = self.mask.count_ones();
4175 let start = total.saturating_sub(n);
4176 self.slice(start, total)
4177 }
4178
4179 pub fn slice_sample(&self, n: usize, seed: u64) -> TidyView {
4184 let mut visible: Vec<usize> = self.mask.iter_set().collect();
4185 let total = visible.len();
4186 if n >= total {
4187 return self.view_from_row_indices(visible);
4188 }
4189 let mut rng = seed;
4191 let selected_count = n;
4192 for i in 0..selected_count {
4193 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
4195 let j = i + (rng as usize % (total - i));
4196 visible.swap(i, j);
4197 }
4198 visible.truncate(selected_count);
4199 visible.sort_unstable();
4201 self.view_from_row_indices(visible)
4202 }
4203
4204 pub fn distinct(&self, cols: &[&str]) -> Result<TidyView, TidyError> {
4216 let mut col_indices = Vec::with_capacity(cols.len());
4218 for &name in cols {
4219 let idx = self
4220 .base
4221 .columns
4222 .iter()
4223 .position(|(n, _)| n == name)
4224 .ok_or_else(|| TidyError::ColumnNotFound(name.to_string()))?;
4225 col_indices.push(idx);
4226 }
4227
4228 let mut seen_keys: BTreeSet<Vec<String>> = BTreeSet::new();
4230 let mut selected_rows: Vec<usize> = Vec::new();
4231
4232 for row in self.mask.iter_set() {
4233 let key: Vec<String> = if col_indices.is_empty() {
4234 vec!["__all__".into()]
4235 } else {
4236 col_indices
4237 .iter()
4238 .map(|&ci| self.base.columns[ci].1.get_display(row))
4239 .collect()
4240 };
4241
4242 if seen_keys.insert(key) {
4243 selected_rows.push(row);
4244 }
4245 }
4246
4247 Ok(self.view_from_row_indices(selected_rows))
4248 }
4249
4250 pub fn inner_join(
4263 &self,
4264 right: &TidyView,
4265 on: &[(&str, &str)],
4266 ) -> Result<TidyFrame, TidyError> {
4267 let (left_rows, right_rows) = join_match_rows(self, right, on, JoinKind::Inner)?;
4268 build_join_frame(self, right, &left_rows, &right_rows, on, false)
4269 }
4270
4271 pub fn left_join(
4275 &self,
4276 right: &TidyView,
4277 on: &[(&str, &str)],
4278 ) -> Result<TidyFrame, TidyError> {
4279 let (left_rows, right_rows_opt) =
4280 join_match_rows_optional(self, right, on, JoinKind::Left)?;
4281 build_left_join_frame(self, right, &left_rows, &right_rows_opt, on)
4282 }
4283
4284 pub fn semi_join(
4288 &self,
4289 right: &TidyView,
4290 on: &[(&str, &str)],
4291 ) -> Result<TidyView, TidyError> {
4292 let included = semi_anti_match_rows(self, right, on, true)?;
4293 Ok(self.view_from_row_indices(included))
4294 }
4295
4296 pub fn anti_join(
4300 &self,
4301 right: &TidyView,
4302 on: &[(&str, &str)],
4303 ) -> Result<TidyView, TidyError> {
4304 let included = semi_anti_match_rows(self, right, on, false)?;
4305 Ok(self.view_from_row_indices(included))
4306 }
4307
4308 fn view_from_row_indices(&self, row_indices: Vec<usize>) -> TidyView {
4313 let nrows_base = self.base.nrows();
4314 let mut words = vec![0u64; nwords_for(nrows_base)];
4315 for &r in &row_indices {
4316 words[r / 64] |= 1u64 << (r % 64);
4317 }
4318 TidyView {
4319 base: Rc::clone(&self.base),
4320 mask: BitMask { words, nrows: nrows_base },
4321 proj: self.proj.clone(),
4322 }
4323 }
4324}
4325
4326#[derive(Clone, Copy)]
4329enum JoinKind { Inner, Left }
4330
4331fn resolve_join_keys(
4333 left: &TidyView,
4334 right: &TidyView,
4335 on: &[(&str, &str)],
4336) -> Result<(Vec<usize>, Vec<usize>), TidyError> {
4337 let mut li = Vec::new();
4338 let mut ri = Vec::new();
4339 for &(lk, rk) in on {
4340 let l = left.base.columns.iter().position(|(n, _)| n == lk)
4341 .ok_or_else(|| TidyError::ColumnNotFound(lk.to_string()))?;
4342 let r = right.base.columns.iter().position(|(n, _)| n == rk)
4343 .ok_or_else(|| TidyError::ColumnNotFound(rk.to_string()))?;
4344 li.push(l);
4345 ri.push(r);
4346 }
4347 Ok((li, ri))
4348}
4349
4350fn row_key(base: &DataFrame, col_indices: &[usize], row: usize) -> Vec<String> {
4352 col_indices
4353 .iter()
4354 .map(|&ci| base.columns[ci].1.get_display(row))
4355 .collect()
4356}
4357
4358fn build_right_lookup(
4361 right: &TidyView,
4362 right_key_cols: &[usize],
4363) -> Vec<(Vec<String>, usize)> {
4364 let mut lookup: Vec<(Vec<String>, usize)> = right
4365 .mask
4366 .iter_set()
4367 .map(|r| (row_key(&right.base, right_key_cols, r), r))
4368 .collect();
4369 lookup.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
4371 lookup
4372}
4373
4374fn find_matches(lookup: &[(Vec<String>, usize)], key: &[String]) -> Vec<usize> {
4376 let key_vec = key.to_vec();
4378 let start = lookup.partition_point(|(k, _)| k.as_slice() < key_vec.as_slice());
4379 let mut matches = Vec::new();
4380 for (k, r) in &lookup[start..] {
4381 if k == &key_vec {
4382 matches.push(*r);
4383 } else {
4384 break;
4385 }
4386 }
4387 matches
4388}
4389
4390fn build_right_lookup_btree(
4395 right: &TidyView,
4396 right_key_cols: &[usize],
4397) -> BTreeMap<Vec<String>, Vec<usize>> {
4398 let mut lookup: BTreeMap<Vec<String>, Vec<usize>> = BTreeMap::new();
4399 for r in right.mask.iter_set() {
4400 let key = row_key(&right.base, right_key_cols, r);
4401 lookup.entry(key).or_default().push(r);
4402 }
4403 lookup
4404}
4405
4406fn join_match_rows(
4408 left: &TidyView,
4409 right: &TidyView,
4410 on: &[(&str, &str)],
4411 _kind: JoinKind,
4412) -> Result<(Vec<usize>, Vec<usize>), TidyError> {
4413 let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
4414 let lookup = build_right_lookup_btree(right, &right_key_cols);
4416
4417 let mut out_left = Vec::new();
4418 let mut out_right = Vec::new();
4419
4420 for l_row in left.mask.iter_set() {
4421 let key = row_key(&left.base, &left_key_cols, l_row);
4422 if let Some(matches) = lookup.get(&key) {
4423 for &r_row in matches {
4424 out_left.push(l_row);
4425 out_right.push(r_row);
4426 }
4427 }
4428 }
4429 Ok((out_left, out_right))
4430}
4431
4432fn join_match_rows_optional(
4434 left: &TidyView,
4435 right: &TidyView,
4436 on: &[(&str, &str)],
4437 _kind: JoinKind,
4438) -> Result<(Vec<usize>, Vec<Option<usize>>), TidyError> {
4439 let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
4440 let lookup = build_right_lookup_btree(right, &right_key_cols);
4442
4443 let mut out_left = Vec::new();
4444 let mut out_right: Vec<Option<usize>> = Vec::new();
4445
4446 for l_row in left.mask.iter_set() {
4447 let key = row_key(&left.base, &left_key_cols, l_row);
4448 match lookup.get(&key) {
4449 Some(matches) if !matches.is_empty() => {
4450 for &r_row in matches {
4451 out_left.push(l_row);
4452 out_right.push(Some(r_row));
4453 }
4454 }
4455 _ => {
4456 out_left.push(l_row);
4457 out_right.push(None);
4458 }
4459 }
4460 }
4461 Ok((out_left, out_right))
4462}
4463
4464fn semi_anti_match_rows(
4466 left: &TidyView,
4467 right: &TidyView,
4468 on: &[(&str, &str)],
4469 semi: bool,
4470) -> Result<Vec<usize>, TidyError> {
4471 let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
4472 let lookup = build_right_lookup_btree(right, &right_key_cols);
4474
4475 let mut out = Vec::new();
4476 for l_row in left.mask.iter_set() {
4477 let key = row_key(&left.base, &left_key_cols, l_row);
4478 let has_match = lookup.contains_key(&key);
4479 if has_match == semi {
4480 out.push(l_row);
4481 }
4482 }
4483 Ok(out)
4484}
4485
4486fn build_join_frame(
4489 left: &TidyView,
4490 right: &TidyView,
4491 left_rows: &[usize],
4492 right_rows: &[usize],
4493 on: &[(&str, &str)],
4494 _include_unmatched: bool,
4495) -> Result<TidyFrame, TidyError> {
4496 let right_key_names: std::collections::BTreeSet<&str> =
4497 on.iter().map(|(_, rk)| *rk).collect();
4498
4499 let n = left_rows.len();
4500 let mut columns: Vec<(String, Column)> = Vec::new();
4501
4502 for &ci in left.proj.indices() {
4504 let (name, col) = &left.base.columns[ci];
4505 columns.push((name.clone(), gather_column(col, left_rows)));
4506 }
4507
4508 for &ci in right.proj.indices() {
4510 let (name, col) = &right.base.columns[ci];
4511 if right_key_names.contains(name.as_str()) {
4512 continue;
4513 }
4514 columns.push((name.clone(), gather_column(col, right_rows)));
4515 }
4516
4517 assert_eq!(n, left_rows.len());
4518 let df = DataFrame::from_columns(columns)
4519 .map_err(|e| TidyError::Internal(e.to_string()))?;
4520 Ok(TidyFrame::from_df(df))
4521}
4522
4523fn build_left_join_frame(
4525 left: &TidyView,
4526 right: &TidyView,
4527 left_rows: &[usize],
4528 right_rows_opt: &[Option<usize>],
4529 on: &[(&str, &str)],
4530) -> Result<TidyFrame, TidyError> {
4531 let right_key_names: std::collections::BTreeSet<&str> =
4532 on.iter().map(|(_, rk)| *rk).collect();
4533
4534 let mut columns: Vec<(String, Column)> = Vec::new();
4535
4536 for &ci in left.proj.indices() {
4538 let (name, col) = &left.base.columns[ci];
4539 columns.push((name.clone(), gather_column(col, left_rows)));
4540 }
4541
4542 for &ci in right.proj.indices() {
4544 let (name, col) = &right.base.columns[ci];
4545 if right_key_names.contains(name.as_str()) {
4546 continue;
4547 }
4548 let new_col = gather_column_nullable(col, right_rows_opt);
4549 columns.push((name.clone(), new_col));
4550 }
4551
4552 let df = DataFrame::from_columns(columns)
4553 .map_err(|e| TidyError::Internal(e.to_string()))?;
4554 Ok(TidyFrame::from_df(df))
4555}
4556
4557fn compare_column_rows(col: &Column, a: usize, b: usize) -> std::cmp::Ordering {
4564 match col {
4565 Column::Int(v) => v[a].cmp(&v[b]),
4566 Column::Float(v) => {
4567 let va = v[a];
4568 let vb = v[b];
4569 match (va.is_nan(), vb.is_nan()) {
4570 (true, true) => std::cmp::Ordering::Equal,
4571 (true, false) => std::cmp::Ordering::Greater, (false, true) => std::cmp::Ordering::Less,
4573 (false, false) => va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal),
4574 }
4575 }
4576 Column::Bool(v) => v[a].cmp(&v[b]),
4577 Column::Str(v) => v[a].cmp(&v[b]),
4578 Column::Categorical { levels, codes } => {
4579 levels[codes[a] as usize].cmp(&levels[codes[b] as usize])
4581 }
4582 Column::DateTime(v) => v[a].cmp(&v[b]),
4583 }
4584}
4585
4586#[cfg(test)]
4607mod phase10_unit_tests {
4608 use super::*;
4609
4610 fn make_df() -> DataFrame {
4611 DataFrame::from_columns(vec![
4612 ("x".into(), Column::Int(vec![1, 2, 3, 4, 5])),
4613 ("y".into(), Column::Float(vec![1.0, 2.0, 3.0, 4.0, 5.0])),
4614 ("flag".into(), Column::Bool(vec![true, false, true, false, true])),
4615 ])
4616 .unwrap()
4617 }
4618
4619 #[test]
4620 fn bitmask_all_true_count() {
4621 let m = BitMask::all_true(5);
4622 assert_eq!(m.count_ones(), 5);
4623 }
4624
4625 #[test]
4626 fn bitmask_all_false_count() {
4627 let m = BitMask::all_false(5);
4628 assert_eq!(m.count_ones(), 0);
4629 }
4630
4631 #[test]
4632 fn bitmask_tail_bits_clean() {
4633 let m = BitMask::all_true(65);
4635 assert_eq!(m.count_ones(), 65);
4636 assert_eq!(m.words.len(), 2);
4637 assert_eq!(m.words[1], 1u64); }
4639
4640 #[test]
4641 fn bitmask_and_semantics() {
4642 let a = BitMask::from_bools(&[true, true, false, true]);
4643 let b = BitMask::from_bools(&[true, false, true, true]);
4644 let c = a.and(&b);
4645 assert!(c.get(0));
4646 assert!(!c.get(1));
4647 assert!(!c.get(2));
4648 assert!(c.get(3));
4649 }
4650
4651 #[test]
4652 fn tidy_view_nrows_ncols() {
4653 let df = make_df();
4654 let v = df.tidy();
4655 assert_eq!(v.nrows(), 5);
4656 assert_eq!(v.ncols(), 3);
4657 }
4658
4659 #[test]
4660 fn filter_basic() {
4661 let df = make_df();
4662 let v = df.tidy();
4663 let filtered = v
4664 .filter(&DExpr::BinOp {
4665 op: DBinOp::Gt,
4666 left: Box::new(DExpr::Col("x".into())),
4667 right: Box::new(DExpr::LitInt(2)),
4668 })
4669 .unwrap();
4670 assert_eq!(filtered.nrows(), 3);
4671 }
4672
4673 #[test]
4674 fn filter_empty_df() {
4675 let df = DataFrame::from_columns(vec![
4676 ("x".into(), Column::Int(vec![])),
4677 ])
4678 .unwrap();
4679 let v = df.tidy();
4680 let filtered = v
4681 .filter(&DExpr::BinOp {
4682 op: DBinOp::Gt,
4683 left: Box::new(DExpr::Col("x".into())),
4684 right: Box::new(DExpr::LitInt(0)),
4685 })
4686 .unwrap();
4687 assert_eq!(filtered.nrows(), 0);
4688 }
4689
4690 #[test]
4691 fn select_reorder() {
4692 let df = make_df();
4693 let v = df.tidy();
4694 let s = v.select(&["y", "x"]).unwrap();
4695 assert_eq!(s.column_names(), vec!["y", "x"]);
4696 }
4697
4698 #[test]
4699 fn select_zero_cols() {
4700 let df = make_df();
4701 let v = df.tidy();
4702 let s = v.select(&[]).unwrap();
4703 assert_eq!(s.ncols(), 0);
4704 assert_eq!(s.nrows(), 5);
4705 }
4706
4707 #[test]
4708 fn select_unknown_col() {
4709 let df = make_df();
4710 let v = df.tidy();
4711 let err = v.select(&["nonexistent"]).unwrap_err();
4712 assert!(matches!(err, TidyError::ColumnNotFound(_)));
4713 }
4714
4715 #[test]
4716 fn select_duplicate_col() {
4717 let df = make_df();
4718 let v = df.tidy();
4719 let err = v.select(&["x", "x"]).unwrap_err();
4720 assert!(matches!(err, TidyError::DuplicateColumn(_)));
4721 }
4722
4723 #[test]
4724 fn mutate_new_col() {
4725 let df = make_df();
4726 let v = df.tidy();
4727 let frame = v
4728 .mutate(&[("z", DExpr::BinOp {
4729 op: DBinOp::Add,
4730 left: Box::new(DExpr::Col("x".into())),
4731 right: Box::new(DExpr::LitInt(10)),
4732 })])
4733 .unwrap();
4734 let b = frame.borrow();
4735 let z = b.get_column("z").unwrap();
4736 assert_eq!(z.len(), 5);
4737 if let Column::Int(v) = z {
4738 assert_eq!(v[0], 11);
4739 assert_eq!(v[4], 15);
4740 } else {
4741 panic!("expected Int column");
4742 }
4743 }
4744
4745 #[test]
4746 fn mutate_type_promotion() {
4747 let df = make_df();
4748 let v = df.tidy();
4749 let frame = v
4751 .mutate(&[("promoted", DExpr::BinOp {
4752 op: DBinOp::Add,
4753 left: Box::new(DExpr::Col("x".into())),
4754 right: Box::new(DExpr::Col("y".into())),
4755 })])
4756 .unwrap();
4757 let b = frame.borrow();
4758 let col = b.get_column("promoted").unwrap();
4759 assert!(matches!(col, Column::Float(_)));
4760 }
4761}
4762
4763impl TidyError {
4815 pub fn schema_mismatch(msg: impl Into<String>) -> Self {
4817 TidyError::Internal(format!("schema mismatch: {}", msg.into()))
4818 }
4819 pub fn join_type_mismatch(col: &str, lt: &str, rt: &str) -> Self {
4821 TidyError::TypeMismatch {
4822 expected: format!("{} (from left key `{}`)", lt, col),
4823 got: rt.to_string(),
4824 }
4825 }
4826 pub fn duplicate_key(key: impl Into<String>) -> Self {
4828 TidyError::DuplicateColumn(format!("duplicate key: {}", key.into()))
4829 }
4830 pub fn empty_selection(msg: impl Into<String>) -> Self {
4832 TidyError::Internal(format!("empty selection: {}", msg.into()))
4833 }
4834}
4835
4836#[derive(Debug, Clone)]
4845pub struct NullableColumn<T: Clone> {
4846 pub values: Vec<T>,
4847 pub validity: BitMask,
4848}
4849
4850impl<T: Clone + Default> NullableColumn<T> {
4851 pub fn from_values(values: Vec<T>) -> Self {
4853 let n = values.len();
4854 Self {
4855 values,
4856 validity: BitMask::all_true(n),
4857 }
4858 }
4859
4860 pub fn new(values: Vec<T>, validity: BitMask) -> Self {
4863 assert_eq!(values.len(), validity.nrows(), "NullableColumn: length mismatch");
4864 Self { values, validity }
4865 }
4866
4867 pub fn len(&self) -> usize {
4869 self.values.len()
4870 }
4871
4872 pub fn is_null(&self, i: usize) -> bool {
4874 !self.validity.get(i)
4875 }
4876
4877 pub fn get(&self, i: usize) -> Option<&T> {
4879 if self.validity.get(i) { Some(&self.values[i]) } else { None }
4880 }
4881
4882 pub fn count_valid(&self) -> usize {
4884 self.validity.count_ones()
4885 }
4886
4887 pub fn gather(&self, indices: &[usize]) -> Self {
4889 let mut vals = Vec::with_capacity(indices.len());
4890 let mut bools = Vec::with_capacity(indices.len());
4891 for &i in indices {
4892 vals.push(self.values[i].clone());
4893 bools.push(self.validity.get(i));
4894 }
4895 let validity = BitMask::from_bools(&bools);
4896 Self { values: vals, validity }
4897 }
4898}
4899
4900#[derive(Debug, Clone)]
4913pub enum NullCol {
4914 Int(NullableColumn<i64>),
4916 Float(NullableColumn<f64>),
4918 Str(NullableColumn<String>),
4920 Bool(NullableColumn<bool>),
4922}
4923
4924impl NullCol {
4925 pub fn len(&self) -> usize {
4927 match self {
4928 NullCol::Int(c) => c.len(),
4929 NullCol::Float(c) => c.len(),
4930 NullCol::Str(c) => c.len(),
4931 NullCol::Bool(c) => c.len(),
4932 }
4933 }
4934
4935 pub fn is_null(&self, i: usize) -> bool {
4937 match self {
4938 NullCol::Int(c) => c.is_null(i),
4939 NullCol::Float(c) => c.is_null(i),
4940 NullCol::Str(c) => c.is_null(i),
4941 NullCol::Bool(c) => c.is_null(i),
4942 }
4943 }
4944
4945 pub fn type_name(&self) -> &'static str {
4947 match self {
4948 NullCol::Int(_) => "Int",
4949 NullCol::Float(_) => "Float",
4950 NullCol::Str(_) => "Str",
4951 NullCol::Bool(_) => "Bool",
4952 }
4953 }
4954
4955 pub fn from_column(col: &Column) -> Self {
4957 match col {
4958 Column::Int(v) => NullCol::Int(NullableColumn::from_values(v.clone())),
4959 Column::Float(v) => NullCol::Float(NullableColumn::from_values(v.clone())),
4960 Column::Str(v) => NullCol::Str(NullableColumn::from_values(v.clone())),
4961 Column::Bool(v) => NullCol::Bool(NullableColumn::from_values(v.clone())),
4962 Column::Categorical { levels, codes } => {
4964 let strings: Vec<String> = codes.iter().map(|&c| levels[c as usize].clone()).collect();
4965 NullCol::Str(NullableColumn::from_values(strings))
4966 }
4967 Column::DateTime(v) => NullCol::Int(NullableColumn::from_values(v.clone())),
4968 }
4969 }
4970
4971 pub fn to_column_strict(&self) -> Result<Column, TidyError> {
4974 match self {
4975 NullCol::Int(nc) => {
4976 if nc.count_valid() == nc.len() {
4977 Ok(Column::Int(nc.values.clone()))
4978 } else {
4979 Err(TidyError::Internal("null values in non-nullable context".into()))
4980 }
4981 }
4982 NullCol::Float(nc) => {
4983 if nc.count_valid() == nc.len() {
4984 Ok(Column::Float(nc.values.clone()))
4985 } else {
4986 Err(TidyError::Internal("null values in non-nullable context".into()))
4987 }
4988 }
4989 NullCol::Str(nc) => {
4990 if nc.count_valid() == nc.len() {
4991 Ok(Column::Str(nc.values.clone()))
4992 } else {
4993 Err(TidyError::Internal("null values in non-nullable context".into()))
4994 }
4995 }
4996 NullCol::Bool(nc) => {
4997 if nc.count_valid() == nc.len() {
4998 Ok(Column::Bool(nc.values.clone()))
4999 } else {
5000 Err(TidyError::Internal("null values in non-nullable context".into()))
5001 }
5002 }
5003 }
5004 }
5005
5006 pub fn to_column_filled(&self) -> Column {
5009 match self {
5010 NullCol::Int(nc) => Column::Int(nc.values.clone()),
5011 NullCol::Float(nc) => {
5012 let v: Vec<f64> = (0..nc.len())
5013 .map(|i| if nc.is_null(i) { f64::NAN } else { nc.values[i] })
5014 .collect();
5015 Column::Float(v)
5016 }
5017 NullCol::Str(nc) => Column::Str(nc.values.clone()),
5018 NullCol::Bool(nc) => Column::Bool(nc.values.clone()),
5019 }
5020 }
5021
5022 pub fn get_display(&self, i: usize) -> String {
5024 if self.is_null(i) {
5025 return "null".to_string();
5026 }
5027 match self {
5028 NullCol::Int(nc) => format!("{}", nc.values[i]),
5029 NullCol::Float(nc) => format!("{}", nc.values[i]),
5030 NullCol::Str(nc) => nc.values[i].clone(),
5031 NullCol::Bool(nc) => format!("{}", nc.values[i]),
5032 }
5033 }
5034
5035 pub fn null_of_type(type_name: &str, len: usize) -> Self {
5037 match type_name {
5038 "Int" => NullCol::Int(NullableColumn {
5039 values: vec![0i64; len],
5040 validity: BitMask::all_false(len),
5041 }),
5042 "Float" => NullCol::Float(NullableColumn {
5043 values: vec![0.0f64; len],
5044 validity: BitMask::all_false(len),
5045 }),
5046 "Bool" => NullCol::Bool(NullableColumn {
5047 values: vec![false; len],
5048 validity: BitMask::all_false(len),
5049 }),
5050 _ => NullCol::Str(NullableColumn {
5051 values: vec![String::new(); len],
5052 validity: BitMask::all_false(len),
5053 }),
5054 }
5055 }
5056
5057 pub fn gather(&self, indices: &[usize]) -> Self {
5059 match self {
5060 NullCol::Int(nc) => NullCol::Int(nc.gather(indices)),
5061 NullCol::Float(nc) => NullCol::Float(nc.gather(indices)),
5062 NullCol::Str(nc) => NullCol::Str(nc.gather(indices)),
5063 NullCol::Bool(nc) => NullCol::Bool(nc.gather(indices)),
5064 }
5065 }
5066}
5067
5068#[derive(Debug, Clone)]
5071pub struct NullableFrame {
5072 pub columns: Vec<(String, NullCol)>,
5073}
5074
5075impl NullableFrame {
5076 pub fn new() -> Self {
5078 Self { columns: Vec::new() }
5079 }
5080
5081 pub fn nrows(&self) -> usize {
5083 self.columns.first().map(|(_, c)| c.len()).unwrap_or(0)
5084 }
5085
5086 pub fn ncols(&self) -> usize {
5088 self.columns.len()
5089 }
5090
5091 pub fn column_names(&self) -> Vec<&str> {
5093 self.columns.iter().map(|(n, _)| n.as_str()).collect()
5094 }
5095
5096 pub fn get_column(&self, name: &str) -> Option<&NullCol> {
5098 self.columns.iter().find(|(n, _)| n == name).map(|(_, c)| c)
5099 }
5100
5101 pub fn to_dataframe_filled(&self) -> DataFrame {
5103 let cols: Vec<(String, Column)> = self.columns.iter()
5104 .map(|(n, c)| (n.clone(), c.to_column_filled()))
5105 .collect();
5106 DataFrame { columns: cols }
5108 }
5109
5110 pub fn to_tidy_frame_filled(&self) -> TidyFrame {
5112 TidyFrame::from_df(self.to_dataframe_filled())
5113 }
5114
5115 pub fn to_tidy_view_filled(&self) -> TidyView {
5117 TidyView::from_df(self.to_dataframe_filled())
5118 }
5119}
5120
5121impl Default for NullableFrame {
5122 fn default() -> Self { Self::new() }
5123}
5124
5125fn gather_column_nullable_null(col: &Column, indices: &[Option<usize>]) -> NullCol {
5130 match col {
5131 Column::Int(v) => {
5132 let mut vals = Vec::with_capacity(indices.len());
5133 let mut valid = Vec::with_capacity(indices.len());
5134 for &idx in indices {
5135 match idx {
5136 Some(i) => { vals.push(v[i]); valid.push(true); }
5137 None => { vals.push(0); valid.push(false); }
5138 }
5139 }
5140 NullCol::Int(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5141 }
5142 Column::Float(v) => {
5143 let mut vals = Vec::with_capacity(indices.len());
5144 let mut valid = Vec::with_capacity(indices.len());
5145 for &idx in indices {
5146 match idx {
5147 Some(i) => { vals.push(v[i]); valid.push(true); }
5148 None => { vals.push(0.0); valid.push(false); }
5149 }
5150 }
5151 NullCol::Float(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5152 }
5153 Column::Str(v) => {
5154 let mut vals = Vec::with_capacity(indices.len());
5155 let mut valid = Vec::with_capacity(indices.len());
5156 for &idx in indices {
5157 match idx {
5158 Some(i) => { vals.push(v[i].clone()); valid.push(true); }
5159 None => { vals.push(String::new()); valid.push(false); }
5160 }
5161 }
5162 NullCol::Str(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5163 }
5164 Column::Bool(v) => {
5165 let mut vals = Vec::with_capacity(indices.len());
5166 let mut valid = Vec::with_capacity(indices.len());
5167 for &idx in indices {
5168 match idx {
5169 Some(i) => { vals.push(v[i]); valid.push(true); }
5170 None => { vals.push(false); valid.push(false); }
5171 }
5172 }
5173 NullCol::Bool(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5174 }
5175 Column::Categorical { levels, codes } => {
5176 let mut vals = Vec::with_capacity(indices.len());
5177 let mut valid = Vec::with_capacity(indices.len());
5178 for &idx in indices {
5179 match idx {
5180 Some(i) => { vals.push(levels[codes[i] as usize].clone()); valid.push(true); }
5181 None => { vals.push(String::new()); valid.push(false); }
5182 }
5183 }
5184 NullCol::Str(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5185 }
5186 Column::DateTime(v) => {
5187 let mut vals = Vec::with_capacity(indices.len());
5188 let mut valid = Vec::with_capacity(indices.len());
5189 for &idx in indices {
5190 match idx {
5191 Some(i) => { vals.push(v[i]); valid.push(true); }
5192 None => { vals.push(0); valid.push(false); }
5193 }
5194 }
5195 NullCol::Int(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5196 }
5197 }
5198}
5199
5200pub type AcrossFn = Box<dyn Fn(&str, &Column) -> Result<Column, TidyError>>;
5207
5208pub struct AcrossTransform {
5210 pub fn_name: String,
5212 pub func: AcrossFn,
5214}
5215
5216impl AcrossTransform {
5217 pub fn new(fn_name: impl Into<String>, func: impl Fn(&str, &Column) -> Result<Column, TidyError> + 'static) -> Self {
5219 Self {
5220 fn_name: fn_name.into(),
5221 func: Box::new(func),
5222 }
5223 }
5224}
5225
5226pub struct AcrossSpec {
5228 pub cols: Vec<String>,
5230 pub transform: AcrossTransform,
5232 pub name_template: Option<String>,
5235}
5236
5237impl AcrossSpec {
5238 pub fn new(cols: impl IntoIterator<Item = impl Into<String>>, transform: AcrossTransform) -> Self {
5240 Self {
5241 cols: cols.into_iter().map(|c| c.into()).collect(),
5242 transform,
5243 name_template: None,
5244 }
5245 }
5246
5247 pub fn with_template(mut self, tmpl: impl Into<String>) -> Self {
5249 self.name_template = Some(tmpl.into());
5250 self
5251 }
5252
5253 pub fn output_name(&self, col_name: &str) -> String {
5255 match &self.name_template {
5256 Some(tmpl) => tmpl
5257 .replace("{col}", col_name)
5258 .replace("{fn}", &self.transform.fn_name),
5259 None => format!("{}_{}", col_name, self.transform.fn_name),
5260 }
5261 }
5262}
5263
5264#[derive(Debug, Clone)]
5268pub struct JoinSuffix {
5269 pub left: String,
5270 pub right: String,
5271}
5272
5273impl Default for JoinSuffix {
5274 fn default() -> Self {
5275 Self { left: ".x".into(), right: ".y".into() }
5276 }
5277}
5278
5279impl JoinSuffix {
5280 pub fn new(left: impl Into<String>, right: impl Into<String>) -> Self {
5282 Self { left: left.into(), right: right.into() }
5283 }
5284}
5285
5286fn join_types_compatible(left: &Column, right: &Column) -> bool {
5291 match (left, right) {
5292 (Column::Int(_), Column::Int(_)) => true,
5293 (Column::Float(_), Column::Float(_)) => true,
5294 (Column::Int(_), Column::Float(_)) => true,
5295 (Column::Float(_), Column::Int(_)) => true,
5296 (Column::Str(_), Column::Str(_)) => true,
5297 (Column::Bool(_), Column::Bool(_)) => true,
5298 _ => false,
5299 }
5300}
5301
5302impl TidyView {
5305
5306 pub fn pivot_longer(
5324 &self,
5325 value_cols: &[&str],
5326 names_to: &str,
5327 values_to: &str,
5328 ) -> Result<TidyFrame, TidyError> {
5329 if value_cols.is_empty() {
5330 return Err(TidyError::empty_selection("pivot_longer requires at least one value_col"));
5331 }
5332
5333 let mut seen_vc: Vec<&str> = Vec::new();
5335 let mut vc_indices: Vec<usize> = Vec::new();
5336 for &name in value_cols {
5337 if seen_vc.contains(&name) {
5338 return Err(TidyError::DuplicateColumn(name.to_string()));
5339 }
5340 seen_vc.push(name);
5341 let idx = self.base.columns.iter().position(|(n, _)| n == name)
5342 .ok_or_else(|| TidyError::ColumnNotFound(name.to_string()))?;
5343 vc_indices.push(idx);
5344 }
5345
5346 let first_type = self.base.columns[vc_indices[0]].1.type_name();
5348 for &idx in &vc_indices[1..] {
5349 let t = self.base.columns[idx].1.type_name();
5350 if t != first_type {
5351 return Err(TidyError::TypeMismatch {
5352 expected: first_type.to_string(),
5353 got: t.to_string(),
5354 });
5355 }
5356 }
5357
5358 let vc_set: std::collections::BTreeSet<usize> = vc_indices.iter().copied().collect();
5360 let id_col_indices: Vec<usize> = self.proj.indices().iter()
5361 .copied()
5362 .filter(|i| !vc_set.contains(i))
5363 .collect();
5364
5365 let visible_rows: Vec<usize> = self.mask.iter_set().collect();
5366 let n_out = visible_rows.len() * value_cols.len();
5367
5368 let mut out_cols: Vec<(String, Column)> = Vec::new();
5370 for &id_idx in &id_col_indices {
5371 let (name, col) = &self.base.columns[id_idx];
5372 let new_col = match col {
5373 Column::Int(v) => {
5374 let mut out = Vec::with_capacity(n_out);
5375 for &r in &visible_rows {
5376 for _ in 0..value_cols.len() { out.push(v[r]); }
5377 }
5378 Column::Int(out)
5379 }
5380 Column::Float(v) => {
5381 let mut out = Vec::with_capacity(n_out);
5382 for &r in &visible_rows {
5383 for _ in 0..value_cols.len() { out.push(v[r]); }
5384 }
5385 Column::Float(out)
5386 }
5387 Column::Str(v) => {
5388 let mut out = Vec::with_capacity(n_out);
5389 for &r in &visible_rows {
5390 for _ in 0..value_cols.len() { out.push(v[r].clone()); }
5391 }
5392 Column::Str(out)
5393 }
5394 Column::Bool(v) => {
5395 let mut out = Vec::with_capacity(n_out);
5396 for &r in &visible_rows {
5397 for _ in 0..value_cols.len() { out.push(v[r]); }
5398 }
5399 Column::Bool(out)
5400 }
5401 Column::Categorical { levels, codes } => {
5402 let mut out = Vec::with_capacity(n_out);
5403 for &r in &visible_rows {
5404 for _ in 0..value_cols.len() { out.push(codes[r]); }
5405 }
5406 Column::Categorical { levels: levels.clone(), codes: out }
5407 }
5408 Column::DateTime(v) => {
5409 let mut out = Vec::with_capacity(n_out);
5410 for &r in &visible_rows {
5411 for _ in 0..value_cols.len() { out.push(v[r]); }
5412 }
5413 Column::DateTime(out)
5414 }
5415 };
5416 out_cols.push((name.clone(), new_col));
5417 }
5418
5419 let names_col: Vec<String> = visible_rows.iter()
5421 .flat_map(|_| value_cols.iter().map(|s| s.to_string()))
5422 .collect();
5423 out_cols.push((names_to.to_string(), Column::Str(names_col)));
5424
5425 match &self.base.columns[vc_indices[0]].1 {
5427 Column::Int(_) => {
5428 let mut vals: Vec<i64> = Vec::with_capacity(n_out);
5429 for &r in &visible_rows {
5430 for &vci in &vc_indices {
5431 if let Column::Int(v) = &self.base.columns[vci].1 {
5432 vals.push(v[r]);
5433 }
5434 }
5435 }
5436 out_cols.push((values_to.to_string(), Column::Int(vals)));
5437 }
5438 Column::Float(_) => {
5439 let mut vals: Vec<f64> = Vec::with_capacity(n_out);
5440 for &r in &visible_rows {
5441 for &vci in &vc_indices {
5442 if let Column::Float(v) = &self.base.columns[vci].1 {
5443 vals.push(v[r]);
5444 }
5445 }
5446 }
5447 out_cols.push((values_to.to_string(), Column::Float(vals)));
5448 }
5449 Column::Str(_) => {
5450 let mut vals: Vec<String> = Vec::with_capacity(n_out);
5451 for &r in &visible_rows {
5452 for &vci in &vc_indices {
5453 if let Column::Str(v) = &self.base.columns[vci].1 {
5454 vals.push(v[r].clone());
5455 }
5456 }
5457 }
5458 out_cols.push((values_to.to_string(), Column::Str(vals)));
5459 }
5460 Column::Bool(_) => {
5461 let mut vals: Vec<bool> = Vec::with_capacity(n_out);
5462 for &r in &visible_rows {
5463 for &vci in &vc_indices {
5464 if let Column::Bool(v) = &self.base.columns[vci].1 {
5465 vals.push(v[r]);
5466 }
5467 }
5468 }
5469 out_cols.push((values_to.to_string(), Column::Bool(vals)));
5470 }
5471 Column::Categorical { .. } | Column::DateTime(_) => {
5472 let mut vals: Vec<String> = Vec::with_capacity(n_out);
5474 for &r in &visible_rows {
5475 for &vci in &vc_indices {
5476 vals.push(self.base.columns[vci].1.get_display(r));
5477 }
5478 }
5479 out_cols.push((values_to.to_string(), Column::Str(vals)));
5480 }
5481 }
5482
5483 let df = DataFrame::from_columns(out_cols)
5484 .map_err(|e| TidyError::Internal(e.to_string()))?;
5485 Ok(TidyFrame::from_df(df))
5486 }
5487
5488 pub fn pivot_wider(
5505 &self,
5506 id_cols: &[&str],
5507 names_from: &str,
5508 values_from: &str,
5509 ) -> Result<NullableFrame, TidyError> {
5510 let _names_col_idx = self.base.columns.iter().position(|(n, _)| n == names_from)
5512 .ok_or_else(|| TidyError::ColumnNotFound(names_from.to_string()))?;
5513 let _values_col_idx = self.base.columns.iter().position(|(n, _)| n == values_from)
5514 .ok_or_else(|| TidyError::ColumnNotFound(values_from.to_string()))?;
5515 for &id in id_cols {
5516 let _ = self.base.columns.iter().position(|(n, _)| n == id)
5517 .ok_or_else(|| TidyError::ColumnNotFound(id.to_string()))?;
5518 }
5519
5520 let visible_rows: Vec<usize> = self.mask.iter_set().collect();
5521
5522 let mut key_values: Vec<String> = Vec::new();
5524 for &r in &visible_rows {
5525 let kv = self.base.get_column(names_from).unwrap().get_display(r);
5526 if !key_values.contains(&kv) {
5527 key_values.push(kv);
5528 }
5529 }
5530
5531 let id_col_refs: Vec<&Column> = id_cols.iter()
5534 .map(|&name| self.base.get_column(name).unwrap())
5535 .collect();
5536
5537 let mut id_order: Vec<Vec<String>> = Vec::new(); let mut id_to_slot: Vec<(Vec<String>, usize)> = Vec::new(); for &r in &visible_rows {
5541 let id_key: Vec<String> = id_col_refs.iter()
5542 .map(|col| col.get_display(r))
5543 .collect();
5544 if !id_to_slot.iter().any(|(k, _)| k == &id_key) {
5545 let slot = id_order.len();
5546 id_order.push(id_key.clone());
5547 id_to_slot.push((id_key, slot));
5548 }
5549 }
5550
5551 let n_rows = id_order.len();
5552 let n_keys = key_values.len();
5553
5554 let mut cell_map: Vec<Vec<Option<usize>>> = vec![vec![None; n_keys]; n_rows];
5557
5558 for &r in &visible_rows {
5559 let id_key: Vec<String> = id_col_refs.iter()
5560 .map(|col| col.get_display(r))
5561 .collect();
5562 let id_slot = id_to_slot.iter().find(|(k, _)| k == &id_key).unwrap().1;
5563
5564 let kv = self.base.get_column(names_from).unwrap().get_display(r);
5565 let key_slot = key_values.iter().position(|v| v == &kv).unwrap();
5566
5567 if cell_map[id_slot][key_slot].is_some() {
5568 return Err(TidyError::duplicate_key(
5569 format!("({}, {})", id_key.join(", "), kv)
5570 ));
5571 }
5572 cell_map[id_slot][key_slot] = Some(r);
5573 }
5574
5575 let mut out_cols: Vec<(String, NullCol)> = Vec::new();
5577
5578 for (id_idx, &id_name) in id_cols.iter().enumerate() {
5580 let id_col = self.base.get_column(id_name).unwrap();
5581 let id_row_indices: Vec<usize> = id_order.iter()
5582 .map(|id_tup| {
5583 *visible_rows.iter().find(|&&r| {
5585 id_col_refs.iter().enumerate().all(|(i, col)| {
5586 col.get_display(r) == id_tup[i]
5587 })
5588 }).unwrap()
5589 })
5590 .collect();
5591 let gathered = gather_column(id_col, &id_row_indices);
5592 out_cols.push((id_name.to_string(), NullCol::from_column(&gathered)));
5593 let _ = id_idx;
5594 }
5595
5596 let values_col = self.base.get_column(values_from).unwrap();
5598 let val_type = values_col.type_name();
5599 for (key_slot, key_val) in key_values.iter().enumerate() {
5600 let row_opts: Vec<Option<usize>> = (0..n_rows)
5601 .map(|id_slot| cell_map[id_slot][key_slot])
5602 .collect();
5603 let null_col = gather_column_nullable_null(values_col, &row_opts);
5604 out_cols.push((key_val.clone(), null_col));
5605 let _ = val_type;
5606 }
5607
5608 Ok(NullableFrame { columns: out_cols })
5609 }
5610
5611 pub fn rename(&self, renames: &[(&str, &str)]) -> Result<TidyView, TidyError> {
5622 let mut rename_map: Vec<(usize, String)> = Vec::new();
5624 let col_names: Vec<&str> = self.base.columns.iter().map(|(n, _)| n.as_str()).collect();
5625
5626 for &(old, new) in renames {
5627 let idx = col_names.iter().position(|&n| n == old)
5628 .ok_or_else(|| TidyError::ColumnNotFound(old.to_string()))?;
5629 if old != new {
5631 let new_name_exists = col_names.iter().any(|&n| n == new)
5632 || rename_map.iter().any(|(_, n)| n == new);
5633 if new_name_exists {
5634 return Err(TidyError::DuplicateColumn(new.to_string()));
5635 }
5636 }
5637 rename_map.push((idx, new.to_string()));
5638 }
5639
5640 let mut new_cols: Vec<(String, Column)> = Vec::new();
5642 for (i, (name, col)) in self.base.columns.iter().enumerate() {
5643 let new_name = rename_map.iter()
5644 .find(|(idx, _)| *idx == i)
5645 .map(|(_, n)| n.clone())
5646 .unwrap_or_else(|| name.clone());
5647 new_cols.push((new_name, col.clone()));
5648 }
5649
5650 let new_base = DataFrame { columns: new_cols };
5651 Ok(TidyView {
5652 base: Rc::new(new_base),
5653 mask: self.mask.clone(),
5654 proj: self.proj.clone(),
5655 })
5656 }
5657
5658 pub fn relocate(&self, cols: &[&str], position: RelocatePos<'_>) -> Result<TidyView, TidyError> {
5673 let proj_names: Vec<&str> = self.column_names();
5675 for &name in cols {
5676 if !proj_names.contains(&name) {
5677 return Err(TidyError::ColumnNotFound(name.to_string()));
5678 }
5679 }
5680
5681 let moved_set: std::collections::BTreeSet<&str> = cols.iter().copied().collect();
5683 let remaining: Vec<&str> = proj_names.iter()
5684 .copied()
5685 .filter(|n| !moved_set.contains(n))
5686 .collect();
5687
5688 let new_order: Vec<&str> = match &position {
5689 RelocatePos::Front => {
5690 let mut v: Vec<&str> = cols.to_vec();
5691 v.extend_from_slice(&remaining);
5692 v
5693 }
5694 RelocatePos::Back => {
5695 let mut v = remaining.clone();
5696 v.extend_from_slice(cols);
5697 v
5698 }
5699 RelocatePos::Before(anchor) => {
5700 if !proj_names.contains(anchor) {
5701 return Err(TidyError::ColumnNotFound(anchor.to_string()));
5702 }
5703 let mut v = Vec::new();
5704 for &n in &remaining {
5705 if n == *anchor {
5706 v.extend_from_slice(cols);
5707 }
5708 v.push(n);
5709 }
5710 v
5711 }
5712 RelocatePos::After(anchor) => {
5713 if !proj_names.contains(anchor) {
5714 return Err(TidyError::ColumnNotFound(anchor.to_string()));
5715 }
5716 let mut v = Vec::new();
5717 for &n in &remaining {
5718 v.push(n);
5719 if n == *anchor {
5720 v.extend_from_slice(cols);
5721 }
5722 }
5723 v
5724 }
5725 };
5726
5727 let new_indices: Vec<usize> = new_order.iter()
5729 .map(|&name| {
5730 self.base.columns.iter().position(|(n, _)| n == name).unwrap()
5731 })
5732 .collect();
5733
5734 Ok(TidyView {
5735 base: Rc::clone(&self.base),
5736 mask: self.mask.clone(),
5737 proj: ProjectionMap::from_indices(new_indices),
5738 })
5739 }
5740
5741 pub fn drop_cols(&self, cols: &[&str]) -> Result<TidyView, TidyError> {
5751 let proj_names = self.column_names();
5752 for &name in cols {
5753 if !proj_names.contains(&name) {
5754 return Err(TidyError::ColumnNotFound(name.to_string()));
5755 }
5756 }
5757 let drop_set: std::collections::BTreeSet<&str> = cols.iter().copied().collect();
5758 let keep: Vec<&str> = proj_names.iter()
5759 .copied()
5760 .filter(|n| !drop_set.contains(n))
5761 .collect();
5762 self.select(&keep)
5763 }
5764
5765 pub fn bind_rows(&self, other: &TidyView) -> Result<TidyFrame, TidyError> {
5776 let self_names = self.column_names();
5777 let other_names = other.column_names();
5778
5779 if self_names != other_names {
5780 return Err(TidyError::schema_mismatch(format!(
5781 "left has {:?}, right has {:?}",
5782 self_names, other_names
5783 )));
5784 }
5785
5786 let self_rows: Vec<usize> = self.mask.iter_set().collect();
5787 let other_rows: Vec<usize> = other.mask.iter_set().collect();
5788
5789 let mut out_cols: Vec<(String, Column)> = Vec::new();
5790 for &ci in self.proj.indices() {
5791 let (name, self_col) = &self.base.columns[ci];
5792 let other_ci = other.proj.indices().iter().copied()
5794 .find(|&i| other.base.columns[i].0 == *name)
5795 .ok_or_else(|| TidyError::ColumnNotFound(name.clone()))?;
5796 let other_col = &other.base.columns[other_ci].1;
5797
5798 let col = concat_columns(self_col, &self_rows, other_col, &other_rows)?;
5799 out_cols.push((name.clone(), col));
5800 }
5801
5802 let df = DataFrame::from_columns(out_cols)
5803 .map_err(|e| TidyError::Internal(e.to_string()))?;
5804 Ok(TidyFrame::from_df(df))
5805 }
5806
5807 pub fn bind_cols(&self, other: &TidyView) -> Result<TidyFrame, TidyError> {
5818 let self_nrows = self.nrows();
5819 let other_nrows = other.nrows();
5820
5821 if self_nrows != other_nrows {
5822 return Err(TidyError::LengthMismatch {
5823 expected: self_nrows,
5824 got: other_nrows,
5825 });
5826 }
5827
5828 let self_names = self.column_names();
5829 let other_names = other.column_names();
5830 for name in &other_names {
5831 if self_names.contains(name) {
5832 return Err(TidyError::DuplicateColumn(name.to_string()));
5833 }
5834 }
5835
5836 let self_rows: Vec<usize> = self.mask.iter_set().collect();
5837 let other_rows: Vec<usize> = other.mask.iter_set().collect();
5838
5839 let mut out_cols: Vec<(String, Column)> = Vec::new();
5840
5841 for &ci in self.proj.indices() {
5842 let (name, col) = &self.base.columns[ci];
5843 out_cols.push((name.clone(), gather_column(col, &self_rows)));
5844 }
5845 for &ci in other.proj.indices() {
5846 let (name, col) = &other.base.columns[ci];
5847 out_cols.push((name.clone(), gather_column(col, &other_rows)));
5848 }
5849
5850 let df = DataFrame::from_columns(out_cols)
5851 .map_err(|e| TidyError::Internal(e.to_string()))?;
5852 Ok(TidyFrame::from_df(df))
5853 }
5854
5855 pub fn mutate_across(&self, specs: &[AcrossSpec]) -> Result<TidyFrame, TidyError> {
5865 let base_df = self.materialize()?;
5867
5868 let mut output_names: Vec<String> = base_df.column_names()
5870 .into_iter().map(|s| s.to_string()).collect();
5871 let mut extra_cols: Vec<(String, Column)> = Vec::new();
5872
5873 for spec in specs {
5874 for col_name in &spec.cols {
5875 let out_name = spec.output_name(col_name);
5876 if output_names.contains(&out_name) && !base_df.column_names().contains(&out_name.as_str()) {
5878 return Err(TidyError::DuplicateColumn(out_name));
5879 }
5880 let col = base_df.get_column(col_name)
5881 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
5882 let new_col = (spec.transform.func)(col_name, col)?;
5883 if !base_df.column_names().contains(&out_name.as_str()) {
5885 output_names.push(out_name.clone());
5886 }
5887 extra_cols.push((out_name, new_col));
5888 }
5889 }
5890
5891 let mut col_map: indexmap_simple::IndexMap = indexmap_simple::IndexMap::from_df(&base_df);
5893 for (name, col) in extra_cols {
5894 col_map.insert(name, col);
5895 }
5896 let df = col_map.into_df()
5897 .map_err(|e| TidyError::Internal(e.to_string()))?;
5898 Ok(TidyFrame::from_df(df))
5899 }
5900
5901 pub fn right_join(
5909 &self,
5910 right: &TidyView,
5911 on: &[(&str, &str)],
5912 suffix: &JoinSuffix,
5913 ) -> Result<NullableFrame, TidyError> {
5914 validate_join_key_types(self, right, on)?;
5916 let swapped_on: Vec<(&str, &str)> = on.iter().map(|&(l, r)| (r, l)).collect();
5918 let (right_rows, left_rows_opt) =
5919 join_match_rows_optional(right, self, &swapped_on, JoinKind::Left)?;
5920 build_right_join_frame(self, right, &left_rows_opt, &right_rows, on, suffix)
5921 }
5922
5923 pub fn full_join(
5929 &self,
5930 right: &TidyView,
5931 on: &[(&str, &str)],
5932 suffix: &JoinSuffix,
5933 ) -> Result<NullableFrame, TidyError> {
5934 validate_join_key_types(self, right, on)?;
5935 build_full_join_frame(self, right, on, suffix)
5936 }
5937
5938 pub fn inner_join_typed(
5946 &self,
5947 right: &TidyView,
5948 on: &[(&str, &str)],
5949 suffix: &JoinSuffix,
5950 ) -> Result<TidyFrame, TidyError> {
5951 validate_join_key_types(self, right, on)?;
5952 let (left_rows, right_rows) = join_match_rows(self, right, on, JoinKind::Inner)?;
5953 build_join_frame_with_suffix(self, right, &left_rows, &right_rows, on, suffix, false)
5954 }
5955
5956 pub fn left_join_typed(
5960 &self,
5961 right: &TidyView,
5962 on: &[(&str, &str)],
5963 suffix: &JoinSuffix,
5964 ) -> Result<TidyFrame, TidyError> {
5965 validate_join_key_types(self, right, on)?;
5966 let (left_rows, right_rows_opt) =
5967 join_match_rows_optional(self, right, on, JoinKind::Left)?;
5968 build_left_join_frame_with_suffix(self, right, &left_rows, &right_rows_opt, on, suffix)
5969 }
5970}
5971
5972pub enum RelocatePos<'a> {
5976 Front,
5978 Back,
5980 Before(&'a str),
5982 After(&'a str),
5984}
5985
5986fn concat_columns(
5989 left: &Column,
5990 left_rows: &[usize],
5991 right: &Column,
5992 right_rows: &[usize],
5993) -> Result<Column, TidyError> {
5994 match (left, right) {
5995 (Column::Int(lv), Column::Int(rv)) => {
5996 let mut out: Vec<i64> = left_rows.iter().map(|&i| lv[i]).collect();
5997 out.extend(right_rows.iter().map(|&i| rv[i]));
5998 Ok(Column::Int(out))
5999 }
6000 (Column::Float(lv), Column::Float(rv)) => {
6001 let mut out: Vec<f64> = left_rows.iter().map(|&i| lv[i]).collect();
6002 out.extend(right_rows.iter().map(|&i| rv[i]));
6003 Ok(Column::Float(out))
6004 }
6005 (Column::Int(lv), Column::Float(rv)) => {
6006 let mut out: Vec<f64> = left_rows.iter().map(|&i| lv[i] as f64).collect();
6007 out.extend(right_rows.iter().map(|&i| rv[i]));
6008 Ok(Column::Float(out))
6009 }
6010 (Column::Float(lv), Column::Int(rv)) => {
6011 let mut out: Vec<f64> = left_rows.iter().map(|&i| lv[i]).collect();
6012 out.extend(right_rows.iter().map(|&i| rv[i] as f64));
6013 Ok(Column::Float(out))
6014 }
6015 (Column::Str(lv), Column::Str(rv)) => {
6016 let mut out: Vec<String> = left_rows.iter().map(|&i| lv[i].clone()).collect();
6017 out.extend(right_rows.iter().map(|&i| rv[i].clone()));
6018 Ok(Column::Str(out))
6019 }
6020 (Column::Bool(lv), Column::Bool(rv)) => {
6021 let mut out: Vec<bool> = left_rows.iter().map(|&i| lv[i]).collect();
6022 out.extend(right_rows.iter().map(|&i| rv[i]));
6023 Ok(Column::Bool(out))
6024 }
6025 _ => Err(TidyError::schema_mismatch(format!(
6026 "type mismatch in bind_rows: {} vs {}",
6027 left.type_name(), right.type_name()
6028 ))),
6029 }
6030}
6031
6032fn validate_join_key_types(
6035 left: &TidyView,
6036 right: &TidyView,
6037 on: &[(&str, &str)],
6038) -> Result<(), TidyError> {
6039 for &(lk, rk) in on {
6040 let l_col = left.base.get_column(lk)
6041 .ok_or_else(|| TidyError::ColumnNotFound(lk.to_string()))?;
6042 let r_col = right.base.get_column(rk)
6043 .ok_or_else(|| TidyError::ColumnNotFound(rk.to_string()))?;
6044 if !join_types_compatible(l_col, r_col) {
6045 return Err(TidyError::join_type_mismatch(lk, l_col.type_name(), r_col.type_name()));
6046 }
6047 }
6048 Ok(())
6049}
6050
6051fn build_join_frame_with_suffix(
6054 left: &TidyView,
6055 right: &TidyView,
6056 left_rows: &[usize],
6057 right_rows: &[usize],
6058 on: &[(&str, &str)],
6059 suffix: &JoinSuffix,
6060 _include_unmatched: bool,
6061) -> Result<TidyFrame, TidyError> {
6062 let right_key_names: std::collections::BTreeSet<&str> =
6063 on.iter().map(|(_, rk)| *rk).collect();
6064
6065 let left_col_names: Vec<String> = left.proj.indices().iter()
6067 .map(|&ci| left.base.columns[ci].0.clone())
6068 .collect();
6069
6070 let mut columns: Vec<(String, Column)> = Vec::new();
6071
6072 for &ci in left.proj.indices() {
6074 let (name, col) = &left.base.columns[ci];
6075 columns.push((name.clone(), gather_column(col, left_rows)));
6076 }
6077
6078 for &ci in right.proj.indices() {
6080 let (name, col) = &right.base.columns[ci];
6081 if right_key_names.contains(name.as_str()) {
6082 continue; }
6084 let out_name = if left_col_names.contains(name) {
6085 format!("{}{}", name, suffix.right)
6086 } else {
6087 name.clone()
6088 };
6089 if left_col_names.contains(name) {
6091 let left_pos = columns.iter().position(|(n, _)| n == name);
6093 if let Some(pos) = left_pos {
6094 let entry = &mut columns[pos];
6095 entry.0 = format!("{}{}", entry.0, suffix.left);
6096 }
6097 }
6098 columns.push((out_name, gather_column(col, right_rows)));
6099 }
6100
6101 let df = DataFrame::from_columns(columns)
6102 .map_err(|e| TidyError::Internal(e.to_string()))?;
6103 Ok(TidyFrame::from_df(df))
6104}
6105
6106fn build_left_join_frame_with_suffix(
6107 left: &TidyView,
6108 right: &TidyView,
6109 left_rows: &[usize],
6110 right_rows_opt: &[Option<usize>],
6111 on: &[(&str, &str)],
6112 suffix: &JoinSuffix,
6113) -> Result<TidyFrame, TidyError> {
6114 let right_key_names: std::collections::BTreeSet<&str> =
6115 on.iter().map(|(_, rk)| *rk).collect();
6116
6117 let left_col_names: Vec<String> = left.proj.indices().iter()
6118 .map(|&ci| left.base.columns[ci].0.clone())
6119 .collect();
6120
6121 let mut columns: Vec<(String, Column)> = Vec::new();
6122
6123 for &ci in left.proj.indices() {
6125 let (name, col) = &left.base.columns[ci];
6126 columns.push((name.clone(), gather_column(col, left_rows)));
6127 }
6128
6129 for &ci in right.proj.indices() {
6131 let (name, col) = &right.base.columns[ci];
6132 if right_key_names.contains(name.as_str()) { continue; }
6133 let out_name = if left_col_names.contains(name) {
6134 let left_pos = columns.iter().position(|(n, _)| n == name);
6136 if let Some(pos) = left_pos {
6137 columns[pos].0 = format!("{}{}", name, suffix.left);
6138 }
6139 format!("{}{}", name, suffix.right)
6140 } else {
6141 name.clone()
6142 };
6143 let new_col = gather_column_nullable(col, right_rows_opt);
6144 columns.push((out_name, new_col));
6145 }
6146
6147 let df = DataFrame::from_columns(columns)
6148 .map_err(|e| TidyError::Internal(e.to_string()))?;
6149 Ok(TidyFrame::from_df(df))
6150}
6151
6152fn build_right_join_frame(
6153 left: &TidyView,
6154 right: &TidyView,
6155 left_rows_opt: &[Option<usize>],
6156 right_rows: &[usize],
6157 on: &[(&str, &str)],
6158 suffix: &JoinSuffix,
6159) -> Result<NullableFrame, TidyError> {
6160 let right_key_names: std::collections::BTreeSet<&str> =
6161 on.iter().map(|(_, rk)| *rk).collect();
6162 let left_key_names: std::collections::BTreeSet<&str> =
6163 on.iter().map(|(lk, _)| *lk).collect();
6164
6165 let right_col_names: Vec<String> = right.proj.indices().iter()
6166 .map(|&ci| right.base.columns[ci].0.clone())
6167 .collect();
6168
6169 let mut columns: Vec<(String, NullCol)> = Vec::new();
6170
6171 for &ci in left.proj.indices() {
6173 let (name, col) = &left.base.columns[ci];
6174 if left_key_names.contains(name.as_str()) { continue; }
6175 let out_name = if right_col_names.contains(name) {
6176 format!("{}{}", name, suffix.left)
6177 } else {
6178 name.clone()
6179 };
6180 let null_col = gather_column_nullable_null(col, left_rows_opt);
6181 columns.push((out_name, null_col));
6182 }
6183
6184 for &ci in right.proj.indices() {
6186 let (name, col) = &right.base.columns[ci];
6187 let out_name = if !right_key_names.contains(name.as_str())
6188 && left.proj.indices().iter().any(|&lci| left.base.columns[lci].0 == *name)
6189 && !left_key_names.contains(name.as_str())
6190 {
6191 format!("{}{}", name, suffix.right)
6192 } else {
6193 name.clone()
6194 };
6195 columns.push((out_name, NullCol::from_column(&gather_column(col, right_rows))));
6196 }
6197
6198 Ok(NullableFrame { columns })
6199}
6200
6201fn build_full_join_frame(
6202 left: &TidyView,
6203 right: &TidyView,
6204 on: &[(&str, &str)],
6205 suffix: &JoinSuffix,
6206) -> Result<NullableFrame, TidyError> {
6207 let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
6208 let lookup = build_right_lookup(right, &right_key_cols);
6209
6210 let mut out_left_rows: Vec<usize> = Vec::new();
6212 let mut out_right_rows: Vec<Option<usize>> = Vec::new();
6213 let mut right_matched: Vec<bool> = vec![false; right.base.nrows()];
6214
6215 for l_row in left.mask.iter_set() {
6216 let key = row_key(&left.base, &left_key_cols, l_row);
6217 let matches = find_matches(&lookup, &key);
6218 if matches.is_empty() {
6219 out_left_rows.push(l_row);
6220 out_right_rows.push(None);
6221 } else {
6222 for r_row in &matches {
6223 out_left_rows.push(l_row);
6224 out_right_rows.push(Some(*r_row));
6225 if *r_row < right_matched.len() {
6226 right_matched[*r_row] = true;
6227 }
6228 }
6229 }
6230 }
6231
6232 let mut unmatched_right: Vec<usize> = Vec::new();
6234 for r_row in right.mask.iter_set() {
6235 if r_row < right_matched.len() && !right_matched[r_row] {
6236 unmatched_right.push(r_row);
6237 }
6238 }
6239
6240 let right_key_names: std::collections::BTreeSet<&str> =
6241 on.iter().map(|(_, rk)| *rk).collect();
6242 let left_key_names: std::collections::BTreeSet<&str> =
6243 on.iter().map(|(lk, _)| *lk).collect();
6244 let right_col_names: Vec<String> = right.proj.indices().iter()
6245 .map(|&ci| right.base.columns[ci].0.clone())
6246 .collect();
6247
6248 let n_matched = out_left_rows.len();
6249 let n_unmatched_r = unmatched_right.len();
6250 let total = n_matched + n_unmatched_r;
6251
6252 let mut columns: Vec<(String, NullCol)> = Vec::new();
6253
6254 for &ci in left.proj.indices() {
6256 let (name, col) = &left.base.columns[ci];
6257 let out_name = if right_col_names.contains(name) && !left_key_names.contains(name.as_str()) {
6258 format!("{}{}", name, suffix.left)
6259 } else {
6260 name.clone()
6261 };
6262 let mut matched_vals: Vec<Option<usize>> = out_left_rows.iter()
6263 .map(|&r| Some(r))
6264 .collect();
6265 matched_vals.extend(std::iter::repeat(None).take(n_unmatched_r));
6267 assert_eq!(matched_vals.len(), total);
6268 columns.push((out_name, gather_column_nullable_null(col, &matched_vals)));
6269 }
6270
6271 for &ci in right.proj.indices() {
6273 let (name, col) = &right.base.columns[ci];
6274 if right_key_names.contains(name.as_str()) { continue; }
6275 let out_name = if left.proj.indices().iter().any(|&lci| left.base.columns[lci].0 == *name)
6276 && !left_key_names.contains(name.as_str())
6277 {
6278 format!("{}{}", name, suffix.right)
6279 } else {
6280 name.clone()
6281 };
6282
6283 let mut row_opts: Vec<Option<usize>> = out_right_rows.clone();
6284 row_opts.extend(unmatched_right.iter().map(|&r| Some(r)));
6286 assert_eq!(row_opts.len(), total);
6287 columns.push((out_name, gather_column_nullable_null(col, &row_opts)));
6288 }
6289
6290 Ok(NullableFrame { columns })
6296}
6297
6298impl GroupedTidyView {
6301
6302 pub fn mutate_across(&self, specs: &[AcrossSpec]) -> Result<TidyFrame, TidyError> {
6308 self.view.mutate_across(specs)
6311 }
6312
6313 pub fn summarise_across(&self, specs: &[AcrossSpec]) -> Result<TidyFrame, TidyError> {
6320 let n_groups = self.ngroups();
6321
6322 let key_names = &self.index.key_names;
6324 let mut out_cols: Vec<(String, Column)> = Vec::new();
6325
6326 for ki in 0..key_names.len() {
6328 let col_vals: Vec<String> = self.index.groups.iter()
6329 .map(|g| g.key_values[ki].clone())
6330 .collect();
6331 out_cols.push((key_names[ki].clone(), Column::Str(col_vals)));
6332 }
6333
6334 for spec in specs {
6336 for col_name in &spec.cols {
6337 let out_name = spec.output_name(col_name);
6338 if out_cols.iter().any(|(n, _)| n == &out_name) {
6340 return Err(TidyError::DuplicateColumn(out_name));
6341 }
6342
6343 let base_col = self.view.base.get_column(col_name)
6344 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
6345
6346 let mut agg_floats: Vec<f64> = Vec::with_capacity(n_groups);
6348 for group in &self.index.groups {
6349 let group_col = gather_column(base_col, &group.row_indices);
6350 let result_col = (spec.transform.func)(col_name, &group_col)?;
6351 if result_col.len() != 1 {
6352 return Err(TidyError::LengthMismatch {
6353 expected: 1,
6354 got: result_col.len(),
6355 });
6356 }
6357 let v = match &result_col {
6358 Column::Float(v) => v[0],
6359 Column::Int(v) => v[0] as f64,
6360 _ => return Err(TidyError::TypeMismatch {
6361 expected: "Float or Int".into(),
6362 got: result_col.type_name().into(),
6363 }),
6364 };
6365 agg_floats.push(v);
6366 }
6367 out_cols.push((out_name, Column::Float(agg_floats)));
6368 }
6369 }
6370
6371 let df = DataFrame::from_columns(out_cols)
6372 .map_err(|e| TidyError::Internal(e.to_string()))?;
6373 Ok(TidyFrame::from_df(df))
6374 }
6375}
6376
6377mod indexmap_simple {
6382 use super::{Column, DataFrame, DataError};
6383
6384 pub struct IndexMap {
6385 entries: Vec<(String, Column)>,
6386 }
6387
6388 impl IndexMap {
6389 pub fn from_df(df: &DataFrame) -> Self {
6390 Self {
6391 entries: df.columns.iter()
6392 .map(|(n, c)| (n.clone(), c.clone()))
6393 .collect(),
6394 }
6395 }
6396
6397 pub fn insert(&mut self, name: String, col: Column) {
6399 if let Some(pos) = self.entries.iter().position(|(n, _)| n == &name) {
6400 self.entries[pos] = (name, col);
6401 } else {
6402 self.entries.push((name, col));
6403 }
6404 }
6405
6406 pub fn into_df(self) -> Result<DataFrame, DataError> {
6407 DataFrame::from_columns(self.entries)
6408 }
6409 }
6410}
6411
6412impl GroupIndex {
6429 pub fn build_fast(
6434 base: &DataFrame,
6435 key_col_indices: &[usize],
6436 visible_rows: &[usize],
6437 key_names: Vec<String>,
6438 ) -> Self {
6439 use std::collections::BTreeMap;
6440
6441 let mut groups: Vec<GroupMeta> = Vec::new();
6442 let mut key_to_slot: BTreeMap<Vec<String>, usize> = BTreeMap::new();
6443
6444 for &row in visible_rows {
6445 let key: Vec<String> = key_col_indices.iter()
6446 .map(|&ci| base.columns[ci].1.get_display(row))
6447 .collect();
6448
6449 if let Some(&slot) = key_to_slot.get(&key) {
6450 groups[slot].row_indices.push(row);
6451 } else {
6452 let slot = groups.len();
6453 let key_values = key.clone();
6454 key_to_slot.insert(key, slot);
6455 groups.push(GroupMeta { key_values, row_indices: vec![row] });
6456 }
6457 }
6458
6459 GroupIndex { groups, key_names }
6460 }
6461}
6462
6463impl TidyView {
6466 pub fn group_by_fast(&self, keys: &[&str]) -> Result<GroupedTidyView, TidyError> {
6471 let mut key_col_indices = Vec::with_capacity(keys.len());
6472 for &key in keys {
6473 let idx = self.base.columns.iter().position(|(n, _)| n == key)
6474 .ok_or_else(|| TidyError::ColumnNotFound(key.to_string()))?;
6475 key_col_indices.push(idx);
6476 }
6477 let visible_rows: Vec<usize> = self.mask.iter_set().collect();
6478 let key_names: Vec<String> = keys.iter().map(|s| s.to_string()).collect();
6479 let index = GroupIndex::build_fast(&self.base, &key_col_indices, &visible_rows, key_names);
6480 Ok(GroupedTidyView { view: self.clone(), index })
6481 }
6482}
6483
6484#[derive(Clone, Debug)]
6550pub struct FctColumn {
6551 pub levels: Vec<String>,
6554 pub data: Vec<u16>,
6556}
6557
6558impl FctColumn {
6559 pub fn encode(strings: &[String]) -> Result<Self, TidyError> {
6566 use std::collections::BTreeMap;
6567 let mut levels: Vec<String> = Vec::new();
6568 let mut level_map: BTreeMap<String, u16> = BTreeMap::new();
6572 let mut data: Vec<u16> = Vec::with_capacity(strings.len());
6573
6574 for s in strings {
6575 let idx = if let Some(&existing) = level_map.get(s.as_str()) {
6576 existing
6577 } else {
6578 let next = levels.len();
6579 if next >= 65_535 {
6580 return Err(TidyError::CapacityExceeded {
6581 limit: 65_535,
6582 got: next + 1,
6583 });
6584 }
6585 let idx = next as u16;
6586 levels.push(s.clone());
6587 level_map.insert(s.clone(), idx);
6588 idx
6589 };
6590 data.push(idx);
6591 }
6592 Ok(FctColumn { levels, data })
6593 }
6594
6595 pub fn encode_from_view(view: &TidyView, col: &str) -> Result<Self, TidyError> {
6597 let base_idx = view.base.columns.iter()
6598 .position(|(n, _)| n == col)
6599 .ok_or_else(|| TidyError::ColumnNotFound(col.to_string()))?;
6600 if !view.proj.indices().contains(&base_idx) {
6602 return Err(TidyError::ColumnNotFound(col.to_string()));
6603 }
6604 let col_data = &view.base.columns[base_idx].1;
6605 let visible: Vec<usize> = view.mask.iter_set().collect();
6606 let strings: Vec<String> = visible.iter()
6607 .map(|&r| col_data.get_display(r))
6608 .collect();
6609 Self::encode(&strings)
6610 }
6611
6612 pub fn nrows(&self) -> usize { self.data.len() }
6616 pub fn nlevels(&self) -> usize { self.levels.len() }
6618
6619 pub fn decode(&self, i: usize) -> &str {
6621 &self.levels[self.data[i] as usize]
6622 }
6623
6624 pub fn fct_lump(&self, n: usize) -> Result<Self, TidyError> {
6636 if n >= self.levels.len() {
6637 return Ok(self.clone()); }
6639
6640 let mut freq = vec![0usize; self.levels.len()];
6642 for &idx in &self.data {
6643 freq[idx as usize] += 1;
6644 }
6645
6646 let mut ranked: Vec<(usize, usize)> = freq.iter().copied().enumerate().collect();
6649 ranked.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
6650
6651 let mut keep_set: Vec<usize> = ranked[..n].iter().map(|(i, _)| *i).collect();
6653 keep_set.sort_unstable(); let mut other_name = "Other".to_string();
6657 while keep_set.iter().any(|&ki| self.levels[ki] == other_name) {
6658 other_name.push('_');
6659 }
6660
6661 let mut new_levels: Vec<String> = keep_set.iter().map(|&ki| self.levels[ki].clone()).collect();
6663 let other_idx = new_levels.len() as u16;
6664 new_levels.push(other_name);
6665
6666 let mut remap = vec![other_idx; self.levels.len()];
6668 for (new_i, &old_i) in keep_set.iter().enumerate() {
6669 remap[old_i] = new_i as u16;
6670 }
6671
6672 let new_data: Vec<u16> = self.data.iter().map(|&d| remap[d as usize]).collect();
6673 Ok(FctColumn { levels: new_levels, data: new_data })
6674 }
6675
6676 pub fn fct_reorder(&self, summary_vals: &[f64], descending: bool) -> Result<Self, TidyError> {
6685 if summary_vals.len() != self.levels.len() {
6686 return Err(TidyError::LengthMismatch {
6687 expected: self.levels.len(),
6688 got: summary_vals.len(),
6689 });
6690 }
6691 let mut order: Vec<usize> = (0..self.levels.len()).collect();
6695 order.sort_by(|&a, &b| {
6696 let va = summary_vals[a];
6697 let vb = summary_vals[b];
6698 match (va.is_nan(), vb.is_nan()) {
6699 (true, true) => std::cmp::Ordering::Equal,
6700 (true, false) => std::cmp::Ordering::Greater, (false, true) => std::cmp::Ordering::Less, (false, false) => {
6703 let cmp = va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal);
6704 if descending { cmp.reverse() } else { cmp }
6705 }
6706 }
6707 });
6708
6709 let new_levels: Vec<String> = order.iter().map(|&i| self.levels[i].clone()).collect();
6711
6712 let mut remap = vec![0u16; self.levels.len()];
6714 for (new_i, &old_i) in order.iter().enumerate() {
6715 remap[old_i] = new_i as u16;
6716 }
6717
6718 let new_data: Vec<u16> = self.data.iter().map(|&d| remap[d as usize]).collect();
6719 Ok(FctColumn { levels: new_levels, data: new_data })
6720 }
6721
6722 pub fn fct_reorder_by_col(&self, numeric_col: &Column, descending: bool) -> Result<Self, TidyError> {
6728 if numeric_col.len() != self.data.len() {
6729 return Err(TidyError::LengthMismatch {
6730 expected: self.data.len(),
6731 got: numeric_col.len(),
6732 });
6733 }
6734 let mut sums = vec![0.0f64; self.levels.len()];
6735 let mut counts = vec![0usize; self.levels.len()];
6736 match numeric_col {
6737 Column::Float(v) => {
6738 for (i, &d) in self.data.iter().enumerate() {
6739 let val = v[i];
6740 if !val.is_nan() {
6741 sums[d as usize] += val;
6742 counts[d as usize] += 1;
6743 }
6744 }
6745 }
6746 Column::Int(v) => {
6747 for (i, &d) in self.data.iter().enumerate() {
6748 sums[d as usize] += v[i] as f64;
6749 counts[d as usize] += 1;
6750 }
6751 }
6752 _ => return Err(TidyError::TypeMismatch {
6753 expected: "Float or Int".to_string(),
6754 got: numeric_col.type_name().to_string(),
6755 }),
6756 }
6757 let means: Vec<f64> = sums.iter().zip(counts.iter())
6758 .map(|(&s, &c)| if c == 0 { f64::NAN } else { s / c as f64 })
6759 .collect();
6760 self.fct_reorder(&means, descending)
6761 }
6762
6763 pub fn fct_collapse(&self, mapping: &[(&str, &str)]) -> Result<Self, TidyError> {
6780 if mapping.is_empty() {
6781 return Ok(self.clone());
6782 }
6783 let new_name_for: Vec<String> = self.levels.iter().map(|old| {
6785 if let Some((_, new)) = mapping.iter().find(|(o, _)| *o == old.as_str()) {
6786 new.to_string()
6787 } else {
6788 old.clone()
6789 }
6790 }).collect();
6791
6792 use std::collections::BTreeMap;
6795 let mut new_levels: Vec<String> = Vec::new();
6796 let mut new_name_to_idx: BTreeMap<String, u16> = BTreeMap::new();
6797
6798 let mut old_to_new: Vec<u16> = Vec::with_capacity(self.levels.len());
6799 for name in &new_name_for {
6800 let idx = if let Some(&existing) = new_name_to_idx.get(name.as_str()) {
6801 existing
6802 } else {
6803 let idx = new_levels.len() as u16;
6804 new_levels.push(name.clone());
6805 new_name_to_idx.insert(name.clone(), idx);
6806 idx
6807 };
6808 old_to_new.push(idx);
6809 }
6810
6811 let changed = old_to_new.iter().enumerate().any(|(i, &new)| new != i as u16);
6813 let new_data = if changed {
6814 self.data.iter().map(|&d| old_to_new[d as usize]).collect()
6815 } else {
6816 self.data.clone()
6817 };
6818 Ok(FctColumn { levels: new_levels, data: new_data })
6819 }
6820
6821 pub fn to_str_column(&self) -> Column {
6825 Column::Str(self.data.iter().map(|&d| self.levels[d as usize].clone()).collect())
6826 }
6827
6828 pub fn gather(&self, indices: &[usize]) -> FctColumn {
6830 FctColumn {
6831 levels: self.levels.clone(),
6832 data: indices.iter().map(|&i| self.data[i]).collect(),
6833 }
6834 }
6835}
6836
6837impl TidyError {
6840 pub fn capacity_exceeded(limit: usize, got: usize) -> Self {
6842 TidyError::CapacityExceeded { limit, got }
6843 }
6844}
6845
6846#[derive(Clone, Debug)]
6851pub struct NullableFactor {
6852 pub fct: FctColumn,
6853 pub validity: BitMask,
6854}
6855
6856impl NullableFactor {
6857 pub fn from_fct(fct: FctColumn) -> Self {
6859 let n = fct.nrows();
6860 NullableFactor { fct, validity: BitMask::all_true(n) }
6861 }
6862
6863 pub fn new(fct: FctColumn, validity: BitMask) -> Self {
6865 NullableFactor { fct, validity }
6866 }
6867
6868 pub fn encode_nullable(strings: &[Option<String>]) -> Result<Self, TidyError> {
6872 use std::collections::BTreeMap;
6873 let mut levels: Vec<String> = Vec::new();
6874 let mut level_map: BTreeMap<String, u16> = BTreeMap::new();
6875 let mut data: Vec<u16> = Vec::with_capacity(strings.len());
6876 let mut valid_flags: Vec<bool> = Vec::with_capacity(strings.len());
6877
6878 for opt in strings {
6879 match opt {
6880 None => {
6881 data.push(0); valid_flags.push(false);
6883 }
6884 Some(s) => {
6885 let idx = if let Some(&existing) = level_map.get(s.as_str()) {
6886 existing
6887 } else {
6888 let next = levels.len();
6889 if next >= 65_535 {
6890 return Err(TidyError::CapacityExceeded { limit: 65_535, got: next + 1 });
6891 }
6892 let idx = next as u16;
6893 levels.push(s.clone());
6894 level_map.insert(s.clone(), idx);
6895 idx
6896 };
6897 data.push(idx);
6898 valid_flags.push(true);
6899 }
6900 }
6901 }
6902 let fct = FctColumn { levels, data };
6903 let validity = BitMask::from_bools(&valid_flags);
6904 Ok(NullableFactor { fct, validity })
6905 }
6906
6907 pub fn nrows(&self) -> usize { self.fct.nrows() }
6909 pub fn nlevels(&self) -> usize { self.fct.nlevels() }
6911 pub fn is_null(&self, i: usize) -> bool { !self.validity.get(i) }
6913 pub fn count_valid(&self) -> usize { self.validity.count_ones() }
6915
6916 pub fn decode(&self, i: usize) -> Option<&str> {
6918 if self.is_null(i) { None } else { Some(self.fct.decode(i)) }
6919 }
6920
6921 pub fn fct_lump(&self, n: usize) -> Result<Self, TidyError> {
6923 let lumped = self.fct.fct_lump(n)?;
6924 Ok(NullableFactor { fct: lumped, validity: self.validity.clone() })
6925 }
6926
6927 pub fn fct_reorder(&self, summary_vals: &[f64], descending: bool) -> Result<Self, TidyError> {
6929 let reordered = self.fct.fct_reorder(summary_vals, descending)?;
6930 Ok(NullableFactor { fct: reordered, validity: self.validity.clone() })
6931 }
6932
6933 pub fn fct_collapse(&self, mapping: &[(&str, &str)]) -> Result<Self, TidyError> {
6935 let collapsed = self.fct.fct_collapse(mapping)?;
6936 Ok(NullableFactor { fct: collapsed, validity: self.validity.clone() })
6937 }
6938}
6939
6940impl TidyView {
6943 pub fn fct_encode(&self, col: &str) -> Result<FctColumn, TidyError> {
6948 FctColumn::encode_from_view(self, col)
6949 }
6950
6951 pub fn fct_summary_means(
6956 &self,
6957 fct: &FctColumn,
6958 numeric_col: &str,
6959 ) -> Result<Vec<f64>, TidyError> {
6960 let base_idx = self.base.columns.iter()
6961 .position(|(n, _)| n == numeric_col)
6962 .ok_or_else(|| TidyError::ColumnNotFound(numeric_col.to_string()))?;
6963 let nc = &self.base.columns[base_idx].1;
6964 if nc.len() != fct.nrows() {
6965 return Err(TidyError::LengthMismatch { expected: fct.nrows(), got: nc.len() });
6966 }
6967 match nc {
6969 Column::Float(_) | Column::Int(_) => {}
6970 _ => return Err(TidyError::TypeMismatch {
6971 expected: "Float or Int".to_string(),
6972 got: nc.type_name().to_string(),
6973 }),
6974 }
6975 let mut sums = vec![0.0f64; fct.levels.len()];
6976 let mut counts = vec![0usize; fct.levels.len()];
6977 match nc {
6978 Column::Float(v) => {
6979 for (i, &d) in fct.data.iter().enumerate() {
6980 if !v[i].is_nan() {
6981 sums[d as usize] += v[i];
6982 counts[d as usize] += 1;
6983 }
6984 }
6985 }
6986 Column::Int(v) => {
6987 for (i, &d) in fct.data.iter().enumerate() {
6988 sums[d as usize] += v[i] as f64;
6989 counts[d as usize] += 1;
6990 }
6991 }
6992 _ => unreachable!(),
6993 }
6994 Ok(sums.iter().zip(counts.iter())
6995 .map(|(&s, &c)| if c == 0 { f64::NAN } else { s / c as f64 })
6996 .collect())
6997 }
6998}
6999
7000pub fn label_encode(col: &[String]) -> (Vec<String>, Vec<u32>) {
7007 let unique: BTreeSet<&str> = col.iter().map(|s| s.as_str()).collect();
7008 let levels: Vec<String> = unique.into_iter().map(|s| s.to_string()).collect();
7009
7010 let lookup: BTreeMap<&str, u32> = levels
7011 .iter()
7012 .enumerate()
7013 .map(|(i, s)| (s.as_str(), i as u32))
7014 .collect();
7015
7016 let codes: Vec<u32> = col.iter().map(|s| lookup[s.as_str()]).collect();
7017 (levels, codes)
7018}
7019
7020pub fn ordinal_encode(col: &[String], order: &[String]) -> Result<(Vec<String>, Vec<u32>), String> {
7025 let lookup: BTreeMap<&str, u32> = order
7026 .iter()
7027 .enumerate()
7028 .map(|(i, s)| (s.as_str(), i as u32))
7029 .collect();
7030
7031 let mut codes = Vec::with_capacity(col.len());
7032 for s in col {
7033 match lookup.get(s.as_str()) {
7034 Some(&idx) => codes.push(idx),
7035 None => return Err(format!("value {:?} not found in specified order", s)),
7036 }
7037 }
7038 Ok((order.to_vec(), codes))
7039}
7040
7041pub fn one_hot_encode(levels: &[String], codes: &[u32]) -> (Vec<String>, Vec<Vec<bool>>) {
7046 let n_levels = levels.len();
7047 let n_rows = codes.len();
7048
7049 let mut columns: Vec<Vec<bool>> = vec![vec![false; n_rows]; n_levels];
7050 for (row, &code) in codes.iter().enumerate() {
7051 columns[code as usize][row] = true;
7052 }
7053
7054 let names: Vec<String> = levels.to_vec();
7055 (names, columns)
7056}
7057
7058#[cfg(test)]
7059mod rolling_window_tests {
7060 use super::*;
7061
7062 fn make_df(col_name: &str, vals: Vec<f64>) -> DataFrame {
7064 DataFrame {
7065 columns: vec![(col_name.to_string(), Column::Float(vals))],
7066 }
7067 }
7068
7069 #[test]
7070 fn rolling_sum_basic() {
7071 let df = make_df("x", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
7074 let expr = DExpr::RollingSum("x".into(), 3);
7075 let col = eval_expr_column(&df, &expr, 5).unwrap();
7076 match col {
7077 Column::Float(v) => {
7078 assert_eq!(v.len(), 5);
7079 assert!((v[0] - 1.0).abs() < 1e-12);
7080 assert!((v[1] - 3.0).abs() < 1e-12);
7081 assert!((v[2] - 6.0).abs() < 1e-12);
7082 assert!((v[3] - 9.0).abs() < 1e-12);
7083 assert!((v[4] - 12.0).abs() < 1e-12);
7084 }
7085 _ => panic!("expected Float column"),
7086 }
7087 }
7088
7089 #[test]
7090 fn rolling_mean_basic() {
7091 let df = make_df("x", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
7094 let expr = DExpr::RollingMean("x".into(), 3);
7095 let col = eval_expr_column(&df, &expr, 5).unwrap();
7096 match col {
7097 Column::Float(v) => {
7098 assert_eq!(v.len(), 5);
7099 assert!((v[0] - 1.0).abs() < 1e-12);
7100 assert!((v[1] - 1.5).abs() < 1e-12);
7101 assert!((v[2] - 2.0).abs() < 1e-12);
7102 assert!((v[3] - 3.0).abs() < 1e-12);
7103 assert!((v[4] - 4.0).abs() < 1e-12);
7104 }
7105 _ => panic!("expected Float column"),
7106 }
7107 }
7108
7109 #[test]
7110 fn rolling_min_basic() {
7111 let df = make_df("x", vec![5.0, 3.0, 4.0, 1.0, 2.0]);
7114 let expr = DExpr::RollingMin("x".into(), 3);
7115 let col = eval_expr_column(&df, &expr, 5).unwrap();
7116 match col {
7117 Column::Float(v) => {
7118 assert_eq!(v.len(), 5);
7119 assert!((v[0] - 5.0).abs() < 1e-12);
7120 assert!((v[1] - 3.0).abs() < 1e-12);
7121 assert!((v[2] - 3.0).abs() < 1e-12);
7122 assert!((v[3] - 1.0).abs() < 1e-12);
7123 assert!((v[4] - 1.0).abs() < 1e-12);
7124 }
7125 _ => panic!("expected Float column"),
7126 }
7127 }
7128
7129 #[test]
7130 fn rolling_max_basic() {
7131 let df = make_df("x", vec![1.0, 5.0, 3.0, 2.0, 4.0]);
7134 let expr = DExpr::RollingMax("x".into(), 3);
7135 let col = eval_expr_column(&df, &expr, 5).unwrap();
7136 match col {
7137 Column::Float(v) => {
7138 assert_eq!(v.len(), 5);
7139 assert!((v[0] - 1.0).abs() < 1e-12);
7140 assert!((v[1] - 5.0).abs() < 1e-12);
7141 assert!((v[2] - 5.0).abs() < 1e-12);
7142 assert!((v[3] - 5.0).abs() < 1e-12);
7143 assert!((v[4] - 4.0).abs() < 1e-12);
7144 }
7145 _ => panic!("expected Float column"),
7146 }
7147 }
7148
7149 #[test]
7150 fn rolling_var_basic() {
7151 let df = make_df("x", vec![2.0, 4.0, 6.0, 8.0]);
7153 let expr = DExpr::RollingVar("x".into(), 3);
7154 let col = eval_expr_column(&df, &expr, 4).unwrap();
7155 match col {
7156 Column::Float(v) => {
7157 assert_eq!(v.len(), 4);
7158 assert!((v[0] - 0.0).abs() < 1e-12);
7160 assert!((v[1] - 2.0).abs() < 1e-10);
7162 assert!((v[2] - 4.0).abs() < 1e-10);
7164 assert!((v[3] - 4.0).abs() < 1e-10);
7166 }
7167 _ => panic!("expected Float column"),
7168 }
7169 }
7170
7171 #[test]
7172 fn rolling_sd_basic() {
7173 let df = make_df("x", vec![2.0, 4.0, 6.0, 8.0]);
7174 let expr = DExpr::RollingSd("x".into(), 3);
7175 let col = eval_expr_column(&df, &expr, 4).unwrap();
7176 match col {
7177 Column::Float(v) => {
7178 assert_eq!(v.len(), 4);
7179 assert!((v[0] - 0.0).abs() < 1e-12);
7180 assert!((v[1] - 2.0_f64.sqrt()).abs() < 1e-10);
7181 assert!((v[2] - 2.0).abs() < 1e-10);
7182 assert!((v[3] - 2.0).abs() < 1e-10);
7183 }
7184 _ => panic!("expected Float column"),
7185 }
7186 }
7187
7188 #[test]
7189 fn rolling_window_larger_than_data() {
7190 let df = make_df("x", vec![1.0, 2.0, 3.0]);
7191 let expr = DExpr::RollingSum("x".into(), 10);
7192 let col = eval_expr_column(&df, &expr, 3).unwrap();
7193 match col {
7194 Column::Float(v) => {
7195 assert_eq!(v.len(), 3);
7196 assert!((v[0] - 1.0).abs() < 1e-12);
7197 assert!((v[1] - 3.0).abs() < 1e-12);
7198 assert!((v[2] - 6.0).abs() < 1e-12);
7199 }
7200 _ => panic!("expected Float column"),
7201 }
7202 }
7203
7204 #[test]
7205 fn rolling_window_of_one() {
7206 let df = make_df("x", vec![3.0, 1.0, 4.0, 1.0, 5.0]);
7207 let expr_min = DExpr::RollingMin("x".into(), 1);
7208 let expr_max = DExpr::RollingMax("x".into(), 1);
7209 let col_min = eval_expr_column(&df, &expr_min, 5).unwrap();
7210 let col_max = eval_expr_column(&df, &expr_max, 5).unwrap();
7211 match (col_min, col_max) {
7212 (Column::Float(mins), Column::Float(maxs)) => {
7213 let expected = [3.0, 1.0, 4.0, 1.0, 5.0];
7214 for i in 0..5 {
7215 assert!((mins[i] - expected[i]).abs() < 1e-12, "min[{}]", i);
7216 assert!((maxs[i] - expected[i]).abs() < 1e-12, "max[{}]", i);
7217 }
7218 }
7219 _ => panic!("expected Float columns"),
7220 }
7221 }
7222
7223 #[test]
7224 fn rolling_sum_with_nan() {
7225 let df = make_df("x", vec![1.0, f64::NAN, 3.0, 4.0]);
7226 let expr = DExpr::RollingSum("x".into(), 2);
7227 let col = eval_expr_column(&df, &expr, 4).unwrap();
7228 match col {
7229 Column::Float(v) => {
7230 assert_eq!(v.len(), 4);
7231 assert!((v[0] - 1.0).abs() < 1e-12);
7232 assert!(v[1].is_nan());
7233 assert!(v[2].is_nan());
7234 assert!(v[3].is_nan()); }
7236 _ => panic!("expected Float column"),
7237 }
7238 }
7239
7240 #[test]
7241 fn rolling_determinism() {
7242 let df = make_df("x", vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
7243 let expr = DExpr::RollingSum("x".into(), 4);
7244 let mut runs: Vec<Vec<f64>> = Vec::new();
7245 for _ in 0..3 {
7246 let col = eval_expr_column(&df, &expr, 10).unwrap();
7247 match col {
7248 Column::Float(v) => runs.push(v),
7249 _ => panic!("expected Float column"),
7250 }
7251 }
7252 assert_eq!(runs[0], runs[1]);
7253 assert_eq!(runs[1], runs[2]);
7254 }
7255
7256 #[test]
7257 fn rolling_display() {
7258 let expr = DExpr::RollingSum("val".into(), 5);
7259 assert_eq!(format!("{}", expr), "rolling_sum(\"val\", 5)");
7260 let expr2 = DExpr::RollingMean("col".into(), 3);
7261 assert_eq!(format!("{}", expr2), "rolling_mean(\"col\", 3)");
7262 }
7263
7264 #[test]
7265 fn rolling_collect_columns() {
7266 let expr = DExpr::RollingSum("revenue".into(), 7);
7267 let mut cols = Vec::new();
7268 collect_expr_columns(&expr, &mut cols);
7269 assert_eq!(cols, vec!["revenue".to_string()]);
7270 }
7271
7272 #[test]
7273 fn rolling_not_allowed_in_row_context() {
7274 let df = make_df("x", vec![1.0, 2.0, 3.0]);
7275 let expr = DExpr::RollingSum("x".into(), 2);
7276 let result = eval_expr_row(&df, &expr, 0);
7277 assert!(result.is_err());
7278 }
7279}
7280
7281