1use cjc_repro::kahan_sum_f64;
10use std::cell::RefCell;
11use std::collections::{BTreeMap, BTreeSet, VecDeque};
12use std::fmt;
13use std::rc::Rc;
14
15mod csv;
16pub use csv::{CsvConfig, CsvReader, StreamingCsvProcessor};
17
18pub mod agg_kernels;
19pub mod column_meta;
20pub mod dict_encoding;
21pub mod lazy;
22pub mod tidy_dispatch;
23
24#[derive(Debug, Clone)]
28pub enum Column {
29 Int(Vec<i64>),
30 Float(Vec<f64>),
31 Str(Vec<String>),
32 Bool(Vec<bool>),
33 Categorical {
35 levels: Vec<String>,
36 codes: Vec<u32>,
37 },
38 DateTime(Vec<i64>),
40}
41
42impl Column {
43 pub fn len(&self) -> usize {
44 match self {
45 Column::Int(v) => v.len(),
46 Column::Float(v) => v.len(),
47 Column::Str(v) => v.len(),
48 Column::Bool(v) => v.len(),
49 Column::Categorical { codes, .. } => codes.len(),
50 Column::DateTime(v) => v.len(),
51 }
52 }
53
54 pub fn is_empty(&self) -> bool {
55 self.len() == 0
56 }
57
58 pub fn type_name(&self) -> &'static str {
59 match self {
60 Column::Int(_) => "Int",
61 Column::Float(_) => "Float",
62 Column::Str(_) => "Str",
63 Column::Bool(_) => "Bool",
64 Column::Categorical { .. } => "Categorical",
65 Column::DateTime(_) => "DateTime",
66 }
67 }
68
69 pub fn get_display(&self, idx: usize) -> String {
71 match self {
72 Column::Int(v) => format!("{}", v[idx]),
73 Column::Float(v) => format!("{}", v[idx]),
74 Column::Str(v) => v[idx].clone(),
75 Column::Bool(v) => format!("{}", v[idx]),
76 Column::Categorical { levels, codes } => levels[codes[idx] as usize].clone(),
77 Column::DateTime(v) => format!("{}ms", v[idx]),
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
86pub struct DataFrame {
87 pub columns: Vec<(String, Column)>,
88}
89
90impl DataFrame {
91 pub fn new() -> Self {
92 Self {
93 columns: Vec::new(),
94 }
95 }
96
97 pub fn from_columns(columns: Vec<(String, Column)>) -> Result<Self, DataError> {
98 if columns.is_empty() {
99 return Ok(Self { columns });
100 }
101 let len = columns[0].1.len();
102 for (name, col) in &columns {
103 if col.len() != len {
104 return Err(DataError::ColumnLengthMismatch {
105 expected: len,
106 got: col.len(),
107 column: name.clone(),
108 });
109 }
110 }
111 Ok(Self { columns })
112 }
113
114 pub fn nrows(&self) -> usize {
115 self.columns.first().map(|(_, c)| c.len()).unwrap_or(0)
116 }
117
118 pub fn ncols(&self) -> usize {
119 self.columns.len()
120 }
121
122 pub fn column_names(&self) -> Vec<&str> {
123 self.columns.iter().map(|(n, _)| n.as_str()).collect()
124 }
125
126 pub fn get_column(&self, name: &str) -> Option<&Column> {
127 self.columns
128 .iter()
129 .find(|(n, _)| n == name)
130 .map(|(_, c)| c)
131 }
132
133 pub fn to_tensor_data(&self, col_names: &[&str]) -> Result<(Vec<f64>, Vec<usize>), DataError> {
135 let nrows = self.nrows();
136 let ncols = col_names.len();
137 let mut data = Vec::with_capacity(nrows * ncols);
138
139 for row in 0..nrows {
140 for &col_name in col_names {
141 let col = self
142 .get_column(col_name)
143 .ok_or_else(|| DataError::ColumnNotFound(col_name.to_string()))?;
144 let val = match col {
145 Column::Float(v) => v[row],
146 Column::Int(v) => v[row] as f64,
147 _ => {
148 return Err(DataError::InvalidOperation(format!(
149 "column `{}` is not numeric",
150 col_name
151 )))
152 }
153 };
154 data.push(val);
155 }
156 }
157
158 Ok((data, vec![nrows, ncols]))
159 }
160}
161
162impl Default for DataFrame {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168impl fmt::Display for DataFrame {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 if self.columns.is_empty() {
171 return write!(f, "(empty DataFrame)");
172 }
173
174 let names: Vec<&str> = self.columns.iter().map(|(n, _)| n.as_str()).collect();
176 let mut col_widths: Vec<usize> = names.iter().map(|n| n.len()).collect();
177
178 let nrows = self.nrows();
180 for (col_idx, (_, col)) in self.columns.iter().enumerate() {
181 for row in 0..nrows {
182 let s = col.get_display(row);
183 col_widths[col_idx] = col_widths[col_idx].max(s.len());
184 }
185 }
186
187 for (i, name) in names.iter().enumerate() {
189 if i > 0 {
190 write!(f, " | ")?;
191 }
192 write!(f, "{:>width$}", name, width = col_widths[i])?;
193 }
194 writeln!(f)?;
195
196 for (i, &w) in col_widths.iter().enumerate() {
198 if i > 0 {
199 write!(f, "-+-")?;
200 }
201 write!(f, "{}", "-".repeat(w))?;
202 }
203 writeln!(f)?;
204
205 for row in 0..nrows {
207 for (col_idx, (_, col)) in self.columns.iter().enumerate() {
208 if col_idx > 0 {
209 write!(f, " | ")?;
210 }
211 let s = col.get_display(row);
212 write!(f, "{:>width$}", s, width = col_widths[col_idx])?;
213 }
214 writeln!(f)?;
215 }
216
217 Ok(())
218 }
219}
220
221#[derive(Debug, Clone)]
225pub enum DExpr {
226 Col(String),
228 LitInt(i64),
230 LitFloat(f64),
232 LitBool(bool),
234 LitStr(String),
236 BinOp {
238 op: DBinOp,
239 left: Box<DExpr>,
240 right: Box<DExpr>,
241 },
242 Agg(AggFunc, Box<DExpr>),
244 Count,
246 FnCall(String, Vec<DExpr>),
248 CumSum(Box<DExpr>),
250 CumProd(Box<DExpr>),
252 CumMax(Box<DExpr>),
254 CumMin(Box<DExpr>),
256 Lag(Box<DExpr>, usize),
258 Lead(Box<DExpr>, usize),
260 Rank(Box<DExpr>),
262 DenseRank(Box<DExpr>),
264 RowNumber,
266 RollingSum(String, usize),
268 RollingMean(String, usize),
270 RollingMin(String, usize),
272 RollingMax(String, usize),
274 RollingVar(String, usize),
276 RollingSd(String, usize),
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq)]
281pub enum DBinOp {
282 Add,
283 Sub,
284 Mul,
285 Div,
286 Gt,
287 Lt,
288 Ge,
289 Le,
290 Eq,
291 Ne,
292 And,
293 Or,
294}
295
296#[derive(Debug, Clone, Copy, PartialEq, Eq)]
297pub enum AggFunc {
298 Sum,
299 Mean,
300 Min,
301 Max,
302 Count,
303}
304
305impl fmt::Display for DExpr {
306 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307 match self {
308 DExpr::Col(name) => write!(f, "col(\"{}\")", name),
309 DExpr::LitInt(v) => write!(f, "{}", v),
310 DExpr::LitFloat(v) => write!(f, "{}", v),
311 DExpr::LitBool(b) => write!(f, "{}", b),
312 DExpr::LitStr(s) => write!(f, "\"{}\"", s),
313 DExpr::BinOp { op, left, right } => {
314 let op_str = match op {
315 DBinOp::Add => "+",
316 DBinOp::Sub => "-",
317 DBinOp::Mul => "*",
318 DBinOp::Div => "/",
319 DBinOp::Gt => ">",
320 DBinOp::Lt => "<",
321 DBinOp::Ge => ">=",
322 DBinOp::Le => "<=",
323 DBinOp::Eq => "==",
324 DBinOp::Ne => "!=",
325 DBinOp::And => "&&",
326 DBinOp::Or => "||",
327 };
328 write!(f, "({} {} {})", left, op_str, right)
329 }
330 DExpr::Agg(func, expr) => {
331 let name = match func {
332 AggFunc::Sum => "sum",
333 AggFunc::Mean => "mean",
334 AggFunc::Min => "min",
335 AggFunc::Max => "max",
336 AggFunc::Count => "count",
337 };
338 write!(f, "{}({})", name, expr)
339 }
340 DExpr::Count => write!(f, "count()"),
341 DExpr::FnCall(name, args) => {
342 let args_str: Vec<String> = args.iter().map(|a| format!("{}", a)).collect();
343 write!(f, "{}({})", name, args_str.join(", "))
344 }
345 DExpr::CumSum(e) => write!(f, "cumsum({})", e),
346 DExpr::CumProd(e) => write!(f, "cumprod({})", e),
347 DExpr::CumMax(e) => write!(f, "cummax({})", e),
348 DExpr::CumMin(e) => write!(f, "cummin({})", e),
349 DExpr::Lag(e, k) => write!(f, "lag({}, {})", e, k),
350 DExpr::Lead(e, k) => write!(f, "lead({}, {})", e, k),
351 DExpr::Rank(e) => write!(f, "rank({})", e),
352 DExpr::DenseRank(e) => write!(f, "dense_rank({})", e),
353 DExpr::RowNumber => write!(f, "row_number()"),
354 DExpr::RollingSum(col, w) => write!(f, "rolling_sum(\"{}\", {})", col, w),
355 DExpr::RollingMean(col, w) => write!(f, "rolling_mean(\"{}\", {})", col, w),
356 DExpr::RollingMin(col, w) => write!(f, "rolling_min(\"{}\", {})", col, w),
357 DExpr::RollingMax(col, w) => write!(f, "rolling_max(\"{}\", {})", col, w),
358 DExpr::RollingVar(col, w) => write!(f, "rolling_var(\"{}\", {})", col, w),
359 DExpr::RollingSd(col, w) => write!(f, "rolling_sd(\"{}\", {})", col, w),
360 }
361 }
362}
363
364#[derive(Debug, Clone)]
368pub enum LogicalPlan {
369 Scan {
371 source: DataFrame,
372 },
373 Filter {
375 input: Box<LogicalPlan>,
376 predicate: DExpr,
377 },
378 GroupBy {
380 input: Box<LogicalPlan>,
381 keys: Vec<String>,
382 },
383 Aggregate {
385 input: Box<LogicalPlan>,
386 keys: Vec<String>,
387 aggs: Vec<(String, DExpr)>,
388 },
389 Project {
391 input: Box<LogicalPlan>,
392 columns: Vec<String>,
393 },
394 InnerJoin {
396 left: Box<LogicalPlan>,
397 right: Box<LogicalPlan>,
398 left_on: String,
399 right_on: String,
400 },
401 LeftJoin {
403 left: Box<LogicalPlan>,
404 right: Box<LogicalPlan>,
405 left_on: String,
406 right_on: String,
407 },
408 CrossJoin {
410 left: Box<LogicalPlan>,
411 right: Box<LogicalPlan>,
412 },
413}
414
415impl LogicalPlan {
416 pub fn referenced_columns(&self) -> Vec<String> {
418 let mut cols = Vec::new();
419 self.collect_columns(&mut cols);
420 cols.sort();
421 cols.dedup();
422 cols
423 }
424
425 fn collect_columns(&self, cols: &mut Vec<String>) {
426 match self {
427 LogicalPlan::Scan { .. } => {}
428 LogicalPlan::Filter { input, predicate } => {
429 input.collect_columns(cols);
430 collect_expr_columns(predicate, cols);
431 }
432 LogicalPlan::GroupBy { input, keys } => {
433 input.collect_columns(cols);
434 cols.extend(keys.clone());
435 }
436 LogicalPlan::Aggregate {
437 input, keys, aggs, ..
438 } => {
439 input.collect_columns(cols);
440 cols.extend(keys.clone());
441 for (_, expr) in aggs {
442 collect_expr_columns(expr, cols);
443 }
444 }
445 LogicalPlan::Project { input, columns } => {
446 input.collect_columns(cols);
447 cols.extend(columns.clone());
448 }
449 LogicalPlan::InnerJoin {
450 left,
451 right,
452 left_on,
453 right_on,
454 }
455 | LogicalPlan::LeftJoin {
456 left,
457 right,
458 left_on,
459 right_on,
460 } => {
461 left.collect_columns(cols);
462 right.collect_columns(cols);
463 cols.push(left_on.clone());
464 cols.push(right_on.clone());
465 }
466 LogicalPlan::CrossJoin { left, right } => {
467 left.collect_columns(cols);
468 right.collect_columns(cols);
469 }
470 }
471 }
472}
473
474fn collect_expr_columns(expr: &DExpr, cols: &mut Vec<String>) {
475 match expr {
476 DExpr::Col(name) => cols.push(name.clone()),
477 DExpr::BinOp { left, right, .. } => {
478 collect_expr_columns(left, cols);
479 collect_expr_columns(right, cols);
480 }
481 DExpr::Agg(_, inner) => collect_expr_columns(inner, cols),
482 DExpr::FnCall(_, args) => {
483 for arg in args {
484 collect_expr_columns(arg, cols);
485 }
486 }
487 DExpr::CumSum(e) | DExpr::CumProd(e) | DExpr::CumMax(e) | DExpr::CumMin(e)
488 | DExpr::Lag(e, _) | DExpr::Lead(e, _) | DExpr::Rank(e) | DExpr::DenseRank(e) => {
489 collect_expr_columns(e, cols);
490 }
491 DExpr::RollingSum(col, _) | DExpr::RollingMean(col, _)
492 | DExpr::RollingMin(col, _) | DExpr::RollingMax(col, _)
493 | DExpr::RollingVar(col, _) | DExpr::RollingSd(col, _) => {
494 cols.push(col.clone());
495 }
496 _ => {}
497 }
498}
499
500pub fn optimize(plan: LogicalPlan) -> LogicalPlan {
504 let plan = push_down_predicates(plan);
505 let plan = prune_columns(plan);
506 plan
507}
508
509fn push_down_predicates(plan: LogicalPlan) -> LogicalPlan {
511 match plan {
512 LogicalPlan::Filter {
513 input,
514 predicate,
515 } => {
516 let optimized_input = push_down_predicates(*input);
517 match optimized_input {
518 LogicalPlan::GroupBy {
520 input: inner,
521 keys,
522 } => {
523 let pred_cols = {
524 let mut c = Vec::new();
525 collect_expr_columns(&predicate, &mut c);
526 c
527 };
528 let can_push = pred_cols.iter().all(|c| !keys.contains(c))
529 || pred_cols.iter().all(|c| {
530 !keys.contains(c) || keys.contains(c)
532 });
533 if can_push && pred_cols.iter().all(|c| !keys.contains(c)) {
535 LogicalPlan::GroupBy {
536 input: Box::new(LogicalPlan::Filter {
537 input: inner,
538 predicate,
539 }),
540 keys,
541 }
542 } else {
543 LogicalPlan::Filter {
544 input: Box::new(LogicalPlan::GroupBy {
545 input: inner,
546 keys,
547 }),
548 predicate,
549 }
550 }
551 }
552 other => LogicalPlan::Filter {
553 input: Box::new(other),
554 predicate,
555 },
556 }
557 }
558 LogicalPlan::GroupBy { input, keys } => LogicalPlan::GroupBy {
559 input: Box::new(push_down_predicates(*input)),
560 keys,
561 },
562 LogicalPlan::Aggregate {
563 input,
564 keys,
565 aggs,
566 } => LogicalPlan::Aggregate {
567 input: Box::new(push_down_predicates(*input)),
568 keys,
569 aggs,
570 },
571 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
572 input: Box::new(push_down_predicates(*input)),
573 columns,
574 },
575 LogicalPlan::InnerJoin {
576 left,
577 right,
578 left_on,
579 right_on,
580 } => LogicalPlan::InnerJoin {
581 left: Box::new(push_down_predicates(*left)),
582 right: Box::new(push_down_predicates(*right)),
583 left_on,
584 right_on,
585 },
586 LogicalPlan::LeftJoin {
587 left,
588 right,
589 left_on,
590 right_on,
591 } => LogicalPlan::LeftJoin {
592 left: Box::new(push_down_predicates(*left)),
593 right: Box::new(push_down_predicates(*right)),
594 left_on,
595 right_on,
596 },
597 LogicalPlan::CrossJoin { left, right } => LogicalPlan::CrossJoin {
598 left: Box::new(push_down_predicates(*left)),
599 right: Box::new(push_down_predicates(*right)),
600 },
601 other => other,
602 }
603}
604
605fn prune_columns(plan: LogicalPlan) -> LogicalPlan {
607 plan
610}
611
612pub fn execute(plan: &LogicalPlan) -> Result<DataFrame, DataError> {
616 match plan {
617 LogicalPlan::Scan { source } => Ok(source.clone()),
618
619 LogicalPlan::Filter { input, predicate } => {
620 let df = execute(input)?;
621 execute_filter(&df, predicate)
622 }
623
624 LogicalPlan::GroupBy { input, keys: _ } => {
625 let df = execute(input)?;
627 Ok(df)
629 }
630
631 LogicalPlan::Aggregate { input, keys, aggs } => {
632 let df = execute(input)?;
633 execute_aggregate(&df, keys, aggs)
634 }
635
636 LogicalPlan::Project { input, columns } => {
637 let df = execute(input)?;
638 let projected = df
639 .columns
640 .into_iter()
641 .filter(|(name, _)| columns.contains(name))
642 .collect();
643 Ok(DataFrame { columns: projected })
644 }
645
646 LogicalPlan::InnerJoin {
647 left,
648 right,
649 left_on,
650 right_on,
651 } => {
652 let left_df = execute(left)?;
653 let right_df = execute(right)?;
654 execute_inner_join(&left_df, &right_df, left_on, right_on)
655 }
656
657 LogicalPlan::LeftJoin {
658 left,
659 right,
660 left_on,
661 right_on,
662 } => {
663 let left_df = execute(left)?;
664 let right_df = execute(right)?;
665 execute_left_join(&left_df, &right_df, left_on, right_on)
666 }
667
668 LogicalPlan::CrossJoin { left, right } => {
669 let left_df = execute(left)?;
670 let right_df = execute(right)?;
671 execute_cross_join(&left_df, &right_df)
672 }
673 }
674}
675
676fn execute_filter(df: &DataFrame, predicate: &DExpr) -> Result<DataFrame, DataError> {
677 let nrows = df.nrows();
678 let mut mask = vec![false; nrows];
679
680 for row in 0..nrows {
681 let val = eval_expr_row(df, predicate, row)?;
682 mask[row] = match val {
683 ExprValue::Bool(b) => b,
684 _ => return Err(DataError::InvalidOperation("filter predicate must be boolean".into())),
685 };
686 }
687
688 let mut new_columns = Vec::new();
689 for (name, col) in &df.columns {
690 let filtered = filter_column(col, &mask);
691 new_columns.push((name.clone(), filtered));
692 }
693
694 Ok(DataFrame {
695 columns: new_columns,
696 })
697}
698
699fn filter_column(col: &Column, mask: &[bool]) -> Column {
700 match col {
701 Column::Int(v) => Column::Int(
702 v.iter()
703 .zip(mask)
704 .filter(|(_, &m)| m)
705 .map(|(v, _)| *v)
706 .collect(),
707 ),
708 Column::Float(v) => Column::Float(
709 v.iter()
710 .zip(mask)
711 .filter(|(_, &m)| m)
712 .map(|(v, _)| *v)
713 .collect(),
714 ),
715 Column::Str(v) => Column::Str(
716 v.iter()
717 .zip(mask)
718 .filter(|(_, &m)| m)
719 .map(|(v, _)| v.clone())
720 .collect(),
721 ),
722 Column::Bool(v) => Column::Bool(
723 v.iter()
724 .zip(mask)
725 .filter(|(_, &m)| m)
726 .map(|(v, _)| *v)
727 .collect(),
728 ),
729 Column::Categorical { levels, codes } => Column::Categorical {
730 levels: levels.clone(),
731 codes: codes
732 .iter()
733 .zip(mask)
734 .filter(|(_, &m)| m)
735 .map(|(v, _)| *v)
736 .collect(),
737 },
738 Column::DateTime(v) => Column::DateTime(
739 v.iter()
740 .zip(mask)
741 .filter(|(_, &m)| m)
742 .map(|(v, _)| *v)
743 .collect(),
744 ),
745 }
746}
747
748fn execute_aggregate(
749 df: &DataFrame,
750 keys: &[String],
751 aggs: &[(String, DExpr)],
752) -> Result<DataFrame, DataError> {
753 let nrows = df.nrows();
755 let mut groups: BTreeMap<Vec<String>, Vec<usize>> = BTreeMap::new();
756
757 for row in 0..nrows {
758 let key: Vec<String> = keys
759 .iter()
760 .map(|k| {
761 df.get_column(k)
762 .map(|col| col.get_display(row))
763 .ok_or_else(|| DataError::ColumnNotFound(k.to_string()))
764 })
765 .collect::<Result<Vec<String>, DataError>>()?;
766 groups.entry(key).or_default().push(row);
767 }
768
769 let mut sorted_groups: Vec<(Vec<String>, Vec<usize>)> = groups.into_iter().collect();
771 sorted_groups.sort_by(|a, b| a.0.cmp(&b.0));
772
773 let mut result_columns: Vec<(String, Column)> = Vec::new();
775
776 for (key_idx, key_name) in keys.iter().enumerate() {
778 let values: Vec<String> = sorted_groups
779 .iter()
780 .map(|(key, _)| key[key_idx].clone())
781 .collect();
782 let source_col = df.get_column(key_name).ok_or_else(|| {
784 DataError::ColumnNotFound(key_name.clone())
785 })?;
786 match source_col {
787 Column::Int(_) => {
788 let int_vals: Vec<i64> = values.iter().map(|s| s.parse().unwrap_or(0)).collect();
789 result_columns.push((key_name.clone(), Column::Int(int_vals)));
790 }
791 Column::Str(_) => {
792 result_columns.push((key_name.clone(), Column::Str(values)));
793 }
794 _ => {
795 result_columns.push((key_name.clone(), Column::Str(values)));
796 }
797 }
798 }
799
800 for (agg_name, agg_expr) in aggs {
802 let mut values = Vec::new();
803 for (_, row_indices) in &sorted_groups {
804 let val = eval_agg_expr(df, agg_expr, row_indices)?;
805 values.push(val);
806 }
807 result_columns.push((agg_name.clone(), Column::Float(values)));
808 }
809
810 Ok(DataFrame {
811 columns: result_columns,
812 })
813}
814
815#[derive(Debug, Clone)]
818enum ExprValue {
819 Int(i64),
820 Float(f64),
821 Str(String),
822 Bool(bool),
823}
824
825fn eval_expr_row(df: &DataFrame, expr: &DExpr, row: usize) -> Result<ExprValue, DataError> {
826 match expr {
827 DExpr::Col(name) => {
828 let col = df
829 .get_column(name)
830 .ok_or_else(|| DataError::ColumnNotFound(name.clone()))?;
831 match col {
832 Column::Int(v) => Ok(ExprValue::Int(v[row])),
833 Column::Float(v) => Ok(ExprValue::Float(v[row])),
834 Column::Str(v) => Ok(ExprValue::Str(v[row].clone())),
835 Column::Bool(v) => Ok(ExprValue::Bool(v[row])),
836 Column::Categorical { levels, codes } => {
837 Ok(ExprValue::Str(levels[codes[row] as usize].clone()))
838 }
839 Column::DateTime(v) => Ok(ExprValue::Int(v[row])),
840 }
841 }
842 DExpr::LitInt(v) => Ok(ExprValue::Int(*v)),
843 DExpr::LitFloat(v) => Ok(ExprValue::Float(*v)),
844 DExpr::LitBool(b) => Ok(ExprValue::Bool(*b)),
845 DExpr::LitStr(s) => Ok(ExprValue::Str(s.clone())),
846 DExpr::BinOp { op, left, right } => {
847 let lv = eval_expr_row(df, left, row)?;
848 let rv = eval_expr_row(df, right, row)?;
849 eval_binop(*op, lv, rv)
850 }
851 DExpr::Agg(_, _) | DExpr::Count => Err(DataError::InvalidOperation(
852 "aggregation not allowed in row context".into(),
853 )),
854 DExpr::FnCall(name, args) => {
855 if args.len() != 1 {
856 return Err(DataError::InvalidOperation(
857 format!("FnCall '{}' requires exactly 1 argument, got {}", name, args.len()),
858 ));
859 }
860 let val = eval_expr_row(df, &args[0], row)?;
861 let x = match val {
862 ExprValue::Float(f) => f,
863 ExprValue::Int(i) => i as f64,
864 _ => return Err(DataError::InvalidOperation(
865 format!("FnCall '{}' requires numeric argument", name),
866 )),
867 };
868 let result = match name.as_str() {
869 "log" => x.ln(),
870 "exp" => x.exp(),
871 "sqrt" => x.sqrt(),
872 "abs" => x.abs(),
873 "ceil" => x.ceil(),
874 "floor" => x.floor(),
875 "round" => x.round(),
876 "sin" => x.sin(),
877 "cos" => x.cos(),
878 "tan" => x.tan(),
879 other => return Err(DataError::InvalidOperation(
880 format!("unknown DExpr function: {}", other),
881 )),
882 };
883 Ok(ExprValue::Float(result))
884 }
885 DExpr::CumSum(_) | DExpr::CumProd(_) | DExpr::CumMax(_) | DExpr::CumMin(_)
886 | DExpr::Lag(_, _) | DExpr::Lead(_, _) | DExpr::Rank(_) | DExpr::DenseRank(_)
887 | DExpr::RowNumber
888 | DExpr::RollingSum(..) | DExpr::RollingMean(..) | DExpr::RollingMin(..)
889 | DExpr::RollingMax(..) | DExpr::RollingVar(..) | DExpr::RollingSd(..) => {
890 Err(DataError::InvalidOperation(
891 "window function not allowed in row context; use eval_expr_column".into(),
892 ))
893 }
894 }
895}
896
897fn eval_binop(op: DBinOp, left: ExprValue, right: ExprValue) -> Result<ExprValue, DataError> {
898 match (left, right) {
899 (ExprValue::Int(a), ExprValue::Int(b)) => match op {
900 DBinOp::Add => Ok(ExprValue::Int(a + b)),
901 DBinOp::Sub => Ok(ExprValue::Int(a - b)),
902 DBinOp::Mul => Ok(ExprValue::Int(a * b)),
903 DBinOp::Div => Ok(ExprValue::Int(a / b)),
904 DBinOp::Gt => Ok(ExprValue::Bool(a > b)),
905 DBinOp::Lt => Ok(ExprValue::Bool(a < b)),
906 DBinOp::Ge => Ok(ExprValue::Bool(a >= b)),
907 DBinOp::Le => Ok(ExprValue::Bool(a <= b)),
908 DBinOp::Eq => Ok(ExprValue::Bool(a == b)),
909 DBinOp::Ne => Ok(ExprValue::Bool(a != b)),
910 _ => Err(DataError::InvalidOperation(format!(
911 "unsupported op {:?} on Int",
912 op
913 ))),
914 },
915 (ExprValue::Float(a), ExprValue::Float(b)) => match op {
916 DBinOp::Add => Ok(ExprValue::Float(a + b)),
917 DBinOp::Sub => Ok(ExprValue::Float(a - b)),
918 DBinOp::Mul => Ok(ExprValue::Float(a * b)),
919 DBinOp::Div => Ok(ExprValue::Float(a / b)),
920 DBinOp::Gt => Ok(ExprValue::Bool(a > b)),
921 DBinOp::Lt => Ok(ExprValue::Bool(a < b)),
922 DBinOp::Ge => Ok(ExprValue::Bool(a >= b)),
923 DBinOp::Le => Ok(ExprValue::Bool(a <= b)),
924 DBinOp::Eq => Ok(ExprValue::Bool(a == b)),
925 DBinOp::Ne => Ok(ExprValue::Bool(a != b)),
926 _ => Err(DataError::InvalidOperation(format!(
927 "unsupported op {:?} on Float",
928 op
929 ))),
930 },
931 (ExprValue::Int(a), ExprValue::Float(b)) => {
933 eval_binop(op, ExprValue::Float(a as f64), ExprValue::Float(b))
934 }
935 (ExprValue::Float(a), ExprValue::Int(b)) => {
936 eval_binop(op, ExprValue::Float(a), ExprValue::Float(b as f64))
937 }
938 (ExprValue::Bool(a), ExprValue::Bool(b)) => match op {
939 DBinOp::And => Ok(ExprValue::Bool(a && b)),
940 DBinOp::Or => Ok(ExprValue::Bool(a || b)),
941 DBinOp::Eq => Ok(ExprValue::Bool(a == b)),
942 DBinOp::Ne => Ok(ExprValue::Bool(a != b)),
943 _ => Err(DataError::InvalidOperation(format!(
944 "unsupported op {:?} on Bool",
945 op
946 ))),
947 },
948 (ExprValue::Str(a), ExprValue::Str(b)) => match op {
949 DBinOp::Eq => Ok(ExprValue::Bool(a == b)),
950 DBinOp::Ne => Ok(ExprValue::Bool(a != b)),
951 _ => Err(DataError::InvalidOperation(format!(
952 "unsupported op {:?} on String",
953 op
954 ))),
955 },
956 _ => Err(DataError::InvalidOperation(
957 "type mismatch in binary operation".into(),
958 )),
959 }
960}
961
962fn eval_agg_expr(
963 df: &DataFrame,
964 expr: &DExpr,
965 rows: &[usize],
966) -> Result<f64, DataError> {
967 match expr {
968 DExpr::Agg(func, inner) => {
969 let values = extract_float_values(df, inner, rows)?;
970 match func {
971 AggFunc::Sum => Ok(kahan_sum_f64(&values)),
972 AggFunc::Mean => {
973 if values.is_empty() {
974 Ok(0.0)
975 } else {
976 Ok(kahan_sum_f64(&values) / values.len() as f64)
977 }
978 }
979 AggFunc::Min => Ok(values
980 .iter()
981 .cloned()
982 .fold(f64::INFINITY, f64::min)),
983 AggFunc::Max => Ok(values
984 .iter()
985 .cloned()
986 .fold(f64::NEG_INFINITY, f64::max)),
987 AggFunc::Count => Ok(values.len() as f64),
988 }
989 }
990 DExpr::Count => Ok(rows.len() as f64),
991 _ => Err(DataError::InvalidOperation(
992 "expected aggregation expression".into(),
993 )),
994 }
995}
996
997fn extract_float_values(
998 df: &DataFrame,
999 expr: &DExpr,
1000 rows: &[usize],
1001) -> Result<Vec<f64>, DataError> {
1002 match expr {
1003 DExpr::Col(name) => {
1004 let col = df
1005 .get_column(name)
1006 .ok_or_else(|| DataError::ColumnNotFound(name.clone()))?;
1007 let vals: Vec<f64> = match col {
1008 Column::Float(v) => rows.iter().map(|&r| v[r]).collect(),
1009 Column::Int(v) => rows.iter().map(|&r| v[r] as f64).collect(),
1010 _ => {
1011 return Err(DataError::InvalidOperation(format!(
1012 "cannot aggregate non-numeric column `{}`",
1013 name
1014 )))
1015 }
1016 };
1017 Ok(vals)
1018 }
1019 _ => Err(DataError::InvalidOperation(
1020 "expected column reference in aggregation".into(),
1021 )),
1022 }
1023}
1024
1025pub struct Pipeline {
1029 plan: LogicalPlan,
1030}
1031
1032impl Pipeline {
1033 pub fn scan(df: DataFrame) -> Self {
1034 Self {
1035 plan: LogicalPlan::Scan { source: df },
1036 }
1037 }
1038
1039 pub fn filter(self, predicate: DExpr) -> Self {
1040 Self {
1041 plan: LogicalPlan::Filter {
1042 input: Box::new(self.plan),
1043 predicate,
1044 },
1045 }
1046 }
1047
1048 pub fn group_by(self, keys: Vec<String>) -> Self {
1049 Self {
1050 plan: LogicalPlan::GroupBy {
1051 input: Box::new(self.plan),
1052 keys,
1053 },
1054 }
1055 }
1056
1057 pub fn summarize(self, keys: Vec<String>, aggs: Vec<(String, DExpr)>) -> Self {
1058 Self {
1059 plan: LogicalPlan::Aggregate {
1060 input: Box::new(self.plan),
1061 keys,
1062 aggs,
1063 },
1064 }
1065 }
1066
1067 pub fn select(self, columns: Vec<String>) -> Self {
1068 Self {
1069 plan: LogicalPlan::Project {
1070 input: Box::new(self.plan),
1071 columns,
1072 },
1073 }
1074 }
1075
1076 pub fn inner_join(self, right: DataFrame, left_on: &str, right_on: &str) -> Self {
1077 Self {
1078 plan: LogicalPlan::InnerJoin {
1079 left: Box::new(self.plan),
1080 right: Box::new(LogicalPlan::Scan { source: right }),
1081 left_on: left_on.to_string(),
1082 right_on: right_on.to_string(),
1083 },
1084 }
1085 }
1086
1087 pub fn left_join(self, right: DataFrame, left_on: &str, right_on: &str) -> Self {
1088 Self {
1089 plan: LogicalPlan::LeftJoin {
1090 left: Box::new(self.plan),
1091 right: Box::new(LogicalPlan::Scan { source: right }),
1092 left_on: left_on.to_string(),
1093 right_on: right_on.to_string(),
1094 },
1095 }
1096 }
1097
1098 pub fn cross_join(self, right: DataFrame) -> Self {
1099 Self {
1100 plan: LogicalPlan::CrossJoin {
1101 left: Box::new(self.plan),
1102 right: Box::new(LogicalPlan::Scan { source: right }),
1103 },
1104 }
1105 }
1106
1107 pub fn collect(self) -> Result<DataFrame, DataError> {
1109 let optimized = optimize(self.plan);
1110 execute(&optimized)
1111 }
1112
1113 pub fn plan(&self) -> &LogicalPlan {
1115 &self.plan
1116 }
1117}
1118
1119#[derive(Debug, Clone)]
1122pub enum DataError {
1123 ColumnNotFound(String),
1124 ColumnLengthMismatch {
1125 expected: usize,
1126 got: usize,
1127 column: String,
1128 },
1129 InvalidOperation(String),
1130}
1131
1132impl fmt::Display for DataError {
1133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1134 match self {
1135 DataError::ColumnNotFound(name) => write!(f, "column `{}` not found", name),
1136 DataError::ColumnLengthMismatch {
1137 expected,
1138 got,
1139 column,
1140 } => write!(
1141 f,
1142 "column `{}` has {} rows, expected {}",
1143 column, got, expected
1144 ),
1145 DataError::InvalidOperation(msg) => write!(f, "invalid operation: {}", msg),
1146 }
1147 }
1148}
1149
1150impl std::error::Error for DataError {}
1151
1152fn column_value_str(col: &Column, row: usize) -> String {
1156 match col {
1157 Column::Int(v) => v[row].to_string(),
1158 Column::Float(v) => v[row].to_string(),
1159 Column::Str(v) => v[row].clone(),
1160 Column::Bool(v) => v[row].to_string(),
1161 Column::Categorical { levels, codes } => levels[codes[row] as usize].clone(),
1162 Column::DateTime(v) => v[row].to_string(),
1163 }
1164}
1165
1166fn execute_inner_join(
1167 left: &DataFrame,
1168 right: &DataFrame,
1169 left_on: &str,
1170 right_on: &str,
1171) -> Result<DataFrame, DataError> {
1172 let left_col = left.get_column(left_on)
1173 .ok_or_else(|| DataError::InvalidOperation(format!("join key `{}` not found in left", left_on)))?;
1174 let right_col = right.get_column(right_on)
1175 .ok_or_else(|| DataError::InvalidOperation(format!("join key `{}` not found in right", right_on)))?;
1176
1177 let right_nrows = right.nrows();
1179 let mut index: std::collections::BTreeMap<String, Vec<usize>> = std::collections::BTreeMap::new();
1180 for i in 0..right_nrows {
1181 let key = column_value_str(right_col, i);
1182 index.entry(key).or_default().push(i);
1183 }
1184
1185 let left_nrows = left.nrows();
1186 let mut left_indices = Vec::new();
1187 let mut right_indices = Vec::new();
1188
1189 for i in 0..left_nrows {
1190 let key = column_value_str(left_col, i);
1191 if let Some(matches) = index.get(&key) {
1192 for &j in matches {
1193 left_indices.push(i);
1194 right_indices.push(j);
1195 }
1196 }
1197 }
1198
1199 build_join_result(left, right, &left_indices, &right_indices, right_on)
1200}
1201
1202fn execute_left_join(
1203 left: &DataFrame,
1204 right: &DataFrame,
1205 left_on: &str,
1206 right_on: &str,
1207) -> Result<DataFrame, DataError> {
1208 let left_col = left.get_column(left_on)
1209 .ok_or_else(|| DataError::InvalidOperation(format!("join key `{}` not found in left", left_on)))?;
1210 let right_col = right.get_column(right_on)
1211 .ok_or_else(|| DataError::InvalidOperation(format!("join key `{}` not found in right", right_on)))?;
1212
1213 let right_nrows = right.nrows();
1214 let mut index: std::collections::BTreeMap<String, Vec<usize>> = std::collections::BTreeMap::new();
1215 for i in 0..right_nrows {
1216 let key = column_value_str(right_col, i);
1217 index.entry(key).or_default().push(i);
1218 }
1219
1220 let left_nrows = left.nrows();
1221 let mut left_indices = Vec::new();
1222 let mut right_indices: Vec<Option<usize>> = Vec::new();
1223
1224 for i in 0..left_nrows {
1225 let key = column_value_str(left_col, i);
1226 if let Some(matches) = index.get(&key) {
1227 for &j in matches {
1228 left_indices.push(i);
1229 right_indices.push(Some(j));
1230 }
1231 } else {
1232 left_indices.push(i);
1233 right_indices.push(None);
1234 }
1235 }
1236
1237 build_left_join_result(left, right, &left_indices, &right_indices, right_on)
1238}
1239
1240fn execute_cross_join(left: &DataFrame, right: &DataFrame) -> Result<DataFrame, DataError> {
1241 let left_nrows = left.nrows();
1242 let right_nrows = right.nrows();
1243 let mut left_indices = Vec::with_capacity(left_nrows * right_nrows);
1244 let mut right_indices = Vec::with_capacity(left_nrows * right_nrows);
1245
1246 for i in 0..left_nrows {
1247 for j in 0..right_nrows {
1248 left_indices.push(i);
1249 right_indices.push(j);
1250 }
1251 }
1252
1253 build_join_result(left, right, &left_indices, &right_indices, "")
1254}
1255
1256fn build_join_result(
1257 left: &DataFrame,
1258 right: &DataFrame,
1259 left_indices: &[usize],
1260 right_indices: &[usize],
1261 right_on: &str,
1262) -> Result<DataFrame, DataError> {
1263 let mut columns = Vec::new();
1264
1265 for (name, col) in &left.columns {
1267 columns.push((name.clone(), gather_column(col, left_indices)));
1268 }
1269
1270 for (name, col) in &right.columns {
1272 if name == right_on {
1273 continue;
1274 }
1275 let out_name = if left.get_column(name).is_some() {
1276 format!("{}_right", name)
1277 } else {
1278 name.clone()
1279 };
1280 columns.push((out_name, gather_column(col, right_indices)));
1281 }
1282
1283 Ok(DataFrame { columns })
1284}
1285
1286fn build_left_join_result(
1287 left: &DataFrame,
1288 right: &DataFrame,
1289 left_indices: &[usize],
1290 right_indices: &[Option<usize>],
1291 right_on: &str,
1292) -> Result<DataFrame, DataError> {
1293 let mut columns = Vec::new();
1294
1295 for (name, col) in &left.columns {
1296 columns.push((name.clone(), gather_column(col, left_indices)));
1297 }
1298
1299 for (name, col) in &right.columns {
1300 if name == right_on {
1301 continue;
1302 }
1303 let out_name = if left.get_column(name).is_some() {
1304 format!("{}_right", name)
1305 } else {
1306 name.clone()
1307 };
1308 columns.push((out_name, gather_column_nullable(col, right_indices)));
1309 }
1310
1311 Ok(DataFrame { columns })
1312}
1313
1314fn gather_column(col: &Column, indices: &[usize]) -> Column {
1315 match col {
1316 Column::Int(v) => Column::Int(indices.iter().map(|&i| v[i]).collect()),
1317 Column::Float(v) => Column::Float(indices.iter().map(|&i| v[i]).collect()),
1318 Column::Str(v) => Column::Str(indices.iter().map(|&i| v[i].clone()).collect()),
1319 Column::Bool(v) => Column::Bool(indices.iter().map(|&i| v[i]).collect()),
1320 Column::Categorical { levels, codes } => Column::Categorical {
1321 levels: levels.clone(),
1322 codes: indices.iter().map(|&i| codes[i]).collect(),
1323 },
1324 Column::DateTime(v) => Column::DateTime(indices.iter().map(|&i| v[i]).collect()),
1325 }
1326}
1327
1328fn gather_column_nullable(col: &Column, indices: &[Option<usize>]) -> Column {
1329 match col {
1330 Column::Int(v) => Column::Int(indices.iter().map(|opt| opt.map_or(0, |i| v[i])).collect()),
1331 Column::Float(v) => Column::Float(indices.iter().map(|opt| opt.map_or(f64::NAN, |i| v[i])).collect()),
1332 Column::Str(v) => Column::Str(indices.iter().map(|opt| opt.map_or_else(String::new, |i| v[i].clone())).collect()),
1333 Column::Bool(v) => Column::Bool(indices.iter().map(|opt| opt.map_or(false, |i| v[i])).collect()),
1334 Column::Categorical { levels, codes } => Column::Categorical {
1335 levels: levels.clone(),
1336 codes: indices.iter().map(|opt| opt.map_or(0, |i| codes[i])).collect(),
1337 },
1338 Column::DateTime(v) => Column::DateTime(indices.iter().map(|opt| opt.map_or(0, |i| v[i])).collect()),
1339 }
1340}
1341
1342#[cfg(test)]
1345mod tests {
1346 use super::*;
1347
1348 fn sample_df() -> DataFrame {
1349 DataFrame::from_columns(vec![
1350 (
1351 "name".into(),
1352 Column::Str(vec![
1353 "Alice".into(),
1354 "Bob".into(),
1355 "Carol".into(),
1356 "Dave".into(),
1357 "Eve".into(),
1358 "Frank".into(),
1359 ]),
1360 ),
1361 (
1362 "dept".into(),
1363 Column::Str(vec![
1364 "eng".into(),
1365 "eng".into(),
1366 "sales".into(),
1367 "eng".into(),
1368 "sales".into(),
1369 "eng".into(),
1370 ]),
1371 ),
1372 (
1373 "salary".into(),
1374 Column::Float(vec![95000.0, 102000.0, 78000.0, 110000.0, 82000.0, 98000.0]),
1375 ),
1376 (
1377 "tenure".into(),
1378 Column::Int(vec![3, 7, 2, 10, 1, 5]),
1379 ),
1380 ])
1381 .unwrap()
1382 }
1383
1384 #[test]
1385 fn test_dataframe_creation() {
1386 let df = sample_df();
1387 assert_eq!(df.nrows(), 6);
1388 assert_eq!(df.ncols(), 4);
1389 assert_eq!(
1390 df.column_names(),
1391 vec!["name", "dept", "salary", "tenure"]
1392 );
1393 }
1394
1395 #[test]
1396 fn test_filter() {
1397 let df = sample_df();
1398
1399 let result = Pipeline::scan(df)
1401 .filter(DExpr::BinOp {
1402 op: DBinOp::Gt,
1403 left: Box::new(DExpr::Col("tenure".into())),
1404 right: Box::new(DExpr::LitInt(2)),
1405 })
1406 .collect()
1407 .unwrap();
1408
1409 assert_eq!(result.nrows(), 4); }
1411
1412 #[test]
1413 fn test_group_by_summarize() {
1414 let df = sample_df();
1415
1416 let result = Pipeline::scan(df)
1417 .summarize(
1418 vec!["dept".into()],
1419 vec![
1420 (
1421 "avg_salary".into(),
1422 DExpr::Agg(AggFunc::Mean, Box::new(DExpr::Col("salary".into()))),
1423 ),
1424 ("headcount".into(), DExpr::Count),
1425 ],
1426 )
1427 .collect()
1428 .unwrap();
1429
1430 assert_eq!(result.nrows(), 2); let dept_col = result.get_column("dept").unwrap();
1434 let avg_col = result.get_column("avg_salary").unwrap();
1435 let count_col = result.get_column("headcount").unwrap();
1436
1437 if let (Column::Str(depts), Column::Float(avgs), Column::Float(counts)) =
1438 (dept_col, avg_col, count_col)
1439 {
1440 let eng_idx = depts.iter().position(|d| d == "eng").unwrap();
1441 let sales_idx = depts.iter().position(|d| d == "sales").unwrap();
1442
1443 assert!((avgs[eng_idx] - 101250.0).abs() < 0.01);
1445 assert!((counts[eng_idx] - 4.0).abs() < 0.01);
1446
1447 assert!((avgs[sales_idx] - 80000.0).abs() < 0.01);
1449 assert!((counts[sales_idx] - 2.0).abs() < 0.01);
1450 } else {
1451 panic!("unexpected column types");
1452 }
1453 }
1454
1455 #[test]
1456 fn test_filter_then_aggregate() {
1457 let df = sample_df();
1458
1459 let result = Pipeline::scan(df)
1461 .filter(DExpr::BinOp {
1462 op: DBinOp::Gt,
1463 left: Box::new(DExpr::Col("tenure".into())),
1464 right: Box::new(DExpr::LitInt(2)),
1465 })
1466 .summarize(
1467 vec!["dept".into()],
1468 vec![
1469 (
1470 "avg_salary".into(),
1471 DExpr::Agg(AggFunc::Mean, Box::new(DExpr::Col("salary".into()))),
1472 ),
1473 (
1474 "max_tenure".into(),
1475 DExpr::Agg(AggFunc::Max, Box::new(DExpr::Col("tenure".into()))),
1476 ),
1477 ("headcount".into(), DExpr::Count),
1478 ],
1479 )
1480 .collect()
1481 .unwrap();
1482
1483 assert_eq!(result.nrows(), 1);
1486
1487 if let Column::Float(avgs) = result.get_column("avg_salary").unwrap() {
1488 assert!((avgs[0] - 101250.0).abs() < 0.01);
1490 }
1491 if let Column::Float(maxes) = result.get_column("max_tenure").unwrap() {
1492 assert!((maxes[0] - 10.0).abs() < 0.01);
1493 }
1494 }
1495
1496 #[test]
1497 fn test_to_tensor_data() {
1498 let df = DataFrame::from_columns(vec![
1499 ("x".into(), Column::Float(vec![1.0, 2.0, 3.0])),
1500 ("y".into(), Column::Float(vec![4.0, 5.0, 6.0])),
1501 ])
1502 .unwrap();
1503
1504 let (data, shape) = df.to_tensor_data(&["x", "y"]).unwrap();
1505 assert_eq!(shape, vec![3, 2]);
1506 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1507 }
1508
1509 #[test]
1510 fn test_display() {
1511 let df = DataFrame::from_columns(vec![
1512 ("x".into(), Column::Int(vec![1, 2, 3])),
1513 ("y".into(), Column::Float(vec![4.5, 5.5, 6.5])),
1514 ])
1515 .unwrap();
1516
1517 let output = format!("{}", df);
1518 assert!(output.contains("x"));
1519 assert!(output.contains("y"));
1520 assert!(output.contains("4.5"));
1521 }
1522
1523 #[test]
1524 fn test_column_not_found() {
1525 let df = sample_df();
1526 let result = Pipeline::scan(df)
1527 .filter(DExpr::BinOp {
1528 op: DBinOp::Gt,
1529 left: Box::new(DExpr::Col("nonexistent".into())),
1530 right: Box::new(DExpr::LitInt(0)),
1531 })
1532 .collect();
1533
1534 assert!(result.is_err());
1535 }
1536
1537 #[test]
1538 fn test_aggregation_functions() {
1539 let df = DataFrame::from_columns(vec![
1540 ("group".into(), Column::Str(vec!["a".into(), "a".into(), "a".into()])),
1541 ("val".into(), Column::Float(vec![10.0, 20.0, 30.0])),
1542 ])
1543 .unwrap();
1544
1545 let result = Pipeline::scan(df)
1546 .summarize(
1547 vec!["group".into()],
1548 vec![
1549 ("total".into(), DExpr::Agg(AggFunc::Sum, Box::new(DExpr::Col("val".into())))),
1550 ("avg".into(), DExpr::Agg(AggFunc::Mean, Box::new(DExpr::Col("val".into())))),
1551 ("lo".into(), DExpr::Agg(AggFunc::Min, Box::new(DExpr::Col("val".into())))),
1552 ("hi".into(), DExpr::Agg(AggFunc::Max, Box::new(DExpr::Col("val".into())))),
1553 ("n".into(), DExpr::Count),
1554 ],
1555 )
1556 .collect()
1557 .unwrap();
1558
1559 if let Column::Float(totals) = result.get_column("total").unwrap() {
1560 assert!((totals[0] - 60.0).abs() < 0.01);
1561 }
1562 if let Column::Float(avgs) = result.get_column("avg").unwrap() {
1563 assert!((avgs[0] - 20.0).abs() < 0.01);
1564 }
1565 if let Column::Float(mins) = result.get_column("lo").unwrap() {
1566 assert!((mins[0] - 10.0).abs() < 0.01);
1567 }
1568 if let Column::Float(maxs) = result.get_column("hi").unwrap() {
1569 assert!((maxs[0] - 30.0).abs() < 0.01);
1570 }
1571 if let Column::Float(counts) = result.get_column("n").unwrap() {
1572 assert!((counts[0] - 3.0).abs() < 0.01);
1573 }
1574 }
1575
1576 #[test]
1577 fn test_empty_dataframe() {
1578 let df = DataFrame::new();
1579 assert_eq!(df.nrows(), 0);
1580 assert_eq!(df.ncols(), 0);
1581 }
1582
1583 #[test]
1584 fn test_expr_display() {
1585 let expr = DExpr::BinOp {
1586 op: DBinOp::Gt,
1587 left: Box::new(DExpr::Col("age".into())),
1588 right: Box::new(DExpr::LitInt(18)),
1589 };
1590 assert_eq!(format!("{}", expr), "(col(\"age\") > 18)");
1591 }
1592
1593 #[test]
1596 fn test_categorical_column_basics() {
1597 let col = Column::Categorical {
1598 levels: vec!["bird".into(), "cat".into(), "dog".into()],
1599 codes: vec![1, 2, 1, 0],
1600 };
1601 assert_eq!(col.len(), 4);
1602 assert_eq!(col.type_name(), "Categorical");
1603 assert_eq!(col.get_display(0), "cat");
1604 assert_eq!(col.get_display(1), "dog");
1605 assert_eq!(col.get_display(2), "cat");
1606 assert_eq!(col.get_display(3), "bird");
1607 }
1608
1609 #[test]
1610 fn test_datetime_column_basics() {
1611 let col = Column::DateTime(vec![1000, 2000, 3000]);
1612 assert_eq!(col.len(), 3);
1613 assert_eq!(col.type_name(), "DateTime");
1614 assert_eq!(col.get_display(0), "1000ms");
1615 assert_eq!(col.get_display(1), "2000ms");
1616 }
1617
1618 #[test]
1619 fn test_label_encode() {
1620 let data: Vec<String> = vec!["cat".into(), "dog".into(), "cat".into(), "bird".into()];
1621 let (levels, codes) = label_encode(&data);
1622 assert_eq!(levels, vec!["bird", "cat", "dog"]);
1623 assert_eq!(codes, vec![1, 2, 1, 0]);
1624 }
1625
1626 #[test]
1627 fn test_label_encode_empty() {
1628 let data: Vec<String> = vec![];
1629 let (levels, codes) = label_encode(&data);
1630 assert!(levels.is_empty());
1631 assert!(codes.is_empty());
1632 }
1633
1634 #[test]
1635 fn test_label_encode_single_level() {
1636 let data: Vec<String> = vec!["x".into(), "x".into(), "x".into()];
1637 let (levels, codes) = label_encode(&data);
1638 assert_eq!(levels, vec!["x"]);
1639 assert_eq!(codes, vec![0, 0, 0]);
1640 }
1641
1642 #[test]
1643 fn test_label_encode_deterministic() {
1644 let data: Vec<String> = vec!["z".into(), "a".into(), "m".into(), "a".into(), "z".into()];
1646 let (levels1, codes1) = label_encode(&data);
1647 let (levels2, codes2) = label_encode(&data);
1648 assert_eq!(levels1, levels2);
1649 assert_eq!(codes1, codes2);
1650 assert_eq!(levels1, vec!["a", "m", "z"]);
1652 }
1653
1654 #[test]
1655 fn test_ordinal_encode() {
1656 let data: Vec<String> = vec!["low".into(), "high".into(), "medium".into(), "low".into()];
1657 let order: Vec<String> = vec!["low".into(), "medium".into(), "high".into()];
1658 let (levels, codes) = ordinal_encode(&data, &order).unwrap();
1659 assert_eq!(levels, vec!["low", "medium", "high"]);
1660 assert_eq!(codes, vec![0, 2, 1, 0]);
1661 }
1662
1663 #[test]
1664 fn test_ordinal_encode_missing_value() {
1665 let data: Vec<String> = vec!["low".into(), "unknown".into()];
1666 let order: Vec<String> = vec!["low".into(), "medium".into(), "high".into()];
1667 let result = ordinal_encode(&data, &order);
1668 assert!(result.is_err());
1669 assert!(result.unwrap_err().contains("unknown"));
1670 }
1671
1672 #[test]
1673 fn test_one_hot_encode() {
1674 let levels = vec!["bird".to_string(), "cat".to_string(), "dog".to_string()];
1675 let codes = vec![1u32, 2, 1, 0];
1676 let (names, cols) = one_hot_encode(&levels, &codes);
1677 assert_eq!(names, vec!["bird", "cat", "dog"]);
1678 assert_eq!(cols.len(), 3);
1679 assert_eq!(cols[0], vec![false, false, false, true]);
1681 assert_eq!(cols[1], vec![true, false, true, false]);
1683 assert_eq!(cols[2], vec![false, true, false, false]);
1685
1686 for row in 0..4 {
1688 let count: usize = cols.iter().map(|c| if c[row] { 1 } else { 0 }).sum();
1689 assert_eq!(count, 1, "row {} should have exactly one true", row);
1690 }
1691 }
1692
1693 #[test]
1694 fn test_one_hot_encode_empty() {
1695 let levels = vec!["a".to_string(), "b".to_string()];
1696 let codes: Vec<u32> = vec![];
1697 let (names, cols) = one_hot_encode(&levels, &codes);
1698 assert_eq!(names.len(), 2);
1699 assert!(cols[0].is_empty());
1700 assert!(cols[1].is_empty());
1701 }
1702
1703 #[test]
1704 fn test_categorical_column_in_dataframe() {
1705 let data: Vec<String> = vec!["cat".into(), "dog".into(), "cat".into()];
1706 let (levels, codes) = label_encode(&data);
1707 let df = DataFrame::from_columns(vec![
1708 ("animal".into(), Column::Categorical { levels, codes }),
1709 ("score".into(), Column::Float(vec![1.0, 2.0, 3.0])),
1710 ])
1711 .unwrap();
1712 assert_eq!(df.nrows(), 3);
1713 assert_eq!(df.ncols(), 2);
1714 assert_eq!(df.get_column("animal").unwrap().type_name(), "Categorical");
1715 }
1716
1717 #[test]
1718 fn test_datetime_column_in_dataframe() {
1719 let df = DataFrame::from_columns(vec![
1720 ("ts".into(), Column::DateTime(vec![1000, 2000, 3000])),
1721 ("val".into(), Column::Float(vec![1.0, 2.0, 3.0])),
1722 ])
1723 .unwrap();
1724 assert_eq!(df.nrows(), 3);
1725 assert_eq!(df.get_column("ts").unwrap().type_name(), "DateTime");
1726 }
1727
1728 #[test]
1729 fn test_label_encode_to_column_roundtrip() {
1730 let data: Vec<String> = vec!["cat".into(), "dog".into(), "cat".into(), "bird".into()];
1731 let (levels, codes) = label_encode(&data);
1732 let col = Column::Categorical { levels: levels.clone(), codes: codes.clone() };
1733 for (i, original) in data.iter().enumerate() {
1735 assert_eq!(col.get_display(i), *original);
1736 }
1737 }
1738}
1739
1740impl DataFrame {
1745 pub fn to_tensor(
1751 &self,
1752 col_names: &[&str],
1753 ) -> Result<cjc_runtime::Tensor, DataError> {
1754 let (data, shape) = self.to_tensor_data(col_names)?;
1755 cjc_runtime::Tensor::from_vec(data, &shape)
1756 .map_err(|e| DataError::InvalidOperation(format!("tensor conversion: {}", e)))
1757 }
1758
1759 pub fn push_row(&mut self, values: &[&str]) -> Result<(), DataError> {
1768 if values.len() != self.ncols() {
1769 return Err(DataError::ColumnLengthMismatch {
1770 expected: self.ncols(),
1771 got: values.len(),
1772 column: "row".to_string(),
1773 });
1774 }
1775 for (i, (_, col)) in self.columns.iter_mut().enumerate() {
1776 let s = values[i];
1777 match col {
1778 Column::Float(v) => v.push(s.trim().parse::<f64>().unwrap_or(0.0)),
1779 Column::Int(v) => v.push(s.trim().parse::<i64>().unwrap_or(0)),
1780 Column::Str(v) => v.push(s.to_string()),
1781 Column::Bool(v) => v.push(matches!(s.trim(), "true" | "1")),
1782 Column::Categorical { .. } => {
1783 }
1785 Column::DateTime(v) => v.push(s.trim().parse::<i64>().unwrap_or(0)),
1786 }
1787 }
1788 Ok(())
1789 }
1790}
1791
1792#[derive(Debug, Clone, PartialEq, Eq)]
1808pub struct BitMask {
1809 words: Vec<u64>,
1810 nrows: usize,
1811}
1812
1813impl BitMask {
1814 pub fn all_true(nrows: usize) -> Self {
1816 let nwords = nwords_for(nrows);
1817 let mut words = vec![u64::MAX; nwords];
1818 if nrows % 64 != 0 && nwords > 0 {
1820 let tail = nrows % 64;
1821 words[nwords - 1] = (1u64 << tail) - 1;
1822 }
1823 BitMask { words, nrows }
1824 }
1825
1826 pub fn all_false(nrows: usize) -> Self {
1828 let nwords = nwords_for(nrows);
1829 BitMask {
1830 words: vec![0u64; nwords],
1831 nrows,
1832 }
1833 }
1834
1835 pub fn from_bools(bools: &[bool]) -> Self {
1837 let nrows = bools.len();
1838 let nwords = nwords_for(nrows);
1839 let mut words = vec![0u64; nwords];
1840 for (i, &b) in bools.iter().enumerate() {
1841 if b {
1842 words[i / 64] |= 1u64 << (i % 64);
1843 }
1844 }
1845 BitMask { words, nrows }
1846 }
1847
1848 #[inline]
1850 pub fn get(&self, i: usize) -> bool {
1851 debug_assert!(i < self.nrows);
1852 (self.words[i / 64] >> (i % 64)) & 1 == 1
1853 }
1854
1855 pub fn count_ones(&self) -> usize {
1857 self.words.iter().map(|w| w.count_ones() as usize).sum()
1858 }
1859
1860 pub fn and(&self, other: &BitMask) -> BitMask {
1864 assert_eq!(
1865 self.nrows, other.nrows,
1866 "BitMask::and: nrows mismatch ({} vs {})",
1867 self.nrows, other.nrows
1868 );
1869 let words = self
1870 .words
1871 .iter()
1872 .zip(other.words.iter())
1873 .map(|(a, b)| a & b)
1874 .collect();
1875 BitMask {
1876 words,
1877 nrows: self.nrows,
1878 }
1879 }
1880
1881 pub fn iter_set(&self) -> impl Iterator<Item = usize> + '_ {
1883 (0..self.nrows).filter(move |&i| self.get(i))
1884 }
1885
1886 pub fn nrows(&self) -> usize {
1887 self.nrows
1888 }
1889
1890 pub fn nwords(&self) -> usize {
1892 self.words.len()
1893 }
1894}
1895
1896#[inline]
1897fn nwords_for(nrows: usize) -> usize {
1898 (nrows + 63) / 64
1899}
1900
1901#[derive(Debug, Clone, PartialEq, Eq)]
1909pub struct ProjectionMap {
1910 indices: Vec<usize>,
1912}
1913
1914impl ProjectionMap {
1915 pub fn identity(ncols: usize) -> Self {
1917 ProjectionMap {
1918 indices: (0..ncols).collect(),
1919 }
1920 }
1921
1922 pub fn from_indices(indices: Vec<usize>) -> Self {
1924 ProjectionMap { indices }
1925 }
1926
1927 pub fn len(&self) -> usize {
1928 self.indices.len()
1929 }
1930
1931 pub fn is_empty(&self) -> bool {
1932 self.indices.is_empty()
1933 }
1934
1935 pub fn indices(&self) -> &[usize] {
1936 &self.indices
1937 }
1938}
1939
1940#[derive(Debug, Clone)]
1951pub struct TidyView {
1952 base: Rc<DataFrame>,
1953 mask: BitMask,
1954 proj: ProjectionMap,
1955}
1956
1957fn try_eval_predicate_columnar(
1971 base: &DataFrame,
1972 predicate: &DExpr,
1973 existing_mask: &BitMask,
1974) -> Option<BitMask> {
1975 match predicate {
1976 DExpr::BinOp {
1978 op: DBinOp::And,
1979 left,
1980 right,
1981 } => {
1982 let lmask = try_eval_predicate_columnar(base, left, existing_mask)?;
1983 let rmask = try_eval_predicate_columnar(base, right, &lmask)?;
1984 Some(rmask)
1985 }
1986 DExpr::BinOp {
1989 op: DBinOp::Or,
1990 left,
1991 right,
1992 } => {
1993 let all_mask = BitMask::all_true(existing_mask.nrows);
1996 let lmask = try_eval_predicate_columnar(base, left, &all_mask)?;
1997 let rmask = try_eval_predicate_columnar(base, right, &all_mask)?;
1998 let nrows = existing_mask.nrows;
2000 let or_words: Vec<u64> = lmask
2001 .words
2002 .iter()
2003 .zip(rmask.words.iter())
2004 .map(|(a, b)| a | b)
2005 .collect();
2006 let final_words: Vec<u64> = or_words
2008 .iter()
2009 .zip(existing_mask.words.iter())
2010 .map(|(a, b)| a & b)
2011 .collect();
2012 Some(BitMask {
2013 words: final_words,
2014 nrows,
2015 })
2016 }
2017 DExpr::BinOp { op, left, right } => {
2019 if !matches!(
2021 op,
2022 DBinOp::Gt | DBinOp::Lt | DBinOp::Ge | DBinOp::Le | DBinOp::Eq | DBinOp::Ne
2023 ) {
2024 return None;
2025 }
2026
2027 enum LitVal {
2030 F(f64),
2031 I(i64),
2032 }
2033
2034 let (col_name, lit, reversed) = match (left.as_ref(), right.as_ref()) {
2035 (DExpr::Col(name), DExpr::LitFloat(v)) => (name.as_str(), LitVal::F(*v), false),
2036 (DExpr::LitFloat(v), DExpr::Col(name)) => (name.as_str(), LitVal::F(*v), true),
2037 (DExpr::Col(name), DExpr::LitInt(v)) => (name.as_str(), LitVal::I(*v), false),
2038 (DExpr::LitInt(v), DExpr::Col(name)) => (name.as_str(), LitVal::I(*v), true),
2039 _ => return None,
2040 };
2041
2042 let column = base.get_column(col_name)?;
2043
2044 let effective_op = if reversed {
2046 match op {
2047 DBinOp::Gt => DBinOp::Lt,
2048 DBinOp::Lt => DBinOp::Gt,
2049 DBinOp::Ge => DBinOp::Le,
2050 DBinOp::Le => DBinOp::Ge,
2051 other => *other, }
2053 } else {
2054 *op
2055 };
2056
2057 let nrows = existing_mask.nrows;
2058 let nwords = nwords_for(nrows);
2059 let mut words = vec![0u64; nwords];
2060
2061 match (column, &lit) {
2062 (Column::Float(data), LitVal::F(v)) => {
2064 columnar_cmp_f64(data, *v, effective_op, &mut words);
2065 }
2066 (Column::Float(data), LitVal::I(v)) => {
2068 columnar_cmp_f64(data, *v as f64, effective_op, &mut words);
2069 }
2070 (Column::Int(data), LitVal::I(v)) => {
2072 columnar_cmp_i64(data, *v, effective_op, &mut words);
2073 }
2074 (Column::Int(data), LitVal::F(v)) => {
2076 let floats: Vec<f64> = data.iter().map(|&x| x as f64).collect();
2078 columnar_cmp_f64(&floats, *v, effective_op, &mut words);
2079 }
2080 _ => return None,
2081 }
2082
2083 for (w, ew) in words.iter_mut().zip(existing_mask.words.iter()) {
2085 *w &= *ew;
2086 }
2087
2088 Some(BitMask { words, nrows })
2089 }
2090 _ => None,
2091 }
2092}
2093
2094#[inline]
2098fn columnar_cmp_f64(data: &[f64], lit: f64, op: DBinOp, out_words: &mut [u64]) {
2099 for (i, &val) in data.iter().enumerate() {
2100 let pass = match op {
2101 DBinOp::Gt => val > lit,
2102 DBinOp::Lt => val < lit,
2103 DBinOp::Ge => val >= lit,
2104 DBinOp::Le => val <= lit,
2105 DBinOp::Eq => val == lit,
2106 DBinOp::Ne => val != lit,
2107 _ => false,
2108 };
2109 if pass {
2110 out_words[i / 64] |= 1u64 << (i % 64);
2111 }
2112 }
2113}
2114
2115#[inline]
2118fn columnar_cmp_i64(data: &[i64], lit: i64, op: DBinOp, out_words: &mut [u64]) {
2119 for (i, &val) in data.iter().enumerate() {
2120 let pass = match op {
2121 DBinOp::Gt => val > lit,
2122 DBinOp::Lt => val < lit,
2123 DBinOp::Ge => val >= lit,
2124 DBinOp::Le => val <= lit,
2125 DBinOp::Eq => val == lit,
2126 DBinOp::Ne => val != lit,
2127 _ => false,
2128 };
2129 if pass {
2130 out_words[i / 64] |= 1u64 << (i % 64);
2131 }
2132 }
2133}
2134
2135impl TidyView {
2136 pub fn from_df(df: DataFrame) -> Self {
2140 let nrows = df.nrows();
2141 let ncols = df.ncols();
2142 TidyView {
2143 base: Rc::new(df),
2144 mask: BitMask::all_true(nrows),
2145 proj: ProjectionMap::identity(ncols),
2146 }
2147 }
2148
2149 pub fn from_rc(df: Rc<DataFrame>) -> Self {
2151 let nrows = df.nrows();
2152 let ncols = df.ncols();
2153 TidyView {
2154 base: df,
2155 mask: BitMask::all_true(nrows),
2156 proj: ProjectionMap::identity(ncols),
2157 }
2158 }
2159
2160 pub fn nrows(&self) -> usize {
2164 self.mask.count_ones()
2165 }
2166
2167 pub fn ncols(&self) -> usize {
2169 self.proj.len()
2170 }
2171
2172 pub fn column_names(&self) -> Vec<&str> {
2174 self.proj
2175 .indices()
2176 .iter()
2177 .map(|&ci| self.base.columns[ci].0.as_str())
2178 .collect()
2179 }
2180
2181 pub fn filter(&self, predicate: &DExpr) -> Result<TidyView, TidyError> {
2194 validate_expr_columns_proj(predicate, &self.base, &self.proj)?;
2196
2197 if let Some(new_mask) = try_eval_predicate_columnar(&self.base, predicate, &self.mask) {
2199 return Ok(TidyView {
2200 base: Rc::clone(&self.base),
2201 mask: new_mask,
2202 proj: self.proj.clone(),
2203 });
2204 }
2205
2206 let nrows_base = self.base.nrows();
2208 let mut new_words = self.mask.words.clone();
2209
2210 for row in self.mask.iter_set() {
2213 let b = eval_expr_row_proj(&self.base, predicate, row, &self.proj)?;
2214 let pass = match b {
2215 ExprValue::Bool(v) => v,
2216 _ => {
2217 return Err(TidyError::PredicateNotBool {
2218 got: b.type_name().to_string(),
2219 })
2220 }
2221 };
2222 if !pass {
2223 new_words[row / 64] &= !(1u64 << (row % 64));
2224 }
2225 }
2226
2227 Ok(TidyView {
2228 base: Rc::clone(&self.base),
2229 mask: BitMask {
2230 words: new_words,
2231 nrows: nrows_base,
2232 },
2233 proj: self.proj.clone(),
2234 })
2235 }
2236
2237 pub fn select(&self, cols: &[&str]) -> Result<TidyView, TidyError> {
2250 {
2252 let mut seen = std::collections::BTreeSet::new();
2253 for &name in cols {
2254 if !seen.insert(name) {
2255 return Err(TidyError::DuplicateColumn(name.to_string()));
2256 }
2257 }
2258 }
2259
2260 let mut new_indices = Vec::with_capacity(cols.len());
2262 for &name in cols {
2263 let idx = self
2264 .base
2265 .columns
2266 .iter()
2267 .position(|(n, _)| n == name)
2268 .ok_or_else(|| TidyError::ColumnNotFound(name.to_string()))?;
2269 new_indices.push(idx);
2270 }
2271
2272 Ok(TidyView {
2273 base: Rc::clone(&self.base),
2274 mask: self.mask.clone(),
2275 proj: ProjectionMap::from_indices(new_indices),
2276 })
2277 }
2278
2279 pub fn mutate(&self, assignments: &[(&str, DExpr)]) -> Result<TidyFrame, TidyError> {
2299 {
2301 let mut seen = std::collections::BTreeSet::new();
2302 for &(name, _) in assignments {
2303 if !seen.insert(name) {
2304 return Err(TidyError::DuplicateColumn(name.to_string()));
2305 }
2306 }
2307 }
2308
2309 let mut df = self.materialize()?;
2311
2312 let snapshot_names: Vec<String> = df.columns.iter().map(|(n, _)| n.clone()).collect();
2314
2315 for &(col_name, ref expr) in assignments {
2316 validate_expr_columns_snapshot(expr, &snapshot_names)?;
2318
2319 let nrows = df.nrows();
2320 let new_col = eval_expr_column(&df, expr, nrows)?;
2322
2323 if let Some(pos) = df.columns.iter().position(|(n, _)| n == col_name) {
2325 df.columns[pos].1 = new_col;
2326 } else {
2327 df.columns.push((col_name.to_string(), new_col));
2328 }
2329 }
2330
2331 Ok(TidyFrame {
2332 inner: Rc::new(RefCell::new(df)),
2333 })
2334 }
2335
2336 pub fn materialize(&self) -> Result<DataFrame, TidyError> {
2348 let row_indices: Vec<usize> = self.mask.iter_set().collect();
2349
2350 let mut columns = Vec::with_capacity(self.proj.len());
2351 for &ci in self.proj.indices() {
2352 let (name, col) = &self.base.columns[ci];
2353 let new_col = gather_column(col, &row_indices);
2354 columns.push((name.clone(), new_col));
2355 }
2356
2357 DataFrame::from_columns(columns)
2358 .map_err(|e| TidyError::Internal(e.to_string()))
2359 }
2360
2361 pub fn to_tensor(&self, col_names: &[&str]) -> Result<cjc_runtime::Tensor, TidyError> {
2365 let df = self.materialize()?;
2366 df.to_tensor(col_names)
2367 .map_err(|e| TidyError::Internal(e.to_string()))
2368 }
2369
2370 pub fn mask(&self) -> &BitMask {
2372 &self.mask
2373 }
2374
2375 pub fn proj(&self) -> &ProjectionMap {
2377 &self.proj
2378 }
2379
2380 pub fn base_column(&self, name: &str) -> Option<&Column> {
2385 self.base.columns.iter()
2386 .find(|(n, _)| n == name)
2387 .map(|(_, c)| c)
2388 }
2389}
2390
2391#[derive(Debug, Clone)]
2399pub struct TidyFrame {
2400 inner: Rc<RefCell<DataFrame>>,
2401}
2402
2403impl TidyFrame {
2404 pub fn from_df(df: DataFrame) -> Self {
2406 TidyFrame {
2407 inner: Rc::new(RefCell::new(df)),
2408 }
2409 }
2410
2411 pub fn borrow(&self) -> std::cell::Ref<'_, DataFrame> {
2413 self.inner.borrow()
2414 }
2415
2416 pub fn view(&self) -> TidyView {
2418 let df = self.inner.borrow().clone();
2419 TidyView::from_df(df)
2420 }
2421
2422 pub fn mutate(&mut self, assignments: &[(&str, DExpr)]) -> Result<(), TidyError> {
2424 if Rc::strong_count(&self.inner) > 1 {
2426 let cloned = self.inner.borrow().clone();
2427 self.inner = Rc::new(RefCell::new(cloned));
2428 }
2429
2430 {
2432 let mut seen = std::collections::BTreeSet::new();
2433 for &(name, _) in assignments {
2434 if !seen.insert(name) {
2435 return Err(TidyError::DuplicateColumn(name.to_string()));
2436 }
2437 }
2438 }
2439
2440 let mut df = self.inner.borrow_mut();
2441
2442 let snapshot_names: Vec<String> = df.columns.iter().map(|(n, _)| n.clone()).collect();
2444
2445 for &(col_name, ref expr) in assignments {
2446 validate_expr_columns_snapshot(expr, &snapshot_names)?;
2447
2448 let nrows = df.nrows();
2449 let new_col = eval_expr_column(&df, expr, nrows)?;
2450
2451 if let Some(pos) = df.columns.iter().position(|(n, _)| n == col_name) {
2452 df.columns[pos].1 = new_col;
2453 } else {
2454 df.columns.push((col_name.to_string(), new_col));
2455 }
2456 }
2457
2458 Ok(())
2459 }
2460}
2461
2462#[derive(Debug, Clone, PartialEq)]
2466pub enum TidyError {
2467 ColumnNotFound(String),
2469 DuplicateColumn(String),
2471 PredicateNotBool { got: String },
2473 TypeMismatch { expected: String, got: String },
2475 LengthMismatch { expected: usize, got: usize },
2477 Internal(String),
2479 EmptyGroup,
2481 CapacityExceeded { limit: usize, got: usize },
2483}
2484
2485impl fmt::Display for TidyError {
2486 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2487 match self {
2488 TidyError::ColumnNotFound(n) => write!(f, "column `{}` not found", n),
2489 TidyError::DuplicateColumn(n) => write!(f, "duplicate column `{}`", n),
2490 TidyError::PredicateNotBool { got } => {
2491 write!(f, "filter predicate must be Bool, got {}", got)
2492 }
2493 TidyError::TypeMismatch { expected, got } => {
2494 write!(f, "type mismatch: expected {}, got {}", expected, got)
2495 }
2496 TidyError::LengthMismatch { expected, got } => {
2497 write!(
2498 f,
2499 "length mismatch: expected {} rows, got {}",
2500 expected, got
2501 )
2502 }
2503 TidyError::Internal(msg) => write!(f, "internal error: {}", msg),
2504 TidyError::EmptyGroup => write!(f, "aggregation on empty group"),
2505 TidyError::CapacityExceeded { limit, got } => {
2506 write!(f, "factor capacity exceeded: limit {} distinct levels, got {}", limit, got)
2507 }
2508 }
2509 }
2510}
2511
2512impl std::error::Error for TidyError {}
2513
2514fn extract_f64_column(df: &DataFrame, expr: &DExpr, nrows: usize) -> Result<Vec<f64>, TidyError> {
2526 let col = eval_expr_column(df, expr, nrows)?;
2527 match col {
2528 Column::Float(v) => Ok(v),
2529 Column::Int(v) => Ok(v.into_iter().map(|i| i as f64).collect()),
2530 _ => Err(TidyError::TypeMismatch {
2531 expected: "numeric".into(),
2532 got: "non-numeric".into(),
2533 }),
2534 }
2535}
2536
2537fn eval_window_column(
2540 df: &DataFrame,
2541 expr: &DExpr,
2542 nrows: usize,
2543) -> Result<Option<Column>, TidyError> {
2544 match expr {
2545 DExpr::RowNumber => {
2546 let vals: Vec<i64> = (1..=nrows as i64).collect();
2547 Ok(Some(Column::Int(vals)))
2548 }
2549 DExpr::CumSum(inner) => {
2550 let src = extract_f64_column(df, inner, nrows)?;
2551 let mut result = Vec::with_capacity(nrows);
2552 let mut sum = 0.0_f64;
2553 let mut comp = 0.0_f64; for &v in &src {
2555 let y = v - comp;
2556 let t = sum + y;
2557 comp = (t - sum) - y;
2558 sum = t;
2559 result.push(sum);
2560 }
2561 Ok(Some(Column::Float(result)))
2562 }
2563 DExpr::CumProd(inner) => {
2564 let src = extract_f64_column(df, inner, nrows)?;
2565 let mut result = Vec::with_capacity(nrows);
2566 let mut prod = 1.0_f64;
2567 for &v in &src {
2568 prod *= v;
2569 result.push(prod);
2570 }
2571 Ok(Some(Column::Float(result)))
2572 }
2573 DExpr::CumMax(inner) => {
2574 let src = extract_f64_column(df, inner, nrows)?;
2575 let mut result = Vec::with_capacity(nrows);
2576 let mut max = f64::NEG_INFINITY;
2577 for &v in &src {
2578 if v > max { max = v; }
2579 result.push(max);
2580 }
2581 Ok(Some(Column::Float(result)))
2582 }
2583 DExpr::CumMin(inner) => {
2584 let src = extract_f64_column(df, inner, nrows)?;
2585 let mut result = Vec::with_capacity(nrows);
2586 let mut min = f64::INFINITY;
2587 for &v in &src {
2588 if v < min { min = v; }
2589 result.push(min);
2590 }
2591 Ok(Some(Column::Float(result)))
2592 }
2593 DExpr::Lag(inner, k) => {
2594 let src = extract_f64_column(df, inner, nrows)?;
2595 let mut result = Vec::with_capacity(nrows);
2596 for i in 0..nrows {
2597 if i < *k {
2598 result.push(f64::NAN);
2599 } else {
2600 result.push(src[i - k]);
2601 }
2602 }
2603 Ok(Some(Column::Float(result)))
2604 }
2605 DExpr::Lead(inner, k) => {
2606 let src = extract_f64_column(df, inner, nrows)?;
2607 let mut result = Vec::with_capacity(nrows);
2608 for i in 0..nrows {
2609 if i + k >= nrows {
2610 result.push(f64::NAN);
2611 } else {
2612 result.push(src[i + k]);
2613 }
2614 }
2615 Ok(Some(Column::Float(result)))
2616 }
2617 DExpr::Rank(inner) => {
2618 let src = extract_f64_column(df, inner, nrows)?;
2619 let mut indexed: Vec<(usize, f64)> = src.iter().cloned().enumerate().collect();
2621 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
2622 let mut ranks = vec![0.0_f64; nrows];
2623 let mut i = 0;
2624 while i < nrows {
2625 let mut j = i;
2626 while j < nrows && indexed[j].1 == indexed[i].1 {
2627 j += 1;
2628 }
2629 let avg_rank = (i + 1 + j) as f64 / 2.0; for idx in i..j {
2631 ranks[indexed[idx].0] = avg_rank;
2632 }
2633 i = j;
2634 }
2635 Ok(Some(Column::Float(ranks)))
2636 }
2637 DExpr::DenseRank(inner) => {
2638 let src = extract_f64_column(df, inner, nrows)?;
2639 let mut indexed: Vec<(usize, f64)> = src.iter().cloned().enumerate().collect();
2640 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
2641 let mut ranks = vec![0_i64; nrows];
2642 let mut rank = 0_i64;
2643 let mut prev: Option<f64> = None;
2644 for &(orig_idx, val) in &indexed {
2645 if prev.is_none() || prev.unwrap() != val {
2646 rank += 1;
2647 }
2648 ranks[orig_idx] = rank;
2649 prev = Some(val);
2650 }
2651 Ok(Some(Column::Int(ranks)))
2652 }
2653 DExpr::RollingSum(col_name, window) => {
2654 let vals = rolling_get_floats(df, col_name)?;
2655 let n = vals.len();
2656 let w = *window;
2657 let mut result = Vec::with_capacity(n);
2658 let mut sum = 0.0_f64;
2659 let mut comp = 0.0_f64;
2660 for i in 0..n {
2661 let y = vals[i] - comp;
2663 let t = sum + y;
2664 comp = (t - sum) - y;
2665 sum = t;
2666 if i >= w {
2668 let y2 = -vals[i - w] - comp;
2669 let t2 = sum + y2;
2670 comp = (t2 - sum) - y2;
2671 sum = t2;
2672 }
2673 result.push(sum);
2674 }
2675 Ok(Some(Column::Float(result)))
2676 }
2677 DExpr::RollingMean(col_name, window) => {
2678 let vals = rolling_get_floats(df, col_name)?;
2679 let n = vals.len();
2680 let w = *window;
2681 let mut result = Vec::with_capacity(n);
2682 let mut sum = 0.0_f64;
2683 let mut comp = 0.0_f64;
2684 for i in 0..n {
2685 let y = vals[i] - comp;
2686 let t = sum + y;
2687 comp = (t - sum) - y;
2688 sum = t;
2689 if i >= w {
2690 let y2 = -vals[i - w] - comp;
2691 let t2 = sum + y2;
2692 comp = (t2 - sum) - y2;
2693 sum = t2;
2694 }
2695 let count = if i < w { i + 1 } else { w };
2696 result.push(sum / count as f64);
2697 }
2698 Ok(Some(Column::Float(result)))
2699 }
2700 DExpr::RollingMin(col_name, window) => {
2701 let vals = rolling_get_floats(df, col_name)?;
2702 let n = vals.len();
2703 let w = *window;
2704 let mut result = Vec::with_capacity(n);
2705 let mut deque: VecDeque<usize> = VecDeque::new();
2706 for i in 0..n {
2707 while !deque.is_empty() && *deque.front().unwrap() + w <= i {
2709 deque.pop_front();
2710 }
2711 while !deque.is_empty() && vals[*deque.back().unwrap()] >= vals[i] {
2713 deque.pop_back();
2714 }
2715 deque.push_back(i);
2716 result.push(vals[*deque.front().unwrap()]);
2717 }
2718 Ok(Some(Column::Float(result)))
2719 }
2720 DExpr::RollingMax(col_name, window) => {
2721 let vals = rolling_get_floats(df, col_name)?;
2722 let n = vals.len();
2723 let w = *window;
2724 let mut result = Vec::with_capacity(n);
2725 let mut deque: VecDeque<usize> = VecDeque::new();
2726 for i in 0..n {
2727 while !deque.is_empty() && *deque.front().unwrap() + w <= i {
2728 deque.pop_front();
2729 }
2730 while !deque.is_empty() && vals[*deque.back().unwrap()] <= vals[i] {
2732 deque.pop_back();
2733 }
2734 deque.push_back(i);
2735 result.push(vals[*deque.front().unwrap()]);
2736 }
2737 Ok(Some(Column::Float(result)))
2738 }
2739 DExpr::RollingVar(col_name, window) => {
2740 let vals = rolling_get_floats(df, col_name)?;
2741 let n = vals.len();
2742 let w = *window;
2743 let mut result = Vec::with_capacity(n);
2744 let mut count = 0_usize;
2746 let mut mean = 0.0_f64;
2747 let mut m2 = 0.0_f64;
2748 for i in 0..n {
2749 count += 1;
2751 let delta = vals[i] - mean;
2752 mean += delta / count as f64;
2753 let delta2 = vals[i] - mean;
2754 m2 += delta * delta2;
2755 if i >= w {
2757 let old = vals[i - w];
2758 count -= 1;
2759 if count == 0 {
2760 mean = 0.0;
2761 m2 = 0.0;
2762 } else {
2763 let delta_old = old - mean;
2764 mean -= delta_old / count as f64;
2765 let delta_old2 = old - mean;
2766 m2 -= delta_old * delta_old2;
2767 }
2768 }
2769 if count < 2 {
2770 result.push(0.0);
2771 } else {
2772 result.push(m2 / (count - 1) as f64);
2775 }
2776 }
2777 Ok(Some(Column::Float(result)))
2778 }
2779 DExpr::RollingSd(col_name, window) => {
2780 let vals = rolling_get_floats(df, col_name)?;
2781 let n = vals.len();
2782 let w = *window;
2783 let mut result = Vec::with_capacity(n);
2784 let mut count = 0_usize;
2785 let mut mean = 0.0_f64;
2786 let mut m2 = 0.0_f64;
2787 for i in 0..n {
2788 count += 1;
2789 let delta = vals[i] - mean;
2790 mean += delta / count as f64;
2791 let delta2 = vals[i] - mean;
2792 m2 += delta * delta2;
2793 if i >= w {
2794 let old = vals[i - w];
2795 count -= 1;
2796 if count == 0 {
2797 mean = 0.0;
2798 m2 = 0.0;
2799 } else {
2800 let delta_old = old - mean;
2801 mean -= delta_old / count as f64;
2802 let delta_old2 = old - mean;
2803 m2 -= delta_old * delta_old2;
2804 }
2805 }
2806 if count < 2 {
2807 result.push(0.0);
2808 } else {
2809 result.push((m2 / (count - 1) as f64).sqrt());
2810 }
2811 }
2812 Ok(Some(Column::Float(result)))
2813 }
2814 _ => Ok(None),
2815 }
2816}
2817
2818fn rolling_get_floats(df: &DataFrame, col_name: &str) -> Result<Vec<f64>, TidyError> {
2820 let col = df
2821 .get_column(col_name)
2822 .ok_or_else(|| TidyError::ColumnNotFound(col_name.to_string()))?;
2823 match col {
2824 Column::Float(v) => Ok(v.clone()),
2825 Column::Int(v) => Ok(v.iter().map(|&i| i as f64).collect()),
2826 _ => Err(TidyError::TypeMismatch {
2827 expected: "numeric".into(),
2828 got: "non-numeric".into(),
2829 }),
2830 }
2831}
2832
2833fn vectorized_binop(op: DBinOp, left: &Column, right: &Column) -> Result<Column, TidyError> {
2838 match (left, right) {
2839 (Column::Int(a), Column::Int(b)) => {
2840 let n = a.len();
2841 match op {
2842 DBinOp::Add => { let mut r = vec![0i64; n]; for i in 0..n { r[i] = a[i] + b[i]; } Ok(Column::Int(r)) }
2843 DBinOp::Sub => { let mut r = vec![0i64; n]; for i in 0..n { r[i] = a[i] - b[i]; } Ok(Column::Int(r)) }
2844 DBinOp::Mul => { let mut r = vec![0i64; n]; for i in 0..n { r[i] = a[i] * b[i]; } Ok(Column::Int(r)) }
2845 DBinOp::Div => { let mut r = vec![0i64; n]; for i in 0..n { r[i] = a[i] / b[i]; } Ok(Column::Int(r)) }
2846 DBinOp::Gt => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] > b[i]; } Ok(Column::Bool(r)) }
2847 DBinOp::Lt => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] < b[i]; } Ok(Column::Bool(r)) }
2848 DBinOp::Ge => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] >= b[i]; } Ok(Column::Bool(r)) }
2849 DBinOp::Le => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] <= b[i]; } Ok(Column::Bool(r)) }
2850 DBinOp::Eq => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] == b[i]; } Ok(Column::Bool(r)) }
2851 DBinOp::Ne => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] != b[i]; } Ok(Column::Bool(r)) }
2852 _ => Err(TidyError::Internal(format!("unsupported op {:?} on Int", op))),
2853 }
2854 }
2855 (Column::Float(a), Column::Float(b)) => {
2856 let n = a.len();
2857 match op {
2858 DBinOp::Add => { let mut r = vec![0.0f64; n]; for i in 0..n { r[i] = a[i] + b[i]; } Ok(Column::Float(r)) }
2859 DBinOp::Sub => { let mut r = vec![0.0f64; n]; for i in 0..n { r[i] = a[i] - b[i]; } Ok(Column::Float(r)) }
2860 DBinOp::Mul => { let mut r = vec![0.0f64; n]; for i in 0..n { r[i] = a[i] * b[i]; } Ok(Column::Float(r)) }
2861 DBinOp::Div => { let mut r = vec![0.0f64; n]; for i in 0..n { r[i] = a[i] / b[i]; } Ok(Column::Float(r)) }
2862 DBinOp::Gt => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] > b[i]; } Ok(Column::Bool(r)) }
2863 DBinOp::Lt => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] < b[i]; } Ok(Column::Bool(r)) }
2864 DBinOp::Ge => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] >= b[i]; } Ok(Column::Bool(r)) }
2865 DBinOp::Le => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] <= b[i]; } Ok(Column::Bool(r)) }
2866 DBinOp::Eq => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] == b[i]; } Ok(Column::Bool(r)) }
2867 DBinOp::Ne => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] != b[i]; } Ok(Column::Bool(r)) }
2868 _ => Err(TidyError::Internal(format!("unsupported op {:?} on Float", op))),
2869 }
2870 }
2871 (Column::Int(a), Column::Float(_b)) => {
2872 let promoted: Vec<f64> = a.iter().map(|&v| v as f64).collect();
2873 vectorized_binop(op, &Column::Float(promoted), right)
2874 }
2875 (Column::Float(_a), Column::Int(b)) => {
2876 let promoted: Vec<f64> = b.iter().map(|&v| v as f64).collect();
2877 vectorized_binop(op, left, &Column::Float(promoted))
2878 }
2879 (Column::Bool(a), Column::Bool(b)) => {
2880 let n = a.len();
2881 match op {
2882 DBinOp::And => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] && b[i]; } Ok(Column::Bool(r)) }
2883 DBinOp::Or => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] || b[i]; } Ok(Column::Bool(r)) }
2884 DBinOp::Eq => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] == b[i]; } Ok(Column::Bool(r)) }
2885 DBinOp::Ne => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] != b[i]; } Ok(Column::Bool(r)) }
2886 _ => Err(TidyError::Internal(format!("unsupported op {:?} on Bool", op))),
2887 }
2888 }
2889 (Column::Str(a), Column::Str(b)) => {
2890 let n = a.len();
2891 match op {
2892 DBinOp::Eq => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] == b[i]; } Ok(Column::Bool(r)) }
2893 DBinOp::Ne => { let mut r = vec![false; n]; for i in 0..n { r[i] = a[i] != b[i]; } Ok(Column::Bool(r)) }
2894 _ => Err(TidyError::Internal(format!("unsupported op {:?} on String", op))),
2895 }
2896 }
2897 _ => Err(TidyError::Internal("type mismatch in binary operation".into())),
2898 }
2899}
2900
2901fn vectorized_fn_call(name: &str, arg: &Column) -> Result<Column, TidyError> {
2904 let floats: Vec<f64> = match arg {
2905 Column::Float(v) => v.clone(),
2906 Column::Int(v) => v.iter().map(|&i| i as f64).collect(),
2907 _ => return Err(TidyError::Internal(format!(
2908 "FnCall '{}' requires numeric argument", name
2909 ))),
2910 };
2911 let f: fn(f64) -> f64 = match name {
2912 "log" => f64::ln,
2913 "exp" => f64::exp,
2914 "sqrt" => f64::sqrt,
2915 "abs" => f64::abs,
2916 "ceil" => f64::ceil,
2917 "floor" => f64::floor,
2918 "round" => f64::round,
2919 "sin" => f64::sin,
2920 "cos" => f64::cos,
2921 "tan" => f64::tan,
2922 _ => return Err(TidyError::Internal(format!(
2923 "unknown DExpr function: {}", name
2924 ))),
2925 };
2926 let mut result = vec![0.0f64; floats.len()];
2927 for i in 0..floats.len() {
2928 result[i] = f(floats[i]);
2929 }
2930 Ok(Column::Float(result))
2931}
2932
2933fn try_eval_expr_column_vectorized(
2937 df: &DataFrame,
2938 expr: &DExpr,
2939 nrows: usize,
2940) -> Option<Result<Column, TidyError>> {
2941 match expr {
2942 DExpr::Col(name) => {
2943 let col = df.get_column(name)?;
2944 let result = match col {
2945 Column::Int(v) => Column::Int(v[..nrows].to_vec()),
2946 Column::Float(v) => Column::Float(v[..nrows].to_vec()),
2947 Column::Str(v) => Column::Str(v[..nrows].to_vec()),
2948 Column::Bool(v) => Column::Bool(v[..nrows].to_vec()),
2949 Column::Categorical { levels, codes } => {
2950 let strs: Vec<String> = codes[..nrows]
2951 .iter()
2952 .map(|&c| levels[c as usize].clone())
2953 .collect();
2954 Column::Str(strs)
2955 }
2956 Column::DateTime(v) => Column::Int(v[..nrows].to_vec()),
2957 };
2958 Some(Ok(result))
2959 }
2960 DExpr::LitFloat(v) => Some(Ok(Column::Float(vec![*v; nrows]))),
2961 DExpr::LitInt(v) => Some(Ok(Column::Int(vec![*v; nrows]))),
2962 DExpr::LitBool(b) => Some(Ok(Column::Bool(vec![*b; nrows]))),
2963 DExpr::LitStr(s) => Some(Ok(Column::Str(vec![s.clone(); nrows]))),
2964 DExpr::BinOp { op, left, right } => {
2965 let left_col = try_eval_expr_column_vectorized(df, left, nrows)?.ok()?;
2966 let right_col = try_eval_expr_column_vectorized(df, right, nrows)?.ok()?;
2967 Some(vectorized_binop(*op, &left_col, &right_col))
2968 }
2969 DExpr::FnCall(name, args) if args.len() == 1 => {
2970 let arg_col = try_eval_expr_column_vectorized(df, &args[0], nrows)?.ok()?;
2971 Some(vectorized_fn_call(name, &arg_col))
2972 }
2973 _ => None,
2974 }
2975}
2976
2977fn eval_expr_column(df: &DataFrame, expr: &DExpr, nrows: usize) -> Result<Column, TidyError> {
2978 if nrows == 0 {
2979 return Ok(Column::Float(vec![]));
2981 }
2982
2983 if let Some(col) = eval_window_column(df, expr, nrows)? {
2985 return Ok(col);
2986 }
2987
2988 if let Some(result) = try_eval_expr_column_vectorized(df, expr, nrows) {
2990 return result;
2991 }
2992
2993 let sample = eval_dexpr_row(df, expr, 0)?;
2995 match sample {
2996 ExprValue::Int(_) => {
2997 let vals: Result<Vec<i64>, TidyError> = (0..nrows)
2998 .map(|r| {
2999 eval_dexpr_row(df, expr, r).and_then(|v| match v {
3000 ExprValue::Int(i) => Ok(i),
3001 ExprValue::Float(f) => Ok(f as i64),
3002 other => Err(TidyError::TypeMismatch {
3003 expected: "Int".into(),
3004 got: other.type_name().into(),
3005 }),
3006 })
3007 })
3008 .collect();
3009 Ok(Column::Int(vals?))
3010 }
3011 ExprValue::Float(_) => {
3012 let vals: Result<Vec<f64>, TidyError> = (0..nrows)
3013 .map(|r| {
3014 eval_dexpr_row(df, expr, r).and_then(|v| match v {
3015 ExprValue::Float(f) => Ok(f),
3016 ExprValue::Int(i) => Ok(i as f64),
3017 other => Err(TidyError::TypeMismatch {
3018 expected: "Float".into(),
3019 got: other.type_name().into(),
3020 }),
3021 })
3022 })
3023 .collect();
3024 Ok(Column::Float(vals?))
3025 }
3026 ExprValue::Bool(_) => {
3027 let vals: Result<Vec<bool>, TidyError> = (0..nrows)
3028 .map(|r| {
3029 eval_dexpr_row(df, expr, r).and_then(|v| match v {
3030 ExprValue::Bool(b) => Ok(b),
3031 other => Err(TidyError::TypeMismatch {
3032 expected: "Bool".into(),
3033 got: other.type_name().into(),
3034 }),
3035 })
3036 })
3037 .collect();
3038 Ok(Column::Bool(vals?))
3039 }
3040 ExprValue::Str(_) => {
3041 let vals: Result<Vec<String>, TidyError> = (0..nrows)
3042 .map(|r| {
3043 eval_dexpr_row(df, expr, r).and_then(|v| match v {
3044 ExprValue::Str(s) => Ok(s),
3045 other => Err(TidyError::TypeMismatch {
3046 expected: "Str".into(),
3047 got: other.type_name().into(),
3048 }),
3049 })
3050 })
3051 .collect();
3052 Ok(Column::Str(vals?))
3053 }
3054 }
3055}
3056
3057fn eval_dexpr_row(df: &DataFrame, expr: &DExpr, row: usize) -> Result<ExprValue, TidyError> {
3059 eval_expr_row(df, expr, row).map_err(|e| TidyError::Internal(e.to_string()))
3060}
3061
3062fn eval_expr_row_proj(
3064 base: &DataFrame,
3065 expr: &DExpr,
3066 row: usize,
3067 _proj: &ProjectionMap,
3068) -> Result<ExprValue, TidyError> {
3069 eval_expr_row(base, expr, row).map_err(|e| TidyError::Internal(e.to_string()))
3073}
3074
3075fn validate_expr_columns_proj(
3082 expr: &DExpr,
3083 base: &DataFrame,
3084 _proj: &ProjectionMap,
3085) -> Result<(), TidyError> {
3086 let mut refs = Vec::new();
3087 collect_expr_columns(expr, &mut refs);
3088 for col_name in refs {
3089 if base.get_column(&col_name).is_none() {
3090 return Err(TidyError::ColumnNotFound(col_name));
3091 }
3092 }
3093 Ok(())
3094}
3095
3096fn validate_expr_columns_snapshot(
3098 expr: &DExpr,
3099 snapshot_names: &[String],
3100) -> Result<(), TidyError> {
3101 let mut refs = Vec::new();
3102 collect_expr_columns(expr, &mut refs);
3103 for col_name in refs {
3104 if !snapshot_names.iter().any(|n| n == &col_name) {
3105 return Err(TidyError::ColumnNotFound(col_name));
3106 }
3107 }
3108 Ok(())
3109}
3110
3111impl ExprValue {
3112 fn type_name(&self) -> &'static str {
3113 match self {
3114 ExprValue::Int(_) => "Int",
3115 ExprValue::Float(_) => "Float",
3116 ExprValue::Str(_) => "Str",
3117 ExprValue::Bool(_) => "Bool",
3118 }
3119 }
3120}
3121
3122impl DataFrame {
3125 pub fn tidy(self) -> TidyView {
3129 TidyView::from_df(self)
3130 }
3131
3132 pub fn tidy_mut(self) -> TidyFrame {
3134 TidyFrame::from_df(self)
3135 }
3136}
3137
3138#[derive(Debug, Clone, PartialEq, Eq)]
3196pub struct RowIndexMap {
3197 pub(crate) indices: Vec<usize>,
3200}
3201
3202impl RowIndexMap {
3203 pub fn new(indices: Vec<usize>) -> Self {
3204 RowIndexMap { indices }
3205 }
3206
3207 pub fn len(&self) -> usize {
3208 self.indices.len()
3209 }
3210
3211 pub fn is_empty(&self) -> bool {
3212 self.indices.is_empty()
3213 }
3214
3215 pub fn as_slice(&self) -> &[usize] {
3216 &self.indices
3217 }
3218}
3219
3220#[derive(Debug, Clone)]
3224pub struct GroupMeta {
3225 pub key_values: Vec<String>,
3227 pub row_indices: Vec<usize>,
3229}
3230
3231#[derive(Debug, Clone)]
3242pub struct GroupIndex {
3243 pub groups: Vec<GroupMeta>,
3245 pub key_names: Vec<String>,
3247}
3248
3249impl GroupIndex {
3250 pub fn build(
3255 base: &DataFrame,
3256 key_col_indices: &[usize],
3257 visible_rows: &[usize],
3258 key_names: Vec<String>,
3259 ) -> Self {
3260 let mut group_order: Vec<Vec<String>> = Vec::new(); let mut group_map: Vec<(Vec<String>, usize)> = Vec::new(); for &row in visible_rows {
3266 let key: Vec<String> = key_col_indices
3267 .iter()
3268 .map(|&ci| base.columns[ci].1.get_display(row))
3269 .collect();
3270
3271 let slot = group_map
3273 .iter()
3274 .position(|(k, _)| k == &key)
3275 .unwrap_or_else(|| {
3276 let s = group_map.len();
3277 group_map.push((key.clone(), s));
3278 group_order.push(key);
3279 s
3280 });
3281
3282 let _ = slot; }
3284
3285 let mut groups: Vec<GroupMeta> = group_order
3287 .iter()
3288 .map(|k| GroupMeta {
3289 key_values: k.clone(),
3290 row_indices: Vec::new(),
3291 })
3292 .collect();
3293
3294 let key_to_slot: Vec<(Vec<String>, usize)> = group_order
3296 .iter()
3297 .enumerate()
3298 .map(|(i, k)| (k.clone(), i))
3299 .collect();
3300
3301 for &row in visible_rows {
3302 let key: Vec<String> = key_col_indices
3303 .iter()
3304 .map(|&ci| base.columns[ci].1.get_display(row))
3305 .collect();
3306 if let Some((_, slot)) = key_to_slot.iter().find(|(k, _)| k == &key) {
3307 groups[*slot].row_indices.push(row);
3308 }
3309 }
3310
3311 GroupIndex { groups, key_names }
3312 }
3313}
3314
3315#[derive(Debug, Clone)]
3325pub struct GroupedTidyView {
3326 view: TidyView,
3327 index: GroupIndex,
3328}
3329
3330impl GroupedTidyView {
3331 pub fn ngroups(&self) -> usize {
3333 self.index.groups.len()
3334 }
3335
3336 pub fn ungroup(self) -> TidyView {
3338 self.view
3339 }
3340
3341 pub fn group_index(&self) -> &GroupIndex {
3343 &self.index
3344 }
3345
3346 pub fn summarise(
3365 &self,
3366 assignments: &[(&str, TidyAgg)],
3367 ) -> Result<TidyFrame, TidyError> {
3368 {
3370 let mut seen = std::collections::BTreeSet::new();
3371 for &(name, _) in assignments {
3372 if !seen.insert(name) {
3373 return Err(TidyError::DuplicateColumn(name.to_string()));
3374 }
3375 }
3376 }
3377
3378 let base = &self.view.base;
3379 let n_groups = self.index.groups.len();
3380
3381 let mut result_columns: Vec<(String, Column)> = Vec::new();
3383
3384 for key_name in &self.index.key_names {
3385 let base_col = base
3386 .get_column(key_name)
3387 .ok_or_else(|| TidyError::ColumnNotFound(key_name.clone()))?;
3388
3389 let col = match base_col {
3390 Column::Int(_) => {
3391 let vals: Vec<i64> = self
3392 .index
3393 .groups
3394 .iter()
3395 .map(|g| {
3396 g.key_values[self
3397 .index
3398 .key_names
3399 .iter()
3400 .position(|k| k == key_name)
3401 .unwrap()]
3402 .parse::<i64>()
3403 .unwrap_or(0)
3404 })
3405 .collect();
3406 Column::Int(vals)
3407 }
3408 Column::Bool(_) => {
3409 let vals: Vec<bool> = self
3410 .index
3411 .groups
3412 .iter()
3413 .map(|g| {
3414 let s = &g.key_values[self
3415 .index
3416 .key_names
3417 .iter()
3418 .position(|k| k == key_name)
3419 .unwrap()];
3420 matches!(s.as_str(), "true" | "1")
3421 })
3422 .collect();
3423 Column::Bool(vals)
3424 }
3425 _ => {
3426 let vals: Vec<String> = self
3428 .index
3429 .groups
3430 .iter()
3431 .map(|g| {
3432 g.key_values[self
3433 .index
3434 .key_names
3435 .iter()
3436 .position(|k| k == key_name)
3437 .unwrap()]
3438 .clone()
3439 })
3440 .collect();
3441 Column::Str(vals)
3442 }
3443 };
3444 result_columns.push((key_name.clone(), col));
3445 }
3446
3447 for &(out_name, ref agg) in assignments {
3449 let col_vals = self.eval_agg_over_groups_fast(agg, n_groups, base)?;
3450 result_columns.push((out_name.to_string(), col_vals));
3451 }
3452
3453 let df = DataFrame::from_columns(result_columns)
3454 .map_err(|e| TidyError::Internal(e.to_string()))?;
3455 Ok(TidyFrame::from_df(df))
3456 }
3457
3458 #[allow(dead_code)]
3460 fn eval_agg_over_groups(
3461 &self,
3462 agg: &TidyAgg,
3463 n_groups: usize,
3464 base: &DataFrame,
3465 ) -> Result<Column, TidyError> {
3466 match agg {
3467 TidyAgg::Count => {
3468 let counts: Vec<i64> = self
3469 .index
3470 .groups
3471 .iter()
3472 .map(|g| g.row_indices.len() as i64)
3473 .collect();
3474 Ok(Column::Int(counts))
3475 }
3476
3477 TidyAgg::Sum(col_name) | TidyAgg::Mean(col_name)
3478 | TidyAgg::Min(col_name) | TidyAgg::Max(col_name)
3479 | TidyAgg::First(col_name) | TidyAgg::Last(col_name)
3480 | TidyAgg::Median(col_name) | TidyAgg::Sd(col_name)
3481 | TidyAgg::Var(col_name) | TidyAgg::Quantile(col_name, _)
3482 | TidyAgg::NDistinct(col_name) | TidyAgg::Iqr(col_name) => {
3483 let src = base
3484 .get_column(col_name)
3485 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3486
3487 let mut vals = Vec::with_capacity(n_groups);
3488 for group in &self.index.groups {
3489 let v = agg_reduce(agg, src, &group.row_indices)?;
3490 vals.push(v);
3491 }
3492 Ok(Column::Float(vals))
3493 }
3494 }
3495 }
3496
3497 fn eval_agg_over_groups_fast(
3500 &self,
3501 agg: &TidyAgg,
3502 n_groups: usize,
3503 base: &DataFrame,
3504 ) -> Result<Column, TidyError> {
3505 match agg {
3506 TidyAgg::Count => {
3507 let counts: Vec<i64> = self
3508 .index
3509 .groups
3510 .iter()
3511 .map(|g| g.row_indices.len() as i64)
3512 .collect();
3513 Ok(Column::Int(counts))
3514 }
3515 TidyAgg::Sum(col_name) => {
3516 let src = base.get_column(col_name)
3517 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3518 Ok(Column::Float(fast_agg_sum(&self.index.groups, src)?))
3519 }
3520 TidyAgg::Mean(col_name) => {
3521 let src = base.get_column(col_name)
3522 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3523 Ok(Column::Float(fast_agg_mean(&self.index.groups, src)?))
3524 }
3525 TidyAgg::Min(col_name) => {
3526 let src = base.get_column(col_name)
3527 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3528 Ok(Column::Float(fast_agg_min(&self.index.groups, src)?))
3529 }
3530 TidyAgg::Max(col_name) => {
3531 let src = base.get_column(col_name)
3532 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3533 Ok(Column::Float(fast_agg_max(&self.index.groups, src)?))
3534 }
3535 TidyAgg::First(col_name) => {
3536 let src = base.get_column(col_name)
3537 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3538 Ok(Column::Float(fast_agg_first(&self.index.groups, src)?))
3539 }
3540 TidyAgg::Last(col_name) => {
3541 let src = base.get_column(col_name)
3542 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3543 Ok(Column::Float(fast_agg_last(&self.index.groups, src)?))
3544 }
3545 TidyAgg::Var(col_name)
3546 | TidyAgg::Sd(col_name)
3547 | TidyAgg::Median(col_name)
3548 | TidyAgg::Quantile(col_name, _)
3549 | TidyAgg::NDistinct(col_name)
3550 | TidyAgg::Iqr(col_name) => {
3551 let src = base.get_column(col_name)
3552 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
3553 Ok(Column::Float(fast_agg_arena(
3554 agg, &self.index.groups, src, n_groups,
3555 )?))
3556 }
3557 }
3558 }
3559}
3560
3561enum ColRef<'a> {
3564 Float(&'a [f64]),
3565 Int(&'a [i64]),
3566}
3567
3568fn col_to_ref(col: &Column) -> Result<ColRef<'_>, TidyError> {
3569 match col {
3570 Column::Float(v) => Ok(ColRef::Float(v)),
3571 Column::Int(v) => Ok(ColRef::Int(v)),
3572 _ => Err(TidyError::TypeMismatch {
3573 expected: "numeric (Int or Float)".into(),
3574 got: col.type_name().into(),
3575 }),
3576 }
3577}
3578
3579fn fast_agg_sum(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3580 use cjc_repro::kahan::KahanAccumulatorF64;
3581 let cr = col_to_ref(col)?;
3582 Ok(groups.iter().map(|g| {
3583 let mut acc = KahanAccumulatorF64::new();
3584 match cr {
3585 ColRef::Float(d) => { for &i in &g.row_indices { acc.add(d[i]); } }
3586 ColRef::Int(d) => { for &i in &g.row_indices { acc.add(d[i] as f64); } }
3587 }
3588 acc.finalize()
3589 }).collect())
3590}
3591
3592fn fast_agg_mean(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3593 use cjc_repro::kahan::KahanAccumulatorF64;
3594 let cr = col_to_ref(col)?;
3595 Ok(groups.iter().map(|g| {
3596 if g.row_indices.is_empty() { return f64::NAN; }
3597 let mut acc = KahanAccumulatorF64::new();
3598 match cr {
3599 ColRef::Float(d) => { for &i in &g.row_indices { acc.add(d[i]); } }
3600 ColRef::Int(d) => { for &i in &g.row_indices { acc.add(d[i] as f64); } }
3601 }
3602 acc.finalize() / g.row_indices.len() as f64
3603 }).collect())
3604}
3605
3606fn fast_agg_min(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3607 let cr = col_to_ref(col)?;
3608 Ok(groups.iter().map(|g| {
3609 if g.row_indices.is_empty() { return f64::NAN; }
3610 match cr {
3611 ColRef::Float(d) => g.row_indices.iter().fold(f64::INFINITY, |a, &i| {
3612 let b = d[i]; if b.is_nan() || b < a { b } else { a }
3613 }),
3614 ColRef::Int(d) => g.row_indices.iter().fold(f64::INFINITY, |a, &i| {
3615 let b = d[i] as f64; if b.is_nan() || b < a { b } else { a }
3616 }),
3617 }
3618 }).collect())
3619}
3620
3621fn fast_agg_max(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3622 let cr = col_to_ref(col)?;
3623 Ok(groups.iter().map(|g| {
3624 if g.row_indices.is_empty() { return f64::NAN; }
3625 match cr {
3626 ColRef::Float(d) => g.row_indices.iter().fold(f64::NEG_INFINITY, |a, &i| {
3627 let b = d[i]; if b.is_nan() || b > a { b } else { a }
3628 }),
3629 ColRef::Int(d) => g.row_indices.iter().fold(f64::NEG_INFINITY, |a, &i| {
3630 let b = d[i] as f64; if b.is_nan() || b > a { b } else { a }
3631 }),
3632 }
3633 }).collect())
3634}
3635
3636fn fast_agg_first(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3637 let cr = col_to_ref(col)?;
3638 let mut vals = Vec::with_capacity(groups.len());
3639 for g in groups {
3640 if g.row_indices.is_empty() { return Err(TidyError::EmptyGroup); }
3641 match cr {
3642 ColRef::Float(d) => vals.push(d[g.row_indices[0]]),
3643 ColRef::Int(d) => vals.push(d[g.row_indices[0]] as f64),
3644 }
3645 }
3646 Ok(vals)
3647}
3648
3649fn fast_agg_last(groups: &[GroupMeta], col: &Column) -> Result<Vec<f64>, TidyError> {
3650 let cr = col_to_ref(col)?;
3651 let mut vals = Vec::with_capacity(groups.len());
3652 for g in groups {
3653 if g.row_indices.is_empty() { return Err(TidyError::EmptyGroup); }
3654 let last = *g.row_indices.last().unwrap();
3655 match cr {
3656 ColRef::Float(d) => vals.push(d[last]),
3657 ColRef::Int(d) => vals.push(d[last] as f64),
3658 }
3659 }
3660 Ok(vals)
3661}
3662
3663fn fast_agg_arena(
3666 agg: &TidyAgg,
3667 groups: &[GroupMeta],
3668 col: &Column,
3669 n_groups: usize,
3670) -> Result<Vec<f64>, TidyError> {
3671 let cr = col_to_ref(col)?;
3672 let max_size = groups.iter().map(|g| g.row_indices.len()).max().unwrap_or(0);
3673 let mut arena: Vec<f64> = Vec::with_capacity(max_size);
3674 let mut results = Vec::with_capacity(n_groups);
3675 for group in groups {
3676 arena.clear();
3677 match cr {
3678 ColRef::Float(d) => { for &i in &group.row_indices { arena.push(d[i]); } }
3679 ColRef::Int(d) => { for &i in &group.row_indices { arena.push(d[i] as f64); } }
3680 }
3681 let val = agg_reduce_slice(agg, &mut arena)?;
3682 results.push(val);
3683 }
3684 Ok(results)
3685}
3686
3687fn agg_reduce_slice(agg: &TidyAgg, values: &mut [f64]) -> Result<f64, TidyError> {
3690 match agg {
3691 TidyAgg::Var(_) => {
3692 if values.len() < 2 {
3693 Ok(f64::NAN)
3694 } else {
3695 let n = values.len() as f64;
3696 let mean = kahan_sum_f64(values) / n;
3697 let sq_diffs: Vec<f64> = values.iter().map(|v| (v - mean) * (v - mean)).collect();
3698 Ok(kahan_sum_f64(&sq_diffs) / (n - 1.0))
3699 }
3700 }
3701 TidyAgg::Sd(_) => {
3702 if values.len() < 2 {
3703 Ok(f64::NAN)
3704 } else {
3705 let n = values.len() as f64;
3706 let mean = kahan_sum_f64(values) / n;
3707 let sq_diffs: Vec<f64> = values.iter().map(|v| (v - mean) * (v - mean)).collect();
3708 Ok((kahan_sum_f64(&sq_diffs) / (n - 1.0)).sqrt())
3709 }
3710 }
3711 TidyAgg::Median(_) => {
3712 if values.is_empty() {
3713 Ok(f64::NAN)
3714 } else {
3715 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3716 let n = values.len();
3717 if n % 2 == 1 { Ok(values[n / 2]) }
3718 else { Ok((values[n / 2 - 1] + values[n / 2]) / 2.0) }
3719 }
3720 }
3721 TidyAgg::Quantile(_, p) => {
3722 if values.is_empty() {
3723 Ok(f64::NAN)
3724 } else {
3725 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3726 let n = values.len();
3727 if n == 1 { return Ok(values[0]); }
3728 let pos = p * (n as f64 - 1.0);
3729 let lo = pos.floor() as usize;
3730 let hi = pos.ceil() as usize;
3731 let frac = pos - lo as f64;
3732 if lo == hi || hi >= n { Ok(values[lo.min(n - 1)]) }
3733 else { Ok(values[lo] + frac * (values[hi] - values[lo])) }
3734 }
3735 }
3736 TidyAgg::NDistinct(_) => {
3737 let distinct: std::collections::BTreeSet<u64> = values.iter().map(|v| v.to_bits()).collect();
3738 Ok(distinct.len() as f64)
3739 }
3740 TidyAgg::Iqr(_) => {
3741 if values.is_empty() {
3742 Ok(f64::NAN)
3743 } else {
3744 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3745 let n = values.len();
3746 if n == 1 { return Ok(0.0); }
3747 let q1 = {
3748 let pos = 0.25 * (n as f64 - 1.0);
3749 let lo = pos.floor() as usize;
3750 let hi = pos.ceil() as usize;
3751 let frac = pos - lo as f64;
3752 if lo == hi || hi >= n { values[lo.min(n - 1)] }
3753 else { values[lo] + frac * (values[hi] - values[lo]) }
3754 };
3755 let q3 = {
3756 let pos = 0.75 * (n as f64 - 1.0);
3757 let lo = pos.floor() as usize;
3758 let hi = pos.ceil() as usize;
3759 let frac = pos - lo as f64;
3760 if lo == hi || hi >= n { values[lo.min(n - 1)] }
3761 else { values[lo] + frac * (values[hi] - values[lo]) }
3762 };
3763 Ok(q3 - q1)
3764 }
3765 }
3766 _ => unreachable!("agg_reduce_slice called for non-arena aggregator"),
3767 }
3768}
3769
3770#[allow(dead_code)]
3772fn agg_reduce(
3773 agg: &TidyAgg,
3774 col: &Column,
3775 rows: &[usize],
3776) -> Result<f64, TidyError> {
3777 let values: Vec<f64> = match col {
3779 Column::Int(v) => rows.iter().map(|&r| v[r] as f64).collect(),
3780 Column::Float(v) => rows.iter().map(|&r| v[r]).collect(),
3781 _ => {
3782 return Err(TidyError::TypeMismatch {
3783 expected: "numeric (Int or Float)".into(),
3784 got: col.type_name().into(),
3785 })
3786 }
3787 };
3788
3789 match agg {
3790 TidyAgg::Sum(_) => Ok(kahan_sum_f64(&values)),
3791 TidyAgg::Mean(_) => {
3792 if values.is_empty() {
3793 Ok(f64::NAN)
3794 } else {
3795 Ok(kahan_sum_f64(&values) / values.len() as f64)
3796 }
3797 }
3798 TidyAgg::Min(_) => {
3799 if values.is_empty() {
3800 Ok(f64::NAN)
3801 } else {
3802 Ok(values.iter().cloned().fold(f64::INFINITY, |a, b| {
3803 if b.is_nan() || b < a { b } else { a }
3804 }))
3805 }
3806 }
3807 TidyAgg::Max(_) => {
3808 if values.is_empty() {
3809 Ok(f64::NAN)
3810 } else {
3811 Ok(values.iter().cloned().fold(f64::NEG_INFINITY, |a, b| {
3812 if b.is_nan() || b > a { b } else { a }
3813 }))
3814 }
3815 }
3816 TidyAgg::First(_) => {
3817 if values.is_empty() {
3818 Err(TidyError::EmptyGroup)
3819 } else {
3820 Ok(values[0])
3821 }
3822 }
3823 TidyAgg::Last(_) => {
3824 if values.is_empty() {
3825 Err(TidyError::EmptyGroup)
3826 } else {
3827 Ok(*values.last().unwrap())
3828 }
3829 }
3830 TidyAgg::Count => Ok(values.len() as f64),
3831 TidyAgg::Median(_) => {
3832 if values.is_empty() {
3833 Ok(f64::NAN)
3834 } else {
3835 let mut sorted = values.clone();
3836 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3837 let n = sorted.len();
3838 if n % 2 == 1 {
3839 Ok(sorted[n / 2])
3840 } else {
3841 Ok((sorted[n / 2 - 1] + sorted[n / 2]) / 2.0)
3842 }
3843 }
3844 }
3845 TidyAgg::Var(_) => {
3846 if values.len() < 2 {
3847 Ok(f64::NAN)
3848 } else {
3849 let n = values.len() as f64;
3850 let mean = kahan_sum_f64(&values) / n;
3851 let sq_diffs: Vec<f64> = values.iter().map(|v| (v - mean) * (v - mean)).collect();
3852 Ok(kahan_sum_f64(&sq_diffs) / (n - 1.0))
3853 }
3854 }
3855 TidyAgg::Sd(_) => {
3856 if values.len() < 2 {
3857 Ok(f64::NAN)
3858 } else {
3859 let n = values.len() as f64;
3860 let mean = kahan_sum_f64(&values) / n;
3861 let sq_diffs: Vec<f64> = values.iter().map(|v| (v - mean) * (v - mean)).collect();
3862 Ok((kahan_sum_f64(&sq_diffs) / (n - 1.0)).sqrt())
3863 }
3864 }
3865 TidyAgg::Quantile(_, p) => {
3866 if values.is_empty() {
3867 Ok(f64::NAN)
3868 } else {
3869 let mut sorted = values.clone();
3870 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3871 let n = sorted.len();
3872 if n == 1 {
3873 return Ok(sorted[0]);
3874 }
3875 let pos = p * (n as f64 - 1.0);
3876 let lo = pos.floor() as usize;
3877 let hi = pos.ceil() as usize;
3878 let frac = pos - lo as f64;
3879 if lo == hi || hi >= n {
3880 Ok(sorted[lo.min(n - 1)])
3881 } else {
3882 Ok(sorted[lo] + frac * (sorted[hi] - sorted[lo]))
3883 }
3884 }
3885 }
3886 TidyAgg::NDistinct(_) => {
3887 use std::collections::BTreeSet;
3888 let distinct: BTreeSet<u64> = values.iter().map(|v| v.to_bits()).collect();
3889 Ok(distinct.len() as f64)
3890 }
3891 TidyAgg::Iqr(_) => {
3892 if values.is_empty() {
3893 Ok(f64::NAN)
3894 } else {
3895 let mut sorted = values.clone();
3896 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
3897 let n = sorted.len();
3898 if n == 1 {
3899 return Ok(0.0);
3900 }
3901 let q1 = {
3902 let pos = 0.25 * (n as f64 - 1.0);
3903 let lo = pos.floor() as usize;
3904 let hi = pos.ceil() as usize;
3905 let frac = pos - lo as f64;
3906 if lo == hi || hi >= n { sorted[lo.min(n - 1)] }
3907 else { sorted[lo] + frac * (sorted[hi] - sorted[lo]) }
3908 };
3909 let q3 = {
3910 let pos = 0.75 * (n as f64 - 1.0);
3911 let lo = pos.floor() as usize;
3912 let hi = pos.ceil() as usize;
3913 let frac = pos - lo as f64;
3914 if lo == hi || hi >= n { sorted[lo.min(n - 1)] }
3915 else { sorted[lo] + frac * (sorted[hi] - sorted[lo]) }
3916 };
3917 Ok(q3 - q1)
3918 }
3919 }
3920 }
3921}
3922
3923#[derive(Debug, Clone)]
3927pub enum TidyAgg {
3928 Count,
3930 Sum(String),
3932 Mean(String),
3934 Min(String),
3936 Max(String),
3938 First(String),
3940 Last(String),
3942 Median(String),
3944 Sd(String),
3946 Var(String),
3948 Quantile(String, f64),
3950 NDistinct(String),
3952 Iqr(String),
3954}
3955
3956#[derive(Debug, Clone)]
3960pub struct ArrangeKey {
3961 pub col_name: String,
3963 pub descending: bool,
3965}
3966
3967impl ArrangeKey {
3968 pub fn asc(col_name: &str) -> Self {
3969 ArrangeKey { col_name: col_name.to_string(), descending: false }
3970 }
3971 pub fn desc(col_name: &str) -> Self {
3972 ArrangeKey { col_name: col_name.to_string(), descending: true }
3973 }
3974}
3975
3976impl TidyView {
3979
3980 pub fn group_by(&self, keys: &[&str]) -> Result<GroupedTidyView, TidyError> {
3994 let mut key_col_indices = Vec::with_capacity(keys.len());
3996 for &key in keys {
3997 let idx = self
3998 .base
3999 .columns
4000 .iter()
4001 .position(|(n, _)| n == key)
4002 .ok_or_else(|| TidyError::ColumnNotFound(key.to_string()))?;
4003 key_col_indices.push(idx);
4004 }
4005
4006 let visible_rows: Vec<usize> = self.mask.iter_set().collect();
4007 let key_names: Vec<String> = keys.iter().map(|s| s.to_string()).collect();
4008
4009 let index = GroupIndex::build_fast(&self.base, &key_col_indices, &visible_rows, key_names);
4011
4012 Ok(GroupedTidyView {
4013 view: self.clone(),
4014 index,
4015 })
4016 }
4017
4018 pub fn arrange(&self, keys: &[ArrangeKey]) -> Result<TidyView, TidyError> {
4038 for key in keys {
4040 if self.base.get_column(&key.col_name).is_none() {
4041 return Err(TidyError::ColumnNotFound(key.col_name.clone()));
4042 }
4043 }
4044
4045 let mut row_indices: Vec<usize> = self.mask.iter_set().collect();
4047
4048 row_indices.sort_by(|&a, &b| {
4050 for key in keys {
4051 let col = self.base.get_column(&key.col_name).unwrap();
4052 let ord = compare_column_rows(col, a, b);
4053 let ord = if key.descending { ord.reverse() } else { ord };
4054 if ord != std::cmp::Ordering::Equal {
4055 return ord;
4056 }
4057 }
4058 std::cmp::Ordering::Equal
4059 });
4060
4061 let mut new_columns = Vec::with_capacity(self.proj.len());
4063 for &ci in self.proj.indices() {
4064 let (name, col) = &self.base.columns[ci];
4065 let new_col = gather_column(col, &row_indices);
4066 new_columns.push((name.clone(), new_col));
4067 }
4068 let mut sorted_all_cols = Vec::with_capacity(self.base.ncols());
4071 for (name, col) in &self.base.columns {
4072 sorted_all_cols.push((name.clone(), gather_column(col, &row_indices)));
4073 }
4074
4075 let new_base = DataFrame::from_columns(sorted_all_cols)
4076 .map_err(|e| TidyError::Internal(e.to_string()))?;
4077 let nrows = new_base.nrows();
4078 let new_proj = self.proj.clone();
4079
4080 Ok(TidyView {
4081 base: Rc::new(new_base),
4082 mask: BitMask::all_true(nrows),
4083 proj: new_proj,
4084 })
4085 }
4086
4087 pub fn slice(&self, start: usize, end: usize) -> TidyView {
4094 let visible: Vec<usize> = self.mask.iter_set().collect();
4095 let n = visible.len();
4096 let s = start.min(n);
4097 let e = end.min(n);
4098 let selected = if s >= e { vec![] } else { visible[s..e].to_vec() };
4099 self.view_from_row_indices(selected)
4100 }
4101
4102 pub fn slice_head(&self, n: usize) -> TidyView {
4104 self.slice(0, n)
4105 }
4106
4107 pub fn slice_tail(&self, n: usize) -> TidyView {
4109 let total = self.mask.count_ones();
4110 let start = total.saturating_sub(n);
4111 self.slice(start, total)
4112 }
4113
4114 pub fn slice_sample(&self, n: usize, seed: u64) -> TidyView {
4119 let mut visible: Vec<usize> = self.mask.iter_set().collect();
4120 let total = visible.len();
4121 if n >= total {
4122 return self.view_from_row_indices(visible);
4123 }
4124 let mut rng = seed;
4126 let selected_count = n;
4127 for i in 0..selected_count {
4128 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
4130 let j = i + (rng as usize % (total - i));
4131 visible.swap(i, j);
4132 }
4133 visible.truncate(selected_count);
4134 visible.sort_unstable();
4136 self.view_from_row_indices(visible)
4137 }
4138
4139 pub fn distinct(&self, cols: &[&str]) -> Result<TidyView, TidyError> {
4151 let mut col_indices = Vec::with_capacity(cols.len());
4153 for &name in cols {
4154 let idx = self
4155 .base
4156 .columns
4157 .iter()
4158 .position(|(n, _)| n == name)
4159 .ok_or_else(|| TidyError::ColumnNotFound(name.to_string()))?;
4160 col_indices.push(idx);
4161 }
4162
4163 let mut seen_keys: BTreeSet<Vec<String>> = BTreeSet::new();
4165 let mut selected_rows: Vec<usize> = Vec::new();
4166
4167 for row in self.mask.iter_set() {
4168 let key: Vec<String> = if col_indices.is_empty() {
4169 vec!["__all__".into()]
4170 } else {
4171 col_indices
4172 .iter()
4173 .map(|&ci| self.base.columns[ci].1.get_display(row))
4174 .collect()
4175 };
4176
4177 if seen_keys.insert(key) {
4178 selected_rows.push(row);
4179 }
4180 }
4181
4182 Ok(self.view_from_row_indices(selected_rows))
4183 }
4184
4185 pub fn inner_join(
4198 &self,
4199 right: &TidyView,
4200 on: &[(&str, &str)],
4201 ) -> Result<TidyFrame, TidyError> {
4202 let (left_rows, right_rows) = join_match_rows(self, right, on, JoinKind::Inner)?;
4203 build_join_frame(self, right, &left_rows, &right_rows, on, false)
4204 }
4205
4206 pub fn left_join(
4210 &self,
4211 right: &TidyView,
4212 on: &[(&str, &str)],
4213 ) -> Result<TidyFrame, TidyError> {
4214 let (left_rows, right_rows_opt) =
4215 join_match_rows_optional(self, right, on, JoinKind::Left)?;
4216 build_left_join_frame(self, right, &left_rows, &right_rows_opt, on)
4217 }
4218
4219 pub fn semi_join(
4223 &self,
4224 right: &TidyView,
4225 on: &[(&str, &str)],
4226 ) -> Result<TidyView, TidyError> {
4227 let included = semi_anti_match_rows(self, right, on, true)?;
4228 Ok(self.view_from_row_indices(included))
4229 }
4230
4231 pub fn anti_join(
4235 &self,
4236 right: &TidyView,
4237 on: &[(&str, &str)],
4238 ) -> Result<TidyView, TidyError> {
4239 let included = semi_anti_match_rows(self, right, on, false)?;
4240 Ok(self.view_from_row_indices(included))
4241 }
4242
4243 fn view_from_row_indices(&self, row_indices: Vec<usize>) -> TidyView {
4248 let nrows_base = self.base.nrows();
4249 let mut words = vec![0u64; nwords_for(nrows_base)];
4250 for &r in &row_indices {
4251 words[r / 64] |= 1u64 << (r % 64);
4252 }
4253 TidyView {
4254 base: Rc::clone(&self.base),
4255 mask: BitMask { words, nrows: nrows_base },
4256 proj: self.proj.clone(),
4257 }
4258 }
4259}
4260
4261#[derive(Clone, Copy)]
4264enum JoinKind { Inner, Left }
4265
4266fn resolve_join_keys(
4268 left: &TidyView,
4269 right: &TidyView,
4270 on: &[(&str, &str)],
4271) -> Result<(Vec<usize>, Vec<usize>), TidyError> {
4272 let mut li = Vec::new();
4273 let mut ri = Vec::new();
4274 for &(lk, rk) in on {
4275 let l = left.base.columns.iter().position(|(n, _)| n == lk)
4276 .ok_or_else(|| TidyError::ColumnNotFound(lk.to_string()))?;
4277 let r = right.base.columns.iter().position(|(n, _)| n == rk)
4278 .ok_or_else(|| TidyError::ColumnNotFound(rk.to_string()))?;
4279 li.push(l);
4280 ri.push(r);
4281 }
4282 Ok((li, ri))
4283}
4284
4285fn row_key(base: &DataFrame, col_indices: &[usize], row: usize) -> Vec<String> {
4287 col_indices
4288 .iter()
4289 .map(|&ci| base.columns[ci].1.get_display(row))
4290 .collect()
4291}
4292
4293fn build_right_lookup(
4296 right: &TidyView,
4297 right_key_cols: &[usize],
4298) -> Vec<(Vec<String>, usize)> {
4299 let mut lookup: Vec<(Vec<String>, usize)> = right
4300 .mask
4301 .iter_set()
4302 .map(|r| (row_key(&right.base, right_key_cols, r), r))
4303 .collect();
4304 lookup.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
4306 lookup
4307}
4308
4309fn find_matches(lookup: &[(Vec<String>, usize)], key: &[String]) -> Vec<usize> {
4311 let key_vec = key.to_vec();
4313 let start = lookup.partition_point(|(k, _)| k.as_slice() < key_vec.as_slice());
4314 let mut matches = Vec::new();
4315 for (k, r) in &lookup[start..] {
4316 if k == &key_vec {
4317 matches.push(*r);
4318 } else {
4319 break;
4320 }
4321 }
4322 matches
4323}
4324
4325fn build_right_lookup_btree(
4330 right: &TidyView,
4331 right_key_cols: &[usize],
4332) -> BTreeMap<Vec<String>, Vec<usize>> {
4333 let mut lookup: BTreeMap<Vec<String>, Vec<usize>> = BTreeMap::new();
4334 for r in right.mask.iter_set() {
4335 let key = row_key(&right.base, right_key_cols, r);
4336 lookup.entry(key).or_default().push(r);
4337 }
4338 lookup
4339}
4340
4341fn join_match_rows(
4343 left: &TidyView,
4344 right: &TidyView,
4345 on: &[(&str, &str)],
4346 _kind: JoinKind,
4347) -> Result<(Vec<usize>, Vec<usize>), TidyError> {
4348 let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
4349 let lookup = build_right_lookup_btree(right, &right_key_cols);
4351
4352 let mut out_left = Vec::new();
4353 let mut out_right = Vec::new();
4354
4355 for l_row in left.mask.iter_set() {
4356 let key = row_key(&left.base, &left_key_cols, l_row);
4357 if let Some(matches) = lookup.get(&key) {
4358 for &r_row in matches {
4359 out_left.push(l_row);
4360 out_right.push(r_row);
4361 }
4362 }
4363 }
4364 Ok((out_left, out_right))
4365}
4366
4367fn join_match_rows_optional(
4369 left: &TidyView,
4370 right: &TidyView,
4371 on: &[(&str, &str)],
4372 _kind: JoinKind,
4373) -> Result<(Vec<usize>, Vec<Option<usize>>), TidyError> {
4374 let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
4375 let lookup = build_right_lookup_btree(right, &right_key_cols);
4377
4378 let mut out_left = Vec::new();
4379 let mut out_right: Vec<Option<usize>> = Vec::new();
4380
4381 for l_row in left.mask.iter_set() {
4382 let key = row_key(&left.base, &left_key_cols, l_row);
4383 match lookup.get(&key) {
4384 Some(matches) if !matches.is_empty() => {
4385 for &r_row in matches {
4386 out_left.push(l_row);
4387 out_right.push(Some(r_row));
4388 }
4389 }
4390 _ => {
4391 out_left.push(l_row);
4392 out_right.push(None);
4393 }
4394 }
4395 }
4396 Ok((out_left, out_right))
4397}
4398
4399fn semi_anti_match_rows(
4401 left: &TidyView,
4402 right: &TidyView,
4403 on: &[(&str, &str)],
4404 semi: bool,
4405) -> Result<Vec<usize>, TidyError> {
4406 let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
4407 let lookup = build_right_lookup_btree(right, &right_key_cols);
4409
4410 let mut out = Vec::new();
4411 for l_row in left.mask.iter_set() {
4412 let key = row_key(&left.base, &left_key_cols, l_row);
4413 let has_match = lookup.contains_key(&key);
4414 if has_match == semi {
4415 out.push(l_row);
4416 }
4417 }
4418 Ok(out)
4419}
4420
4421fn build_join_frame(
4424 left: &TidyView,
4425 right: &TidyView,
4426 left_rows: &[usize],
4427 right_rows: &[usize],
4428 on: &[(&str, &str)],
4429 _include_unmatched: bool,
4430) -> Result<TidyFrame, TidyError> {
4431 let right_key_names: std::collections::BTreeSet<&str> =
4432 on.iter().map(|(_, rk)| *rk).collect();
4433
4434 let n = left_rows.len();
4435 let mut columns: Vec<(String, Column)> = Vec::new();
4436
4437 for &ci in left.proj.indices() {
4439 let (name, col) = &left.base.columns[ci];
4440 columns.push((name.clone(), gather_column(col, left_rows)));
4441 }
4442
4443 for &ci in right.proj.indices() {
4445 let (name, col) = &right.base.columns[ci];
4446 if right_key_names.contains(name.as_str()) {
4447 continue;
4448 }
4449 columns.push((name.clone(), gather_column(col, right_rows)));
4450 }
4451
4452 assert_eq!(n, left_rows.len());
4453 let df = DataFrame::from_columns(columns)
4454 .map_err(|e| TidyError::Internal(e.to_string()))?;
4455 Ok(TidyFrame::from_df(df))
4456}
4457
4458fn build_left_join_frame(
4460 left: &TidyView,
4461 right: &TidyView,
4462 left_rows: &[usize],
4463 right_rows_opt: &[Option<usize>],
4464 on: &[(&str, &str)],
4465) -> Result<TidyFrame, TidyError> {
4466 let right_key_names: std::collections::BTreeSet<&str> =
4467 on.iter().map(|(_, rk)| *rk).collect();
4468
4469 let mut columns: Vec<(String, Column)> = Vec::new();
4470
4471 for &ci in left.proj.indices() {
4473 let (name, col) = &left.base.columns[ci];
4474 columns.push((name.clone(), gather_column(col, left_rows)));
4475 }
4476
4477 for &ci in right.proj.indices() {
4479 let (name, col) = &right.base.columns[ci];
4480 if right_key_names.contains(name.as_str()) {
4481 continue;
4482 }
4483 let new_col = gather_column_nullable(col, right_rows_opt);
4484 columns.push((name.clone(), new_col));
4485 }
4486
4487 let df = DataFrame::from_columns(columns)
4488 .map_err(|e| TidyError::Internal(e.to_string()))?;
4489 Ok(TidyFrame::from_df(df))
4490}
4491
4492fn compare_column_rows(col: &Column, a: usize, b: usize) -> std::cmp::Ordering {
4499 match col {
4500 Column::Int(v) => v[a].cmp(&v[b]),
4501 Column::Float(v) => {
4502 let va = v[a];
4503 let vb = v[b];
4504 match (va.is_nan(), vb.is_nan()) {
4505 (true, true) => std::cmp::Ordering::Equal,
4506 (true, false) => std::cmp::Ordering::Greater, (false, true) => std::cmp::Ordering::Less,
4508 (false, false) => va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal),
4509 }
4510 }
4511 Column::Bool(v) => v[a].cmp(&v[b]),
4512 Column::Str(v) => v[a].cmp(&v[b]),
4513 Column::Categorical { levels, codes } => {
4514 levels[codes[a] as usize].cmp(&levels[codes[b] as usize])
4516 }
4517 Column::DateTime(v) => v[a].cmp(&v[b]),
4518 }
4519}
4520
4521#[cfg(test)]
4542mod phase10_unit_tests {
4543 use super::*;
4544
4545 fn make_df() -> DataFrame {
4546 DataFrame::from_columns(vec![
4547 ("x".into(), Column::Int(vec![1, 2, 3, 4, 5])),
4548 ("y".into(), Column::Float(vec![1.0, 2.0, 3.0, 4.0, 5.0])),
4549 ("flag".into(), Column::Bool(vec![true, false, true, false, true])),
4550 ])
4551 .unwrap()
4552 }
4553
4554 #[test]
4555 fn bitmask_all_true_count() {
4556 let m = BitMask::all_true(5);
4557 assert_eq!(m.count_ones(), 5);
4558 }
4559
4560 #[test]
4561 fn bitmask_all_false_count() {
4562 let m = BitMask::all_false(5);
4563 assert_eq!(m.count_ones(), 0);
4564 }
4565
4566 #[test]
4567 fn bitmask_tail_bits_clean() {
4568 let m = BitMask::all_true(65);
4570 assert_eq!(m.count_ones(), 65);
4571 assert_eq!(m.words.len(), 2);
4572 assert_eq!(m.words[1], 1u64); }
4574
4575 #[test]
4576 fn bitmask_and_semantics() {
4577 let a = BitMask::from_bools(&[true, true, false, true]);
4578 let b = BitMask::from_bools(&[true, false, true, true]);
4579 let c = a.and(&b);
4580 assert!(c.get(0));
4581 assert!(!c.get(1));
4582 assert!(!c.get(2));
4583 assert!(c.get(3));
4584 }
4585
4586 #[test]
4587 fn tidy_view_nrows_ncols() {
4588 let df = make_df();
4589 let v = df.tidy();
4590 assert_eq!(v.nrows(), 5);
4591 assert_eq!(v.ncols(), 3);
4592 }
4593
4594 #[test]
4595 fn filter_basic() {
4596 let df = make_df();
4597 let v = df.tidy();
4598 let filtered = v
4599 .filter(&DExpr::BinOp {
4600 op: DBinOp::Gt,
4601 left: Box::new(DExpr::Col("x".into())),
4602 right: Box::new(DExpr::LitInt(2)),
4603 })
4604 .unwrap();
4605 assert_eq!(filtered.nrows(), 3);
4606 }
4607
4608 #[test]
4609 fn filter_empty_df() {
4610 let df = DataFrame::from_columns(vec![
4611 ("x".into(), Column::Int(vec![])),
4612 ])
4613 .unwrap();
4614 let v = df.tidy();
4615 let filtered = v
4616 .filter(&DExpr::BinOp {
4617 op: DBinOp::Gt,
4618 left: Box::new(DExpr::Col("x".into())),
4619 right: Box::new(DExpr::LitInt(0)),
4620 })
4621 .unwrap();
4622 assert_eq!(filtered.nrows(), 0);
4623 }
4624
4625 #[test]
4626 fn select_reorder() {
4627 let df = make_df();
4628 let v = df.tidy();
4629 let s = v.select(&["y", "x"]).unwrap();
4630 assert_eq!(s.column_names(), vec!["y", "x"]);
4631 }
4632
4633 #[test]
4634 fn select_zero_cols() {
4635 let df = make_df();
4636 let v = df.tidy();
4637 let s = v.select(&[]).unwrap();
4638 assert_eq!(s.ncols(), 0);
4639 assert_eq!(s.nrows(), 5);
4640 }
4641
4642 #[test]
4643 fn select_unknown_col() {
4644 let df = make_df();
4645 let v = df.tidy();
4646 let err = v.select(&["nonexistent"]).unwrap_err();
4647 assert!(matches!(err, TidyError::ColumnNotFound(_)));
4648 }
4649
4650 #[test]
4651 fn select_duplicate_col() {
4652 let df = make_df();
4653 let v = df.tidy();
4654 let err = v.select(&["x", "x"]).unwrap_err();
4655 assert!(matches!(err, TidyError::DuplicateColumn(_)));
4656 }
4657
4658 #[test]
4659 fn mutate_new_col() {
4660 let df = make_df();
4661 let v = df.tidy();
4662 let frame = v
4663 .mutate(&[("z", DExpr::BinOp {
4664 op: DBinOp::Add,
4665 left: Box::new(DExpr::Col("x".into())),
4666 right: Box::new(DExpr::LitInt(10)),
4667 })])
4668 .unwrap();
4669 let b = frame.borrow();
4670 let z = b.get_column("z").unwrap();
4671 assert_eq!(z.len(), 5);
4672 if let Column::Int(v) = z {
4673 assert_eq!(v[0], 11);
4674 assert_eq!(v[4], 15);
4675 } else {
4676 panic!("expected Int column");
4677 }
4678 }
4679
4680 #[test]
4681 fn mutate_type_promotion() {
4682 let df = make_df();
4683 let v = df.tidy();
4684 let frame = v
4686 .mutate(&[("promoted", DExpr::BinOp {
4687 op: DBinOp::Add,
4688 left: Box::new(DExpr::Col("x".into())),
4689 right: Box::new(DExpr::Col("y".into())),
4690 })])
4691 .unwrap();
4692 let b = frame.borrow();
4693 let col = b.get_column("promoted").unwrap();
4694 assert!(matches!(col, Column::Float(_)));
4695 }
4696}
4697
4698impl TidyError {
4750 pub fn schema_mismatch(msg: impl Into<String>) -> Self {
4752 TidyError::Internal(format!("schema mismatch: {}", msg.into()))
4753 }
4754 pub fn join_type_mismatch(col: &str, lt: &str, rt: &str) -> Self {
4756 TidyError::TypeMismatch {
4757 expected: format!("{} (from left key `{}`)", lt, col),
4758 got: rt.to_string(),
4759 }
4760 }
4761 pub fn duplicate_key(key: impl Into<String>) -> Self {
4763 TidyError::DuplicateColumn(format!("duplicate key: {}", key.into()))
4764 }
4765 pub fn empty_selection(msg: impl Into<String>) -> Self {
4767 TidyError::Internal(format!("empty selection: {}", msg.into()))
4768 }
4769}
4770
4771#[derive(Debug, Clone)]
4780pub struct NullableColumn<T: Clone> {
4781 pub values: Vec<T>,
4782 pub validity: BitMask,
4783}
4784
4785impl<T: Clone + Default> NullableColumn<T> {
4786 pub fn from_values(values: Vec<T>) -> Self {
4788 let n = values.len();
4789 Self {
4790 values,
4791 validity: BitMask::all_true(n),
4792 }
4793 }
4794
4795 pub fn new(values: Vec<T>, validity: BitMask) -> Self {
4798 assert_eq!(values.len(), validity.nrows(), "NullableColumn: length mismatch");
4799 Self { values, validity }
4800 }
4801
4802 pub fn len(&self) -> usize {
4804 self.values.len()
4805 }
4806
4807 pub fn is_null(&self, i: usize) -> bool {
4809 !self.validity.get(i)
4810 }
4811
4812 pub fn get(&self, i: usize) -> Option<&T> {
4814 if self.validity.get(i) { Some(&self.values[i]) } else { None }
4815 }
4816
4817 pub fn count_valid(&self) -> usize {
4819 self.validity.count_ones()
4820 }
4821
4822 pub fn gather(&self, indices: &[usize]) -> Self {
4824 let mut vals = Vec::with_capacity(indices.len());
4825 let mut bools = Vec::with_capacity(indices.len());
4826 for &i in indices {
4827 vals.push(self.values[i].clone());
4828 bools.push(self.validity.get(i));
4829 }
4830 let validity = BitMask::from_bools(&bools);
4831 Self { values: vals, validity }
4832 }
4833}
4834
4835#[derive(Debug, Clone)]
4848pub enum NullCol {
4849 Int(NullableColumn<i64>),
4850 Float(NullableColumn<f64>),
4851 Str(NullableColumn<String>),
4852 Bool(NullableColumn<bool>),
4853}
4854
4855impl NullCol {
4856 pub fn len(&self) -> usize {
4857 match self {
4858 NullCol::Int(c) => c.len(),
4859 NullCol::Float(c) => c.len(),
4860 NullCol::Str(c) => c.len(),
4861 NullCol::Bool(c) => c.len(),
4862 }
4863 }
4864
4865 pub fn is_null(&self, i: usize) -> bool {
4866 match self {
4867 NullCol::Int(c) => c.is_null(i),
4868 NullCol::Float(c) => c.is_null(i),
4869 NullCol::Str(c) => c.is_null(i),
4870 NullCol::Bool(c) => c.is_null(i),
4871 }
4872 }
4873
4874 pub fn type_name(&self) -> &'static str {
4875 match self {
4876 NullCol::Int(_) => "Int",
4877 NullCol::Float(_) => "Float",
4878 NullCol::Str(_) => "Str",
4879 NullCol::Bool(_) => "Bool",
4880 }
4881 }
4882
4883 pub fn from_column(col: &Column) -> Self {
4885 match col {
4886 Column::Int(v) => NullCol::Int(NullableColumn::from_values(v.clone())),
4887 Column::Float(v) => NullCol::Float(NullableColumn::from_values(v.clone())),
4888 Column::Str(v) => NullCol::Str(NullableColumn::from_values(v.clone())),
4889 Column::Bool(v) => NullCol::Bool(NullableColumn::from_values(v.clone())),
4890 Column::Categorical { levels, codes } => {
4892 let strings: Vec<String> = codes.iter().map(|&c| levels[c as usize].clone()).collect();
4893 NullCol::Str(NullableColumn::from_values(strings))
4894 }
4895 Column::DateTime(v) => NullCol::Int(NullableColumn::from_values(v.clone())),
4896 }
4897 }
4898
4899 pub fn to_column_strict(&self) -> Result<Column, TidyError> {
4902 match self {
4903 NullCol::Int(nc) => {
4904 if nc.count_valid() == nc.len() {
4905 Ok(Column::Int(nc.values.clone()))
4906 } else {
4907 Err(TidyError::Internal("null values in non-nullable context".into()))
4908 }
4909 }
4910 NullCol::Float(nc) => {
4911 if nc.count_valid() == nc.len() {
4912 Ok(Column::Float(nc.values.clone()))
4913 } else {
4914 Err(TidyError::Internal("null values in non-nullable context".into()))
4915 }
4916 }
4917 NullCol::Str(nc) => {
4918 if nc.count_valid() == nc.len() {
4919 Ok(Column::Str(nc.values.clone()))
4920 } else {
4921 Err(TidyError::Internal("null values in non-nullable context".into()))
4922 }
4923 }
4924 NullCol::Bool(nc) => {
4925 if nc.count_valid() == nc.len() {
4926 Ok(Column::Bool(nc.values.clone()))
4927 } else {
4928 Err(TidyError::Internal("null values in non-nullable context".into()))
4929 }
4930 }
4931 }
4932 }
4933
4934 pub fn to_column_filled(&self) -> Column {
4937 match self {
4938 NullCol::Int(nc) => Column::Int(nc.values.clone()),
4939 NullCol::Float(nc) => {
4940 let v: Vec<f64> = (0..nc.len())
4941 .map(|i| if nc.is_null(i) { f64::NAN } else { nc.values[i] })
4942 .collect();
4943 Column::Float(v)
4944 }
4945 NullCol::Str(nc) => Column::Str(nc.values.clone()),
4946 NullCol::Bool(nc) => Column::Bool(nc.values.clone()),
4947 }
4948 }
4949
4950 pub fn get_display(&self, i: usize) -> String {
4952 if self.is_null(i) {
4953 return "null".to_string();
4954 }
4955 match self {
4956 NullCol::Int(nc) => format!("{}", nc.values[i]),
4957 NullCol::Float(nc) => format!("{}", nc.values[i]),
4958 NullCol::Str(nc) => nc.values[i].clone(),
4959 NullCol::Bool(nc) => format!("{}", nc.values[i]),
4960 }
4961 }
4962
4963 pub fn null_of_type(type_name: &str, len: usize) -> Self {
4965 match type_name {
4966 "Int" => NullCol::Int(NullableColumn {
4967 values: vec![0i64; len],
4968 validity: BitMask::all_false(len),
4969 }),
4970 "Float" => NullCol::Float(NullableColumn {
4971 values: vec![0.0f64; len],
4972 validity: BitMask::all_false(len),
4973 }),
4974 "Bool" => NullCol::Bool(NullableColumn {
4975 values: vec![false; len],
4976 validity: BitMask::all_false(len),
4977 }),
4978 _ => NullCol::Str(NullableColumn {
4979 values: vec![String::new(); len],
4980 validity: BitMask::all_false(len),
4981 }),
4982 }
4983 }
4984
4985 pub fn gather(&self, indices: &[usize]) -> Self {
4987 match self {
4988 NullCol::Int(nc) => NullCol::Int(nc.gather(indices)),
4989 NullCol::Float(nc) => NullCol::Float(nc.gather(indices)),
4990 NullCol::Str(nc) => NullCol::Str(nc.gather(indices)),
4991 NullCol::Bool(nc) => NullCol::Bool(nc.gather(indices)),
4992 }
4993 }
4994}
4995
4996#[derive(Debug, Clone)]
4999pub struct NullableFrame {
5000 pub columns: Vec<(String, NullCol)>,
5001}
5002
5003impl NullableFrame {
5004 pub fn new() -> Self {
5005 Self { columns: Vec::new() }
5006 }
5007
5008 pub fn nrows(&self) -> usize {
5009 self.columns.first().map(|(_, c)| c.len()).unwrap_or(0)
5010 }
5011
5012 pub fn ncols(&self) -> usize {
5013 self.columns.len()
5014 }
5015
5016 pub fn column_names(&self) -> Vec<&str> {
5017 self.columns.iter().map(|(n, _)| n.as_str()).collect()
5018 }
5019
5020 pub fn get_column(&self, name: &str) -> Option<&NullCol> {
5021 self.columns.iter().find(|(n, _)| n == name).map(|(_, c)| c)
5022 }
5023
5024 pub fn to_dataframe_filled(&self) -> DataFrame {
5026 let cols: Vec<(String, Column)> = self.columns.iter()
5027 .map(|(n, c)| (n.clone(), c.to_column_filled()))
5028 .collect();
5029 DataFrame { columns: cols }
5031 }
5032
5033 pub fn to_tidy_frame_filled(&self) -> TidyFrame {
5035 TidyFrame::from_df(self.to_dataframe_filled())
5036 }
5037
5038 pub fn to_tidy_view_filled(&self) -> TidyView {
5040 TidyView::from_df(self.to_dataframe_filled())
5041 }
5042}
5043
5044impl Default for NullableFrame {
5045 fn default() -> Self { Self::new() }
5046}
5047
5048fn gather_column_nullable_null(col: &Column, indices: &[Option<usize>]) -> NullCol {
5053 match col {
5054 Column::Int(v) => {
5055 let mut vals = Vec::with_capacity(indices.len());
5056 let mut valid = Vec::with_capacity(indices.len());
5057 for &idx in indices {
5058 match idx {
5059 Some(i) => { vals.push(v[i]); valid.push(true); }
5060 None => { vals.push(0); valid.push(false); }
5061 }
5062 }
5063 NullCol::Int(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5064 }
5065 Column::Float(v) => {
5066 let mut vals = Vec::with_capacity(indices.len());
5067 let mut valid = Vec::with_capacity(indices.len());
5068 for &idx in indices {
5069 match idx {
5070 Some(i) => { vals.push(v[i]); valid.push(true); }
5071 None => { vals.push(0.0); valid.push(false); }
5072 }
5073 }
5074 NullCol::Float(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5075 }
5076 Column::Str(v) => {
5077 let mut vals = Vec::with_capacity(indices.len());
5078 let mut valid = Vec::with_capacity(indices.len());
5079 for &idx in indices {
5080 match idx {
5081 Some(i) => { vals.push(v[i].clone()); valid.push(true); }
5082 None => { vals.push(String::new()); valid.push(false); }
5083 }
5084 }
5085 NullCol::Str(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5086 }
5087 Column::Bool(v) => {
5088 let mut vals = Vec::with_capacity(indices.len());
5089 let mut valid = Vec::with_capacity(indices.len());
5090 for &idx in indices {
5091 match idx {
5092 Some(i) => { vals.push(v[i]); valid.push(true); }
5093 None => { vals.push(false); valid.push(false); }
5094 }
5095 }
5096 NullCol::Bool(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5097 }
5098 Column::Categorical { levels, codes } => {
5099 let mut vals = Vec::with_capacity(indices.len());
5100 let mut valid = Vec::with_capacity(indices.len());
5101 for &idx in indices {
5102 match idx {
5103 Some(i) => { vals.push(levels[codes[i] as usize].clone()); valid.push(true); }
5104 None => { vals.push(String::new()); valid.push(false); }
5105 }
5106 }
5107 NullCol::Str(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5108 }
5109 Column::DateTime(v) => {
5110 let mut vals = Vec::with_capacity(indices.len());
5111 let mut valid = Vec::with_capacity(indices.len());
5112 for &idx in indices {
5113 match idx {
5114 Some(i) => { vals.push(v[i]); valid.push(true); }
5115 None => { vals.push(0); valid.push(false); }
5116 }
5117 }
5118 NullCol::Int(NullableColumn::new(vals, BitMask::from_bools(&valid)))
5119 }
5120 }
5121}
5122
5123pub type AcrossFn = Box<dyn Fn(&str, &Column) -> Result<Column, TidyError>>;
5130
5131pub struct AcrossTransform {
5133 pub fn_name: String,
5135 pub func: AcrossFn,
5137}
5138
5139impl AcrossTransform {
5140 pub fn new(fn_name: impl Into<String>, func: impl Fn(&str, &Column) -> Result<Column, TidyError> + 'static) -> Self {
5141 Self {
5142 fn_name: fn_name.into(),
5143 func: Box::new(func),
5144 }
5145 }
5146}
5147
5148pub struct AcrossSpec {
5150 pub cols: Vec<String>,
5152 pub transform: AcrossTransform,
5154 pub name_template: Option<String>,
5157}
5158
5159impl AcrossSpec {
5160 pub fn new(cols: impl IntoIterator<Item = impl Into<String>>, transform: AcrossTransform) -> Self {
5161 Self {
5162 cols: cols.into_iter().map(|c| c.into()).collect(),
5163 transform,
5164 name_template: None,
5165 }
5166 }
5167
5168 pub fn with_template(mut self, tmpl: impl Into<String>) -> Self {
5169 self.name_template = Some(tmpl.into());
5170 self
5171 }
5172
5173 pub fn output_name(&self, col_name: &str) -> String {
5175 match &self.name_template {
5176 Some(tmpl) => tmpl
5177 .replace("{col}", col_name)
5178 .replace("{fn}", &self.transform.fn_name),
5179 None => format!("{}_{}", col_name, self.transform.fn_name),
5180 }
5181 }
5182}
5183
5184#[derive(Debug, Clone)]
5188pub struct JoinSuffix {
5189 pub left: String,
5190 pub right: String,
5191}
5192
5193impl Default for JoinSuffix {
5194 fn default() -> Self {
5195 Self { left: ".x".into(), right: ".y".into() }
5196 }
5197}
5198
5199impl JoinSuffix {
5200 pub fn new(left: impl Into<String>, right: impl Into<String>) -> Self {
5201 Self { left: left.into(), right: right.into() }
5202 }
5203}
5204
5205fn join_types_compatible(left: &Column, right: &Column) -> bool {
5210 match (left, right) {
5211 (Column::Int(_), Column::Int(_)) => true,
5212 (Column::Float(_), Column::Float(_)) => true,
5213 (Column::Int(_), Column::Float(_)) => true,
5214 (Column::Float(_), Column::Int(_)) => true,
5215 (Column::Str(_), Column::Str(_)) => true,
5216 (Column::Bool(_), Column::Bool(_)) => true,
5217 _ => false,
5218 }
5219}
5220
5221impl TidyView {
5224
5225 pub fn pivot_longer(
5243 &self,
5244 value_cols: &[&str],
5245 names_to: &str,
5246 values_to: &str,
5247 ) -> Result<TidyFrame, TidyError> {
5248 if value_cols.is_empty() {
5249 return Err(TidyError::empty_selection("pivot_longer requires at least one value_col"));
5250 }
5251
5252 let mut seen_vc: Vec<&str> = Vec::new();
5254 let mut vc_indices: Vec<usize> = Vec::new();
5255 for &name in value_cols {
5256 if seen_vc.contains(&name) {
5257 return Err(TidyError::DuplicateColumn(name.to_string()));
5258 }
5259 seen_vc.push(name);
5260 let idx = self.base.columns.iter().position(|(n, _)| n == name)
5261 .ok_or_else(|| TidyError::ColumnNotFound(name.to_string()))?;
5262 vc_indices.push(idx);
5263 }
5264
5265 let first_type = self.base.columns[vc_indices[0]].1.type_name();
5267 for &idx in &vc_indices[1..] {
5268 let t = self.base.columns[idx].1.type_name();
5269 if t != first_type {
5270 return Err(TidyError::TypeMismatch {
5271 expected: first_type.to_string(),
5272 got: t.to_string(),
5273 });
5274 }
5275 }
5276
5277 let vc_set: std::collections::BTreeSet<usize> = vc_indices.iter().copied().collect();
5279 let id_col_indices: Vec<usize> = self.proj.indices().iter()
5280 .copied()
5281 .filter(|i| !vc_set.contains(i))
5282 .collect();
5283
5284 let visible_rows: Vec<usize> = self.mask.iter_set().collect();
5285 let n_out = visible_rows.len() * value_cols.len();
5286
5287 let mut out_cols: Vec<(String, Column)> = Vec::new();
5289 for &id_idx in &id_col_indices {
5290 let (name, col) = &self.base.columns[id_idx];
5291 let new_col = match col {
5292 Column::Int(v) => {
5293 let mut out = Vec::with_capacity(n_out);
5294 for &r in &visible_rows {
5295 for _ in 0..value_cols.len() { out.push(v[r]); }
5296 }
5297 Column::Int(out)
5298 }
5299 Column::Float(v) => {
5300 let mut out = Vec::with_capacity(n_out);
5301 for &r in &visible_rows {
5302 for _ in 0..value_cols.len() { out.push(v[r]); }
5303 }
5304 Column::Float(out)
5305 }
5306 Column::Str(v) => {
5307 let mut out = Vec::with_capacity(n_out);
5308 for &r in &visible_rows {
5309 for _ in 0..value_cols.len() { out.push(v[r].clone()); }
5310 }
5311 Column::Str(out)
5312 }
5313 Column::Bool(v) => {
5314 let mut out = Vec::with_capacity(n_out);
5315 for &r in &visible_rows {
5316 for _ in 0..value_cols.len() { out.push(v[r]); }
5317 }
5318 Column::Bool(out)
5319 }
5320 Column::Categorical { levels, codes } => {
5321 let mut out = Vec::with_capacity(n_out);
5322 for &r in &visible_rows {
5323 for _ in 0..value_cols.len() { out.push(codes[r]); }
5324 }
5325 Column::Categorical { levels: levels.clone(), codes: out }
5326 }
5327 Column::DateTime(v) => {
5328 let mut out = Vec::with_capacity(n_out);
5329 for &r in &visible_rows {
5330 for _ in 0..value_cols.len() { out.push(v[r]); }
5331 }
5332 Column::DateTime(out)
5333 }
5334 };
5335 out_cols.push((name.clone(), new_col));
5336 }
5337
5338 let names_col: Vec<String> = visible_rows.iter()
5340 .flat_map(|_| value_cols.iter().map(|s| s.to_string()))
5341 .collect();
5342 out_cols.push((names_to.to_string(), Column::Str(names_col)));
5343
5344 match &self.base.columns[vc_indices[0]].1 {
5346 Column::Int(_) => {
5347 let mut vals: Vec<i64> = Vec::with_capacity(n_out);
5348 for &r in &visible_rows {
5349 for &vci in &vc_indices {
5350 if let Column::Int(v) = &self.base.columns[vci].1 {
5351 vals.push(v[r]);
5352 }
5353 }
5354 }
5355 out_cols.push((values_to.to_string(), Column::Int(vals)));
5356 }
5357 Column::Float(_) => {
5358 let mut vals: Vec<f64> = Vec::with_capacity(n_out);
5359 for &r in &visible_rows {
5360 for &vci in &vc_indices {
5361 if let Column::Float(v) = &self.base.columns[vci].1 {
5362 vals.push(v[r]);
5363 }
5364 }
5365 }
5366 out_cols.push((values_to.to_string(), Column::Float(vals)));
5367 }
5368 Column::Str(_) => {
5369 let mut vals: Vec<String> = Vec::with_capacity(n_out);
5370 for &r in &visible_rows {
5371 for &vci in &vc_indices {
5372 if let Column::Str(v) = &self.base.columns[vci].1 {
5373 vals.push(v[r].clone());
5374 }
5375 }
5376 }
5377 out_cols.push((values_to.to_string(), Column::Str(vals)));
5378 }
5379 Column::Bool(_) => {
5380 let mut vals: Vec<bool> = Vec::with_capacity(n_out);
5381 for &r in &visible_rows {
5382 for &vci in &vc_indices {
5383 if let Column::Bool(v) = &self.base.columns[vci].1 {
5384 vals.push(v[r]);
5385 }
5386 }
5387 }
5388 out_cols.push((values_to.to_string(), Column::Bool(vals)));
5389 }
5390 Column::Categorical { .. } | Column::DateTime(_) => {
5391 let mut vals: Vec<String> = Vec::with_capacity(n_out);
5393 for &r in &visible_rows {
5394 for &vci in &vc_indices {
5395 vals.push(self.base.columns[vci].1.get_display(r));
5396 }
5397 }
5398 out_cols.push((values_to.to_string(), Column::Str(vals)));
5399 }
5400 }
5401
5402 let df = DataFrame::from_columns(out_cols)
5403 .map_err(|e| TidyError::Internal(e.to_string()))?;
5404 Ok(TidyFrame::from_df(df))
5405 }
5406
5407 pub fn pivot_wider(
5424 &self,
5425 id_cols: &[&str],
5426 names_from: &str,
5427 values_from: &str,
5428 ) -> Result<NullableFrame, TidyError> {
5429 let _names_col_idx = self.base.columns.iter().position(|(n, _)| n == names_from)
5431 .ok_or_else(|| TidyError::ColumnNotFound(names_from.to_string()))?;
5432 let _values_col_idx = self.base.columns.iter().position(|(n, _)| n == values_from)
5433 .ok_or_else(|| TidyError::ColumnNotFound(values_from.to_string()))?;
5434 for &id in id_cols {
5435 let _ = self.base.columns.iter().position(|(n, _)| n == id)
5436 .ok_or_else(|| TidyError::ColumnNotFound(id.to_string()))?;
5437 }
5438
5439 let visible_rows: Vec<usize> = self.mask.iter_set().collect();
5440
5441 let mut key_values: Vec<String> = Vec::new();
5443 for &r in &visible_rows {
5444 let kv = self.base.get_column(names_from).unwrap().get_display(r);
5445 if !key_values.contains(&kv) {
5446 key_values.push(kv);
5447 }
5448 }
5449
5450 let id_col_refs: Vec<&Column> = id_cols.iter()
5453 .map(|&name| self.base.get_column(name).unwrap())
5454 .collect();
5455
5456 let mut id_order: Vec<Vec<String>> = Vec::new(); let mut id_to_slot: Vec<(Vec<String>, usize)> = Vec::new(); for &r in &visible_rows {
5460 let id_key: Vec<String> = id_col_refs.iter()
5461 .map(|col| col.get_display(r))
5462 .collect();
5463 if !id_to_slot.iter().any(|(k, _)| k == &id_key) {
5464 let slot = id_order.len();
5465 id_order.push(id_key.clone());
5466 id_to_slot.push((id_key, slot));
5467 }
5468 }
5469
5470 let n_rows = id_order.len();
5471 let n_keys = key_values.len();
5472
5473 let mut cell_map: Vec<Vec<Option<usize>>> = vec![vec![None; n_keys]; n_rows];
5476
5477 for &r in &visible_rows {
5478 let id_key: Vec<String> = id_col_refs.iter()
5479 .map(|col| col.get_display(r))
5480 .collect();
5481 let id_slot = id_to_slot.iter().find(|(k, _)| k == &id_key).unwrap().1;
5482
5483 let kv = self.base.get_column(names_from).unwrap().get_display(r);
5484 let key_slot = key_values.iter().position(|v| v == &kv).unwrap();
5485
5486 if cell_map[id_slot][key_slot].is_some() {
5487 return Err(TidyError::duplicate_key(
5488 format!("({}, {})", id_key.join(", "), kv)
5489 ));
5490 }
5491 cell_map[id_slot][key_slot] = Some(r);
5492 }
5493
5494 let mut out_cols: Vec<(String, NullCol)> = Vec::new();
5496
5497 for (id_idx, &id_name) in id_cols.iter().enumerate() {
5499 let id_col = self.base.get_column(id_name).unwrap();
5500 let id_row_indices: Vec<usize> = id_order.iter()
5501 .map(|id_tup| {
5502 *visible_rows.iter().find(|&&r| {
5504 id_col_refs.iter().enumerate().all(|(i, col)| {
5505 col.get_display(r) == id_tup[i]
5506 })
5507 }).unwrap()
5508 })
5509 .collect();
5510 let gathered = gather_column(id_col, &id_row_indices);
5511 out_cols.push((id_name.to_string(), NullCol::from_column(&gathered)));
5512 let _ = id_idx;
5513 }
5514
5515 let values_col = self.base.get_column(values_from).unwrap();
5517 let val_type = values_col.type_name();
5518 for (key_slot, key_val) in key_values.iter().enumerate() {
5519 let row_opts: Vec<Option<usize>> = (0..n_rows)
5520 .map(|id_slot| cell_map[id_slot][key_slot])
5521 .collect();
5522 let null_col = gather_column_nullable_null(values_col, &row_opts);
5523 out_cols.push((key_val.clone(), null_col));
5524 let _ = val_type;
5525 }
5526
5527 Ok(NullableFrame { columns: out_cols })
5528 }
5529
5530 pub fn rename(&self, renames: &[(&str, &str)]) -> Result<TidyView, TidyError> {
5541 let mut rename_map: Vec<(usize, String)> = Vec::new();
5543 let col_names: Vec<&str> = self.base.columns.iter().map(|(n, _)| n.as_str()).collect();
5544
5545 for &(old, new) in renames {
5546 let idx = col_names.iter().position(|&n| n == old)
5547 .ok_or_else(|| TidyError::ColumnNotFound(old.to_string()))?;
5548 if old != new {
5550 let new_name_exists = col_names.iter().any(|&n| n == new)
5551 || rename_map.iter().any(|(_, n)| n == new);
5552 if new_name_exists {
5553 return Err(TidyError::DuplicateColumn(new.to_string()));
5554 }
5555 }
5556 rename_map.push((idx, new.to_string()));
5557 }
5558
5559 let mut new_cols: Vec<(String, Column)> = Vec::new();
5561 for (i, (name, col)) in self.base.columns.iter().enumerate() {
5562 let new_name = rename_map.iter()
5563 .find(|(idx, _)| *idx == i)
5564 .map(|(_, n)| n.clone())
5565 .unwrap_or_else(|| name.clone());
5566 new_cols.push((new_name, col.clone()));
5567 }
5568
5569 let new_base = DataFrame { columns: new_cols };
5570 Ok(TidyView {
5571 base: Rc::new(new_base),
5572 mask: self.mask.clone(),
5573 proj: self.proj.clone(),
5574 })
5575 }
5576
5577 pub fn relocate(&self, cols: &[&str], position: RelocatePos<'_>) -> Result<TidyView, TidyError> {
5592 let proj_names: Vec<&str> = self.column_names();
5594 for &name in cols {
5595 if !proj_names.contains(&name) {
5596 return Err(TidyError::ColumnNotFound(name.to_string()));
5597 }
5598 }
5599
5600 let moved_set: std::collections::BTreeSet<&str> = cols.iter().copied().collect();
5602 let remaining: Vec<&str> = proj_names.iter()
5603 .copied()
5604 .filter(|n| !moved_set.contains(n))
5605 .collect();
5606
5607 let new_order: Vec<&str> = match &position {
5608 RelocatePos::Front => {
5609 let mut v: Vec<&str> = cols.to_vec();
5610 v.extend_from_slice(&remaining);
5611 v
5612 }
5613 RelocatePos::Back => {
5614 let mut v = remaining.clone();
5615 v.extend_from_slice(cols);
5616 v
5617 }
5618 RelocatePos::Before(anchor) => {
5619 if !proj_names.contains(anchor) {
5620 return Err(TidyError::ColumnNotFound(anchor.to_string()));
5621 }
5622 let mut v = Vec::new();
5623 for &n in &remaining {
5624 if n == *anchor {
5625 v.extend_from_slice(cols);
5626 }
5627 v.push(n);
5628 }
5629 v
5630 }
5631 RelocatePos::After(anchor) => {
5632 if !proj_names.contains(anchor) {
5633 return Err(TidyError::ColumnNotFound(anchor.to_string()));
5634 }
5635 let mut v = Vec::new();
5636 for &n in &remaining {
5637 v.push(n);
5638 if n == *anchor {
5639 v.extend_from_slice(cols);
5640 }
5641 }
5642 v
5643 }
5644 };
5645
5646 let new_indices: Vec<usize> = new_order.iter()
5648 .map(|&name| {
5649 self.base.columns.iter().position(|(n, _)| n == name).unwrap()
5650 })
5651 .collect();
5652
5653 Ok(TidyView {
5654 base: Rc::clone(&self.base),
5655 mask: self.mask.clone(),
5656 proj: ProjectionMap::from_indices(new_indices),
5657 })
5658 }
5659
5660 pub fn drop_cols(&self, cols: &[&str]) -> Result<TidyView, TidyError> {
5670 let proj_names = self.column_names();
5671 for &name in cols {
5672 if !proj_names.contains(&name) {
5673 return Err(TidyError::ColumnNotFound(name.to_string()));
5674 }
5675 }
5676 let drop_set: std::collections::BTreeSet<&str> = cols.iter().copied().collect();
5677 let keep: Vec<&str> = proj_names.iter()
5678 .copied()
5679 .filter(|n| !drop_set.contains(n))
5680 .collect();
5681 self.select(&keep)
5682 }
5683
5684 pub fn bind_rows(&self, other: &TidyView) -> Result<TidyFrame, TidyError> {
5695 let self_names = self.column_names();
5696 let other_names = other.column_names();
5697
5698 if self_names != other_names {
5699 return Err(TidyError::schema_mismatch(format!(
5700 "left has {:?}, right has {:?}",
5701 self_names, other_names
5702 )));
5703 }
5704
5705 let self_rows: Vec<usize> = self.mask.iter_set().collect();
5706 let other_rows: Vec<usize> = other.mask.iter_set().collect();
5707
5708 let mut out_cols: Vec<(String, Column)> = Vec::new();
5709 for &ci in self.proj.indices() {
5710 let (name, self_col) = &self.base.columns[ci];
5711 let other_ci = other.proj.indices().iter().copied()
5713 .find(|&i| other.base.columns[i].0 == *name)
5714 .ok_or_else(|| TidyError::ColumnNotFound(name.clone()))?;
5715 let other_col = &other.base.columns[other_ci].1;
5716
5717 let col = concat_columns(self_col, &self_rows, other_col, &other_rows)?;
5718 out_cols.push((name.clone(), col));
5719 }
5720
5721 let df = DataFrame::from_columns(out_cols)
5722 .map_err(|e| TidyError::Internal(e.to_string()))?;
5723 Ok(TidyFrame::from_df(df))
5724 }
5725
5726 pub fn bind_cols(&self, other: &TidyView) -> Result<TidyFrame, TidyError> {
5737 let self_nrows = self.nrows();
5738 let other_nrows = other.nrows();
5739
5740 if self_nrows != other_nrows {
5741 return Err(TidyError::LengthMismatch {
5742 expected: self_nrows,
5743 got: other_nrows,
5744 });
5745 }
5746
5747 let self_names = self.column_names();
5748 let other_names = other.column_names();
5749 for name in &other_names {
5750 if self_names.contains(name) {
5751 return Err(TidyError::DuplicateColumn(name.to_string()));
5752 }
5753 }
5754
5755 let self_rows: Vec<usize> = self.mask.iter_set().collect();
5756 let other_rows: Vec<usize> = other.mask.iter_set().collect();
5757
5758 let mut out_cols: Vec<(String, Column)> = Vec::new();
5759
5760 for &ci in self.proj.indices() {
5761 let (name, col) = &self.base.columns[ci];
5762 out_cols.push((name.clone(), gather_column(col, &self_rows)));
5763 }
5764 for &ci in other.proj.indices() {
5765 let (name, col) = &other.base.columns[ci];
5766 out_cols.push((name.clone(), gather_column(col, &other_rows)));
5767 }
5768
5769 let df = DataFrame::from_columns(out_cols)
5770 .map_err(|e| TidyError::Internal(e.to_string()))?;
5771 Ok(TidyFrame::from_df(df))
5772 }
5773
5774 pub fn mutate_across(&self, specs: &[AcrossSpec]) -> Result<TidyFrame, TidyError> {
5784 let base_df = self.materialize()?;
5786
5787 let mut output_names: Vec<String> = base_df.column_names()
5789 .into_iter().map(|s| s.to_string()).collect();
5790 let mut extra_cols: Vec<(String, Column)> = Vec::new();
5791
5792 for spec in specs {
5793 for col_name in &spec.cols {
5794 let out_name = spec.output_name(col_name);
5795 if output_names.contains(&out_name) && !base_df.column_names().contains(&out_name.as_str()) {
5797 return Err(TidyError::DuplicateColumn(out_name));
5798 }
5799 let col = base_df.get_column(col_name)
5800 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
5801 let new_col = (spec.transform.func)(col_name, col)?;
5802 if !base_df.column_names().contains(&out_name.as_str()) {
5804 output_names.push(out_name.clone());
5805 }
5806 extra_cols.push((out_name, new_col));
5807 }
5808 }
5809
5810 let mut col_map: indexmap_simple::IndexMap = indexmap_simple::IndexMap::from_df(&base_df);
5812 for (name, col) in extra_cols {
5813 col_map.insert(name, col);
5814 }
5815 let df = col_map.into_df()
5816 .map_err(|e| TidyError::Internal(e.to_string()))?;
5817 Ok(TidyFrame::from_df(df))
5818 }
5819
5820 pub fn right_join(
5828 &self,
5829 right: &TidyView,
5830 on: &[(&str, &str)],
5831 suffix: &JoinSuffix,
5832 ) -> Result<NullableFrame, TidyError> {
5833 validate_join_key_types(self, right, on)?;
5835 let swapped_on: Vec<(&str, &str)> = on.iter().map(|&(l, r)| (r, l)).collect();
5837 let (right_rows, left_rows_opt) =
5838 join_match_rows_optional(right, self, &swapped_on, JoinKind::Left)?;
5839 build_right_join_frame(self, right, &left_rows_opt, &right_rows, on, suffix)
5840 }
5841
5842 pub fn full_join(
5848 &self,
5849 right: &TidyView,
5850 on: &[(&str, &str)],
5851 suffix: &JoinSuffix,
5852 ) -> Result<NullableFrame, TidyError> {
5853 validate_join_key_types(self, right, on)?;
5854 build_full_join_frame(self, right, on, suffix)
5855 }
5856
5857 pub fn inner_join_typed(
5865 &self,
5866 right: &TidyView,
5867 on: &[(&str, &str)],
5868 suffix: &JoinSuffix,
5869 ) -> Result<TidyFrame, TidyError> {
5870 validate_join_key_types(self, right, on)?;
5871 let (left_rows, right_rows) = join_match_rows(self, right, on, JoinKind::Inner)?;
5872 build_join_frame_with_suffix(self, right, &left_rows, &right_rows, on, suffix, false)
5873 }
5874
5875 pub fn left_join_typed(
5879 &self,
5880 right: &TidyView,
5881 on: &[(&str, &str)],
5882 suffix: &JoinSuffix,
5883 ) -> Result<TidyFrame, TidyError> {
5884 validate_join_key_types(self, right, on)?;
5885 let (left_rows, right_rows_opt) =
5886 join_match_rows_optional(self, right, on, JoinKind::Left)?;
5887 build_left_join_frame_with_suffix(self, right, &left_rows, &right_rows_opt, on, suffix)
5888 }
5889}
5890
5891pub enum RelocatePos<'a> {
5895 Front,
5897 Back,
5899 Before(&'a str),
5901 After(&'a str),
5903}
5904
5905fn concat_columns(
5908 left: &Column,
5909 left_rows: &[usize],
5910 right: &Column,
5911 right_rows: &[usize],
5912) -> Result<Column, TidyError> {
5913 match (left, right) {
5914 (Column::Int(lv), Column::Int(rv)) => {
5915 let mut out: Vec<i64> = left_rows.iter().map(|&i| lv[i]).collect();
5916 out.extend(right_rows.iter().map(|&i| rv[i]));
5917 Ok(Column::Int(out))
5918 }
5919 (Column::Float(lv), Column::Float(rv)) => {
5920 let mut out: Vec<f64> = left_rows.iter().map(|&i| lv[i]).collect();
5921 out.extend(right_rows.iter().map(|&i| rv[i]));
5922 Ok(Column::Float(out))
5923 }
5924 (Column::Int(lv), Column::Float(rv)) => {
5925 let mut out: Vec<f64> = left_rows.iter().map(|&i| lv[i] as f64).collect();
5926 out.extend(right_rows.iter().map(|&i| rv[i]));
5927 Ok(Column::Float(out))
5928 }
5929 (Column::Float(lv), Column::Int(rv)) => {
5930 let mut out: Vec<f64> = left_rows.iter().map(|&i| lv[i]).collect();
5931 out.extend(right_rows.iter().map(|&i| rv[i] as f64));
5932 Ok(Column::Float(out))
5933 }
5934 (Column::Str(lv), Column::Str(rv)) => {
5935 let mut out: Vec<String> = left_rows.iter().map(|&i| lv[i].clone()).collect();
5936 out.extend(right_rows.iter().map(|&i| rv[i].clone()));
5937 Ok(Column::Str(out))
5938 }
5939 (Column::Bool(lv), Column::Bool(rv)) => {
5940 let mut out: Vec<bool> = left_rows.iter().map(|&i| lv[i]).collect();
5941 out.extend(right_rows.iter().map(|&i| rv[i]));
5942 Ok(Column::Bool(out))
5943 }
5944 _ => Err(TidyError::schema_mismatch(format!(
5945 "type mismatch in bind_rows: {} vs {}",
5946 left.type_name(), right.type_name()
5947 ))),
5948 }
5949}
5950
5951fn validate_join_key_types(
5954 left: &TidyView,
5955 right: &TidyView,
5956 on: &[(&str, &str)],
5957) -> Result<(), TidyError> {
5958 for &(lk, rk) in on {
5959 let l_col = left.base.get_column(lk)
5960 .ok_or_else(|| TidyError::ColumnNotFound(lk.to_string()))?;
5961 let r_col = right.base.get_column(rk)
5962 .ok_or_else(|| TidyError::ColumnNotFound(rk.to_string()))?;
5963 if !join_types_compatible(l_col, r_col) {
5964 return Err(TidyError::join_type_mismatch(lk, l_col.type_name(), r_col.type_name()));
5965 }
5966 }
5967 Ok(())
5968}
5969
5970fn build_join_frame_with_suffix(
5973 left: &TidyView,
5974 right: &TidyView,
5975 left_rows: &[usize],
5976 right_rows: &[usize],
5977 on: &[(&str, &str)],
5978 suffix: &JoinSuffix,
5979 _include_unmatched: bool,
5980) -> Result<TidyFrame, TidyError> {
5981 let right_key_names: std::collections::BTreeSet<&str> =
5982 on.iter().map(|(_, rk)| *rk).collect();
5983
5984 let left_col_names: Vec<String> = left.proj.indices().iter()
5986 .map(|&ci| left.base.columns[ci].0.clone())
5987 .collect();
5988
5989 let mut columns: Vec<(String, Column)> = Vec::new();
5990
5991 for &ci in left.proj.indices() {
5993 let (name, col) = &left.base.columns[ci];
5994 columns.push((name.clone(), gather_column(col, left_rows)));
5995 }
5996
5997 for &ci in right.proj.indices() {
5999 let (name, col) = &right.base.columns[ci];
6000 if right_key_names.contains(name.as_str()) {
6001 continue; }
6003 let out_name = if left_col_names.contains(name) {
6004 format!("{}{}", name, suffix.right)
6005 } else {
6006 name.clone()
6007 };
6008 if left_col_names.contains(name) {
6010 let left_pos = columns.iter().position(|(n, _)| n == name);
6012 if let Some(pos) = left_pos {
6013 let entry = &mut columns[pos];
6014 entry.0 = format!("{}{}", entry.0, suffix.left);
6015 }
6016 }
6017 columns.push((out_name, gather_column(col, right_rows)));
6018 }
6019
6020 let df = DataFrame::from_columns(columns)
6021 .map_err(|e| TidyError::Internal(e.to_string()))?;
6022 Ok(TidyFrame::from_df(df))
6023}
6024
6025fn build_left_join_frame_with_suffix(
6026 left: &TidyView,
6027 right: &TidyView,
6028 left_rows: &[usize],
6029 right_rows_opt: &[Option<usize>],
6030 on: &[(&str, &str)],
6031 suffix: &JoinSuffix,
6032) -> Result<TidyFrame, TidyError> {
6033 let right_key_names: std::collections::BTreeSet<&str> =
6034 on.iter().map(|(_, rk)| *rk).collect();
6035
6036 let left_col_names: Vec<String> = left.proj.indices().iter()
6037 .map(|&ci| left.base.columns[ci].0.clone())
6038 .collect();
6039
6040 let mut columns: Vec<(String, Column)> = Vec::new();
6041
6042 for &ci in left.proj.indices() {
6044 let (name, col) = &left.base.columns[ci];
6045 columns.push((name.clone(), gather_column(col, left_rows)));
6046 }
6047
6048 for &ci in right.proj.indices() {
6050 let (name, col) = &right.base.columns[ci];
6051 if right_key_names.contains(name.as_str()) { continue; }
6052 let out_name = if left_col_names.contains(name) {
6053 let left_pos = columns.iter().position(|(n, _)| n == name);
6055 if let Some(pos) = left_pos {
6056 columns[pos].0 = format!("{}{}", name, suffix.left);
6057 }
6058 format!("{}{}", name, suffix.right)
6059 } else {
6060 name.clone()
6061 };
6062 let new_col = gather_column_nullable(col, right_rows_opt);
6063 columns.push((out_name, new_col));
6064 }
6065
6066 let df = DataFrame::from_columns(columns)
6067 .map_err(|e| TidyError::Internal(e.to_string()))?;
6068 Ok(TidyFrame::from_df(df))
6069}
6070
6071fn build_right_join_frame(
6072 left: &TidyView,
6073 right: &TidyView,
6074 left_rows_opt: &[Option<usize>],
6075 right_rows: &[usize],
6076 on: &[(&str, &str)],
6077 suffix: &JoinSuffix,
6078) -> Result<NullableFrame, TidyError> {
6079 let right_key_names: std::collections::BTreeSet<&str> =
6080 on.iter().map(|(_, rk)| *rk).collect();
6081 let left_key_names: std::collections::BTreeSet<&str> =
6082 on.iter().map(|(lk, _)| *lk).collect();
6083
6084 let right_col_names: Vec<String> = right.proj.indices().iter()
6085 .map(|&ci| right.base.columns[ci].0.clone())
6086 .collect();
6087
6088 let mut columns: Vec<(String, NullCol)> = Vec::new();
6089
6090 for &ci in left.proj.indices() {
6092 let (name, col) = &left.base.columns[ci];
6093 if left_key_names.contains(name.as_str()) { continue; }
6094 let out_name = if right_col_names.contains(name) {
6095 format!("{}{}", name, suffix.left)
6096 } else {
6097 name.clone()
6098 };
6099 let null_col = gather_column_nullable_null(col, left_rows_opt);
6100 columns.push((out_name, null_col));
6101 }
6102
6103 for &ci in right.proj.indices() {
6105 let (name, col) = &right.base.columns[ci];
6106 let out_name = if !right_key_names.contains(name.as_str())
6107 && left.proj.indices().iter().any(|&lci| left.base.columns[lci].0 == *name)
6108 && !left_key_names.contains(name.as_str())
6109 {
6110 format!("{}{}", name, suffix.right)
6111 } else {
6112 name.clone()
6113 };
6114 columns.push((out_name, NullCol::from_column(&gather_column(col, right_rows))));
6115 }
6116
6117 Ok(NullableFrame { columns })
6118}
6119
6120fn build_full_join_frame(
6121 left: &TidyView,
6122 right: &TidyView,
6123 on: &[(&str, &str)],
6124 suffix: &JoinSuffix,
6125) -> Result<NullableFrame, TidyError> {
6126 let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
6127 let lookup = build_right_lookup(right, &right_key_cols);
6128
6129 let mut out_left_rows: Vec<usize> = Vec::new();
6131 let mut out_right_rows: Vec<Option<usize>> = Vec::new();
6132 let mut right_matched: Vec<bool> = vec![false; right.base.nrows()];
6133
6134 for l_row in left.mask.iter_set() {
6135 let key = row_key(&left.base, &left_key_cols, l_row);
6136 let matches = find_matches(&lookup, &key);
6137 if matches.is_empty() {
6138 out_left_rows.push(l_row);
6139 out_right_rows.push(None);
6140 } else {
6141 for r_row in &matches {
6142 out_left_rows.push(l_row);
6143 out_right_rows.push(Some(*r_row));
6144 if *r_row < right_matched.len() {
6145 right_matched[*r_row] = true;
6146 }
6147 }
6148 }
6149 }
6150
6151 let mut unmatched_right: Vec<usize> = Vec::new();
6153 for r_row in right.mask.iter_set() {
6154 if r_row < right_matched.len() && !right_matched[r_row] {
6155 unmatched_right.push(r_row);
6156 }
6157 }
6158
6159 let right_key_names: std::collections::BTreeSet<&str> =
6160 on.iter().map(|(_, rk)| *rk).collect();
6161 let left_key_names: std::collections::BTreeSet<&str> =
6162 on.iter().map(|(lk, _)| *lk).collect();
6163 let right_col_names: Vec<String> = right.proj.indices().iter()
6164 .map(|&ci| right.base.columns[ci].0.clone())
6165 .collect();
6166
6167 let n_matched = out_left_rows.len();
6168 let n_unmatched_r = unmatched_right.len();
6169 let total = n_matched + n_unmatched_r;
6170
6171 let mut columns: Vec<(String, NullCol)> = Vec::new();
6172
6173 for &ci in left.proj.indices() {
6175 let (name, col) = &left.base.columns[ci];
6176 let out_name = if right_col_names.contains(name) && !left_key_names.contains(name.as_str()) {
6177 format!("{}{}", name, suffix.left)
6178 } else {
6179 name.clone()
6180 };
6181 let mut matched_vals: Vec<Option<usize>> = out_left_rows.iter()
6182 .map(|&r| Some(r))
6183 .collect();
6184 matched_vals.extend(std::iter::repeat(None).take(n_unmatched_r));
6186 assert_eq!(matched_vals.len(), total);
6187 columns.push((out_name, gather_column_nullable_null(col, &matched_vals)));
6188 }
6189
6190 for &ci in right.proj.indices() {
6192 let (name, col) = &right.base.columns[ci];
6193 if right_key_names.contains(name.as_str()) { continue; }
6194 let out_name = if left.proj.indices().iter().any(|&lci| left.base.columns[lci].0 == *name)
6195 && !left_key_names.contains(name.as_str())
6196 {
6197 format!("{}{}", name, suffix.right)
6198 } else {
6199 name.clone()
6200 };
6201
6202 let mut row_opts: Vec<Option<usize>> = out_right_rows.clone();
6203 row_opts.extend(unmatched_right.iter().map(|&r| Some(r)));
6205 assert_eq!(row_opts.len(), total);
6206 columns.push((out_name, gather_column_nullable_null(col, &row_opts)));
6207 }
6208
6209 Ok(NullableFrame { columns })
6215}
6216
6217impl GroupedTidyView {
6220
6221 pub fn mutate_across(&self, specs: &[AcrossSpec]) -> Result<TidyFrame, TidyError> {
6227 self.view.mutate_across(specs)
6230 }
6231
6232 pub fn summarise_across(&self, specs: &[AcrossSpec]) -> Result<TidyFrame, TidyError> {
6239 let n_groups = self.ngroups();
6240
6241 let key_names = &self.index.key_names;
6243 let mut out_cols: Vec<(String, Column)> = Vec::new();
6244
6245 for ki in 0..key_names.len() {
6247 let col_vals: Vec<String> = self.index.groups.iter()
6248 .map(|g| g.key_values[ki].clone())
6249 .collect();
6250 out_cols.push((key_names[ki].clone(), Column::Str(col_vals)));
6251 }
6252
6253 for spec in specs {
6255 for col_name in &spec.cols {
6256 let out_name = spec.output_name(col_name);
6257 if out_cols.iter().any(|(n, _)| n == &out_name) {
6259 return Err(TidyError::DuplicateColumn(out_name));
6260 }
6261
6262 let base_col = self.view.base.get_column(col_name)
6263 .ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
6264
6265 let mut agg_floats: Vec<f64> = Vec::with_capacity(n_groups);
6267 for group in &self.index.groups {
6268 let group_col = gather_column(base_col, &group.row_indices);
6269 let result_col = (spec.transform.func)(col_name, &group_col)?;
6270 if result_col.len() != 1 {
6271 return Err(TidyError::LengthMismatch {
6272 expected: 1,
6273 got: result_col.len(),
6274 });
6275 }
6276 let v = match &result_col {
6277 Column::Float(v) => v[0],
6278 Column::Int(v) => v[0] as f64,
6279 _ => return Err(TidyError::TypeMismatch {
6280 expected: "Float or Int".into(),
6281 got: result_col.type_name().into(),
6282 }),
6283 };
6284 agg_floats.push(v);
6285 }
6286 out_cols.push((out_name, Column::Float(agg_floats)));
6287 }
6288 }
6289
6290 let df = DataFrame::from_columns(out_cols)
6291 .map_err(|e| TidyError::Internal(e.to_string()))?;
6292 Ok(TidyFrame::from_df(df))
6293 }
6294}
6295
6296mod indexmap_simple {
6301 use super::{Column, DataFrame, DataError};
6302
6303 pub struct IndexMap {
6304 entries: Vec<(String, Column)>,
6305 }
6306
6307 impl IndexMap {
6308 pub fn from_df(df: &DataFrame) -> Self {
6309 Self {
6310 entries: df.columns.iter()
6311 .map(|(n, c)| (n.clone(), c.clone()))
6312 .collect(),
6313 }
6314 }
6315
6316 pub fn insert(&mut self, name: String, col: Column) {
6318 if let Some(pos) = self.entries.iter().position(|(n, _)| n == &name) {
6319 self.entries[pos] = (name, col);
6320 } else {
6321 self.entries.push((name, col));
6322 }
6323 }
6324
6325 pub fn into_df(self) -> Result<DataFrame, DataError> {
6326 DataFrame::from_columns(self.entries)
6327 }
6328 }
6329}
6330
6331impl GroupIndex {
6348 pub fn build_fast(
6353 base: &DataFrame,
6354 key_col_indices: &[usize],
6355 visible_rows: &[usize],
6356 key_names: Vec<String>,
6357 ) -> Self {
6358 use std::collections::BTreeMap;
6359
6360 let mut groups: Vec<GroupMeta> = Vec::new();
6361 let mut key_to_slot: BTreeMap<Vec<String>, usize> = BTreeMap::new();
6362
6363 for &row in visible_rows {
6364 let key: Vec<String> = key_col_indices.iter()
6365 .map(|&ci| base.columns[ci].1.get_display(row))
6366 .collect();
6367
6368 if let Some(&slot) = key_to_slot.get(&key) {
6369 groups[slot].row_indices.push(row);
6370 } else {
6371 let slot = groups.len();
6372 let key_values = key.clone();
6373 key_to_slot.insert(key, slot);
6374 groups.push(GroupMeta { key_values, row_indices: vec![row] });
6375 }
6376 }
6377
6378 GroupIndex { groups, key_names }
6379 }
6380}
6381
6382impl TidyView {
6385 pub fn group_by_fast(&self, keys: &[&str]) -> Result<GroupedTidyView, TidyError> {
6390 let mut key_col_indices = Vec::with_capacity(keys.len());
6391 for &key in keys {
6392 let idx = self.base.columns.iter().position(|(n, _)| n == key)
6393 .ok_or_else(|| TidyError::ColumnNotFound(key.to_string()))?;
6394 key_col_indices.push(idx);
6395 }
6396 let visible_rows: Vec<usize> = self.mask.iter_set().collect();
6397 let key_names: Vec<String> = keys.iter().map(|s| s.to_string()).collect();
6398 let index = GroupIndex::build_fast(&self.base, &key_col_indices, &visible_rows, key_names);
6399 Ok(GroupedTidyView { view: self.clone(), index })
6400 }
6401}
6402
6403#[derive(Clone, Debug)]
6469pub struct FctColumn {
6470 pub levels: Vec<String>,
6473 pub data: Vec<u16>,
6475}
6476
6477impl FctColumn {
6478 pub fn encode(strings: &[String]) -> Result<Self, TidyError> {
6485 use std::collections::BTreeMap;
6486 let mut levels: Vec<String> = Vec::new();
6487 let mut level_map: BTreeMap<String, u16> = BTreeMap::new();
6491 let mut data: Vec<u16> = Vec::with_capacity(strings.len());
6492
6493 for s in strings {
6494 let idx = if let Some(&existing) = level_map.get(s.as_str()) {
6495 existing
6496 } else {
6497 let next = levels.len();
6498 if next >= 65_535 {
6499 return Err(TidyError::CapacityExceeded {
6500 limit: 65_535,
6501 got: next + 1,
6502 });
6503 }
6504 let idx = next as u16;
6505 levels.push(s.clone());
6506 level_map.insert(s.clone(), idx);
6507 idx
6508 };
6509 data.push(idx);
6510 }
6511 Ok(FctColumn { levels, data })
6512 }
6513
6514 pub fn encode_from_view(view: &TidyView, col: &str) -> Result<Self, TidyError> {
6516 let base_idx = view.base.columns.iter()
6517 .position(|(n, _)| n == col)
6518 .ok_or_else(|| TidyError::ColumnNotFound(col.to_string()))?;
6519 if !view.proj.indices().contains(&base_idx) {
6521 return Err(TidyError::ColumnNotFound(col.to_string()));
6522 }
6523 let col_data = &view.base.columns[base_idx].1;
6524 let visible: Vec<usize> = view.mask.iter_set().collect();
6525 let strings: Vec<String> = visible.iter()
6526 .map(|&r| col_data.get_display(r))
6527 .collect();
6528 Self::encode(&strings)
6529 }
6530
6531 pub fn nrows(&self) -> usize { self.data.len() }
6534 pub fn nlevels(&self) -> usize { self.levels.len() }
6535
6536 pub fn decode(&self, i: usize) -> &str {
6538 &self.levels[self.data[i] as usize]
6539 }
6540
6541 pub fn fct_lump(&self, n: usize) -> Result<Self, TidyError> {
6553 if n >= self.levels.len() {
6554 return Ok(self.clone()); }
6556
6557 let mut freq = vec![0usize; self.levels.len()];
6559 for &idx in &self.data {
6560 freq[idx as usize] += 1;
6561 }
6562
6563 let mut ranked: Vec<(usize, usize)> = freq.iter().copied().enumerate().collect();
6566 ranked.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
6567
6568 let mut keep_set: Vec<usize> = ranked[..n].iter().map(|(i, _)| *i).collect();
6570 keep_set.sort_unstable(); let mut other_name = "Other".to_string();
6574 while keep_set.iter().any(|&ki| self.levels[ki] == other_name) {
6575 other_name.push('_');
6576 }
6577
6578 let mut new_levels: Vec<String> = keep_set.iter().map(|&ki| self.levels[ki].clone()).collect();
6580 let other_idx = new_levels.len() as u16;
6581 new_levels.push(other_name);
6582
6583 let mut remap = vec![other_idx; self.levels.len()];
6585 for (new_i, &old_i) in keep_set.iter().enumerate() {
6586 remap[old_i] = new_i as u16;
6587 }
6588
6589 let new_data: Vec<u16> = self.data.iter().map(|&d| remap[d as usize]).collect();
6590 Ok(FctColumn { levels: new_levels, data: new_data })
6591 }
6592
6593 pub fn fct_reorder(&self, summary_vals: &[f64], descending: bool) -> Result<Self, TidyError> {
6602 if summary_vals.len() != self.levels.len() {
6603 return Err(TidyError::LengthMismatch {
6604 expected: self.levels.len(),
6605 got: summary_vals.len(),
6606 });
6607 }
6608 let mut order: Vec<usize> = (0..self.levels.len()).collect();
6612 order.sort_by(|&a, &b| {
6613 let va = summary_vals[a];
6614 let vb = summary_vals[b];
6615 match (va.is_nan(), vb.is_nan()) {
6616 (true, true) => std::cmp::Ordering::Equal,
6617 (true, false) => std::cmp::Ordering::Greater, (false, true) => std::cmp::Ordering::Less, (false, false) => {
6620 let cmp = va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal);
6621 if descending { cmp.reverse() } else { cmp }
6622 }
6623 }
6624 });
6625
6626 let new_levels: Vec<String> = order.iter().map(|&i| self.levels[i].clone()).collect();
6628
6629 let mut remap = vec![0u16; self.levels.len()];
6631 for (new_i, &old_i) in order.iter().enumerate() {
6632 remap[old_i] = new_i as u16;
6633 }
6634
6635 let new_data: Vec<u16> = self.data.iter().map(|&d| remap[d as usize]).collect();
6636 Ok(FctColumn { levels: new_levels, data: new_data })
6637 }
6638
6639 pub fn fct_reorder_by_col(&self, numeric_col: &Column, descending: bool) -> Result<Self, TidyError> {
6645 if numeric_col.len() != self.data.len() {
6646 return Err(TidyError::LengthMismatch {
6647 expected: self.data.len(),
6648 got: numeric_col.len(),
6649 });
6650 }
6651 let mut sums = vec![0.0f64; self.levels.len()];
6652 let mut counts = vec![0usize; self.levels.len()];
6653 match numeric_col {
6654 Column::Float(v) => {
6655 for (i, &d) in self.data.iter().enumerate() {
6656 let val = v[i];
6657 if !val.is_nan() {
6658 sums[d as usize] += val;
6659 counts[d as usize] += 1;
6660 }
6661 }
6662 }
6663 Column::Int(v) => {
6664 for (i, &d) in self.data.iter().enumerate() {
6665 sums[d as usize] += v[i] as f64;
6666 counts[d as usize] += 1;
6667 }
6668 }
6669 _ => return Err(TidyError::TypeMismatch {
6670 expected: "Float or Int".to_string(),
6671 got: numeric_col.type_name().to_string(),
6672 }),
6673 }
6674 let means: Vec<f64> = sums.iter().zip(counts.iter())
6675 .map(|(&s, &c)| if c == 0 { f64::NAN } else { s / c as f64 })
6676 .collect();
6677 self.fct_reorder(&means, descending)
6678 }
6679
6680 pub fn fct_collapse(&self, mapping: &[(&str, &str)]) -> Result<Self, TidyError> {
6697 if mapping.is_empty() {
6698 return Ok(self.clone());
6699 }
6700 let new_name_for: Vec<String> = self.levels.iter().map(|old| {
6702 if let Some((_, new)) = mapping.iter().find(|(o, _)| *o == old.as_str()) {
6703 new.to_string()
6704 } else {
6705 old.clone()
6706 }
6707 }).collect();
6708
6709 use std::collections::BTreeMap;
6712 let mut new_levels: Vec<String> = Vec::new();
6713 let mut new_name_to_idx: BTreeMap<String, u16> = BTreeMap::new();
6714
6715 let mut old_to_new: Vec<u16> = Vec::with_capacity(self.levels.len());
6716 for name in &new_name_for {
6717 let idx = if let Some(&existing) = new_name_to_idx.get(name.as_str()) {
6718 existing
6719 } else {
6720 let idx = new_levels.len() as u16;
6721 new_levels.push(name.clone());
6722 new_name_to_idx.insert(name.clone(), idx);
6723 idx
6724 };
6725 old_to_new.push(idx);
6726 }
6727
6728 let changed = old_to_new.iter().enumerate().any(|(i, &new)| new != i as u16);
6730 let new_data = if changed {
6731 self.data.iter().map(|&d| old_to_new[d as usize]).collect()
6732 } else {
6733 self.data.clone()
6734 };
6735 Ok(FctColumn { levels: new_levels, data: new_data })
6736 }
6737
6738 pub fn to_str_column(&self) -> Column {
6742 Column::Str(self.data.iter().map(|&d| self.levels[d as usize].clone()).collect())
6743 }
6744
6745 pub fn gather(&self, indices: &[usize]) -> FctColumn {
6747 FctColumn {
6748 levels: self.levels.clone(),
6749 data: indices.iter().map(|&i| self.data[i]).collect(),
6750 }
6751 }
6752}
6753
6754impl TidyError {
6757 pub fn capacity_exceeded(limit: usize, got: usize) -> Self {
6758 TidyError::CapacityExceeded { limit, got }
6759 }
6760}
6761
6762#[derive(Clone, Debug)]
6767pub struct NullableFactor {
6768 pub fct: FctColumn,
6769 pub validity: BitMask,
6770}
6771
6772impl NullableFactor {
6773 pub fn from_fct(fct: FctColumn) -> Self {
6775 let n = fct.nrows();
6776 NullableFactor { fct, validity: BitMask::all_true(n) }
6777 }
6778
6779 pub fn new(fct: FctColumn, validity: BitMask) -> Self {
6781 NullableFactor { fct, validity }
6782 }
6783
6784 pub fn encode_nullable(strings: &[Option<String>]) -> Result<Self, TidyError> {
6788 use std::collections::BTreeMap;
6789 let mut levels: Vec<String> = Vec::new();
6790 let mut level_map: BTreeMap<String, u16> = BTreeMap::new();
6791 let mut data: Vec<u16> = Vec::with_capacity(strings.len());
6792 let mut valid_flags: Vec<bool> = Vec::with_capacity(strings.len());
6793
6794 for opt in strings {
6795 match opt {
6796 None => {
6797 data.push(0); valid_flags.push(false);
6799 }
6800 Some(s) => {
6801 let idx = if let Some(&existing) = level_map.get(s.as_str()) {
6802 existing
6803 } else {
6804 let next = levels.len();
6805 if next >= 65_535 {
6806 return Err(TidyError::CapacityExceeded { limit: 65_535, got: next + 1 });
6807 }
6808 let idx = next as u16;
6809 levels.push(s.clone());
6810 level_map.insert(s.clone(), idx);
6811 idx
6812 };
6813 data.push(idx);
6814 valid_flags.push(true);
6815 }
6816 }
6817 }
6818 let fct = FctColumn { levels, data };
6819 let validity = BitMask::from_bools(&valid_flags);
6820 Ok(NullableFactor { fct, validity })
6821 }
6822
6823 pub fn nrows(&self) -> usize { self.fct.nrows() }
6824 pub fn nlevels(&self) -> usize { self.fct.nlevels() }
6825 pub fn is_null(&self, i: usize) -> bool { !self.validity.get(i) }
6826 pub fn count_valid(&self) -> usize { self.validity.count_ones() }
6827
6828 pub fn decode(&self, i: usize) -> Option<&str> {
6830 if self.is_null(i) { None } else { Some(self.fct.decode(i)) }
6831 }
6832
6833 pub fn fct_lump(&self, n: usize) -> Result<Self, TidyError> {
6835 let lumped = self.fct.fct_lump(n)?;
6836 Ok(NullableFactor { fct: lumped, validity: self.validity.clone() })
6837 }
6838
6839 pub fn fct_reorder(&self, summary_vals: &[f64], descending: bool) -> Result<Self, TidyError> {
6841 let reordered = self.fct.fct_reorder(summary_vals, descending)?;
6842 Ok(NullableFactor { fct: reordered, validity: self.validity.clone() })
6843 }
6844
6845 pub fn fct_collapse(&self, mapping: &[(&str, &str)]) -> Result<Self, TidyError> {
6847 let collapsed = self.fct.fct_collapse(mapping)?;
6848 Ok(NullableFactor { fct: collapsed, validity: self.validity.clone() })
6849 }
6850}
6851
6852impl TidyView {
6855 pub fn fct_encode(&self, col: &str) -> Result<FctColumn, TidyError> {
6860 FctColumn::encode_from_view(self, col)
6861 }
6862
6863 pub fn fct_summary_means(
6868 &self,
6869 fct: &FctColumn,
6870 numeric_col: &str,
6871 ) -> Result<Vec<f64>, TidyError> {
6872 let base_idx = self.base.columns.iter()
6873 .position(|(n, _)| n == numeric_col)
6874 .ok_or_else(|| TidyError::ColumnNotFound(numeric_col.to_string()))?;
6875 let nc = &self.base.columns[base_idx].1;
6876 if nc.len() != fct.nrows() {
6877 return Err(TidyError::LengthMismatch { expected: fct.nrows(), got: nc.len() });
6878 }
6879 match nc {
6881 Column::Float(_) | Column::Int(_) => {}
6882 _ => return Err(TidyError::TypeMismatch {
6883 expected: "Float or Int".to_string(),
6884 got: nc.type_name().to_string(),
6885 }),
6886 }
6887 let mut sums = vec![0.0f64; fct.levels.len()];
6888 let mut counts = vec![0usize; fct.levels.len()];
6889 match nc {
6890 Column::Float(v) => {
6891 for (i, &d) in fct.data.iter().enumerate() {
6892 if !v[i].is_nan() {
6893 sums[d as usize] += v[i];
6894 counts[d as usize] += 1;
6895 }
6896 }
6897 }
6898 Column::Int(v) => {
6899 for (i, &d) in fct.data.iter().enumerate() {
6900 sums[d as usize] += v[i] as f64;
6901 counts[d as usize] += 1;
6902 }
6903 }
6904 _ => unreachable!(),
6905 }
6906 Ok(sums.iter().zip(counts.iter())
6907 .map(|(&s, &c)| if c == 0 { f64::NAN } else { s / c as f64 })
6908 .collect())
6909 }
6910}
6911
6912pub fn label_encode(col: &[String]) -> (Vec<String>, Vec<u32>) {
6919 let unique: BTreeSet<&str> = col.iter().map(|s| s.as_str()).collect();
6920 let levels: Vec<String> = unique.into_iter().map(|s| s.to_string()).collect();
6921
6922 let lookup: BTreeMap<&str, u32> = levels
6923 .iter()
6924 .enumerate()
6925 .map(|(i, s)| (s.as_str(), i as u32))
6926 .collect();
6927
6928 let codes: Vec<u32> = col.iter().map(|s| lookup[s.as_str()]).collect();
6929 (levels, codes)
6930}
6931
6932pub fn ordinal_encode(col: &[String], order: &[String]) -> Result<(Vec<String>, Vec<u32>), String> {
6937 let lookup: BTreeMap<&str, u32> = order
6938 .iter()
6939 .enumerate()
6940 .map(|(i, s)| (s.as_str(), i as u32))
6941 .collect();
6942
6943 let mut codes = Vec::with_capacity(col.len());
6944 for s in col {
6945 match lookup.get(s.as_str()) {
6946 Some(&idx) => codes.push(idx),
6947 None => return Err(format!("value {:?} not found in specified order", s)),
6948 }
6949 }
6950 Ok((order.to_vec(), codes))
6951}
6952
6953pub fn one_hot_encode(levels: &[String], codes: &[u32]) -> (Vec<String>, Vec<Vec<bool>>) {
6958 let n_levels = levels.len();
6959 let n_rows = codes.len();
6960
6961 let mut columns: Vec<Vec<bool>> = vec![vec![false; n_rows]; n_levels];
6962 for (row, &code) in codes.iter().enumerate() {
6963 columns[code as usize][row] = true;
6964 }
6965
6966 let names: Vec<String> = levels.to_vec();
6967 (names, columns)
6968}
6969
6970#[cfg(test)]
6971mod rolling_window_tests {
6972 use super::*;
6973
6974 fn make_df(col_name: &str, vals: Vec<f64>) -> DataFrame {
6976 DataFrame {
6977 columns: vec![(col_name.to_string(), Column::Float(vals))],
6978 }
6979 }
6980
6981 #[test]
6982 fn rolling_sum_basic() {
6983 let df = make_df("x", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
6986 let expr = DExpr::RollingSum("x".into(), 3);
6987 let col = eval_expr_column(&df, &expr, 5).unwrap();
6988 match col {
6989 Column::Float(v) => {
6990 assert_eq!(v.len(), 5);
6991 assert!((v[0] - 1.0).abs() < 1e-12);
6992 assert!((v[1] - 3.0).abs() < 1e-12);
6993 assert!((v[2] - 6.0).abs() < 1e-12);
6994 assert!((v[3] - 9.0).abs() < 1e-12);
6995 assert!((v[4] - 12.0).abs() < 1e-12);
6996 }
6997 _ => panic!("expected Float column"),
6998 }
6999 }
7000
7001 #[test]
7002 fn rolling_mean_basic() {
7003 let df = make_df("x", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
7006 let expr = DExpr::RollingMean("x".into(), 3);
7007 let col = eval_expr_column(&df, &expr, 5).unwrap();
7008 match col {
7009 Column::Float(v) => {
7010 assert_eq!(v.len(), 5);
7011 assert!((v[0] - 1.0).abs() < 1e-12);
7012 assert!((v[1] - 1.5).abs() < 1e-12);
7013 assert!((v[2] - 2.0).abs() < 1e-12);
7014 assert!((v[3] - 3.0).abs() < 1e-12);
7015 assert!((v[4] - 4.0).abs() < 1e-12);
7016 }
7017 _ => panic!("expected Float column"),
7018 }
7019 }
7020
7021 #[test]
7022 fn rolling_min_basic() {
7023 let df = make_df("x", vec![5.0, 3.0, 4.0, 1.0, 2.0]);
7026 let expr = DExpr::RollingMin("x".into(), 3);
7027 let col = eval_expr_column(&df, &expr, 5).unwrap();
7028 match col {
7029 Column::Float(v) => {
7030 assert_eq!(v.len(), 5);
7031 assert!((v[0] - 5.0).abs() < 1e-12);
7032 assert!((v[1] - 3.0).abs() < 1e-12);
7033 assert!((v[2] - 3.0).abs() < 1e-12);
7034 assert!((v[3] - 1.0).abs() < 1e-12);
7035 assert!((v[4] - 1.0).abs() < 1e-12);
7036 }
7037 _ => panic!("expected Float column"),
7038 }
7039 }
7040
7041 #[test]
7042 fn rolling_max_basic() {
7043 let df = make_df("x", vec![1.0, 5.0, 3.0, 2.0, 4.0]);
7046 let expr = DExpr::RollingMax("x".into(), 3);
7047 let col = eval_expr_column(&df, &expr, 5).unwrap();
7048 match col {
7049 Column::Float(v) => {
7050 assert_eq!(v.len(), 5);
7051 assert!((v[0] - 1.0).abs() < 1e-12);
7052 assert!((v[1] - 5.0).abs() < 1e-12);
7053 assert!((v[2] - 5.0).abs() < 1e-12);
7054 assert!((v[3] - 5.0).abs() < 1e-12);
7055 assert!((v[4] - 4.0).abs() < 1e-12);
7056 }
7057 _ => panic!("expected Float column"),
7058 }
7059 }
7060
7061 #[test]
7062 fn rolling_var_basic() {
7063 let df = make_df("x", vec![2.0, 4.0, 6.0, 8.0]);
7065 let expr = DExpr::RollingVar("x".into(), 3);
7066 let col = eval_expr_column(&df, &expr, 4).unwrap();
7067 match col {
7068 Column::Float(v) => {
7069 assert_eq!(v.len(), 4);
7070 assert!((v[0] - 0.0).abs() < 1e-12);
7072 assert!((v[1] - 2.0).abs() < 1e-10);
7074 assert!((v[2] - 4.0).abs() < 1e-10);
7076 assert!((v[3] - 4.0).abs() < 1e-10);
7078 }
7079 _ => panic!("expected Float column"),
7080 }
7081 }
7082
7083 #[test]
7084 fn rolling_sd_basic() {
7085 let df = make_df("x", vec![2.0, 4.0, 6.0, 8.0]);
7086 let expr = DExpr::RollingSd("x".into(), 3);
7087 let col = eval_expr_column(&df, &expr, 4).unwrap();
7088 match col {
7089 Column::Float(v) => {
7090 assert_eq!(v.len(), 4);
7091 assert!((v[0] - 0.0).abs() < 1e-12);
7092 assert!((v[1] - 2.0_f64.sqrt()).abs() < 1e-10);
7093 assert!((v[2] - 2.0).abs() < 1e-10);
7094 assert!((v[3] - 2.0).abs() < 1e-10);
7095 }
7096 _ => panic!("expected Float column"),
7097 }
7098 }
7099
7100 #[test]
7101 fn rolling_window_larger_than_data() {
7102 let df = make_df("x", vec![1.0, 2.0, 3.0]);
7103 let expr = DExpr::RollingSum("x".into(), 10);
7104 let col = eval_expr_column(&df, &expr, 3).unwrap();
7105 match col {
7106 Column::Float(v) => {
7107 assert_eq!(v.len(), 3);
7108 assert!((v[0] - 1.0).abs() < 1e-12);
7109 assert!((v[1] - 3.0).abs() < 1e-12);
7110 assert!((v[2] - 6.0).abs() < 1e-12);
7111 }
7112 _ => panic!("expected Float column"),
7113 }
7114 }
7115
7116 #[test]
7117 fn rolling_window_of_one() {
7118 let df = make_df("x", vec![3.0, 1.0, 4.0, 1.0, 5.0]);
7119 let expr_min = DExpr::RollingMin("x".into(), 1);
7120 let expr_max = DExpr::RollingMax("x".into(), 1);
7121 let col_min = eval_expr_column(&df, &expr_min, 5).unwrap();
7122 let col_max = eval_expr_column(&df, &expr_max, 5).unwrap();
7123 match (col_min, col_max) {
7124 (Column::Float(mins), Column::Float(maxs)) => {
7125 let expected = [3.0, 1.0, 4.0, 1.0, 5.0];
7126 for i in 0..5 {
7127 assert!((mins[i] - expected[i]).abs() < 1e-12, "min[{}]", i);
7128 assert!((maxs[i] - expected[i]).abs() < 1e-12, "max[{}]", i);
7129 }
7130 }
7131 _ => panic!("expected Float columns"),
7132 }
7133 }
7134
7135 #[test]
7136 fn rolling_sum_with_nan() {
7137 let df = make_df("x", vec![1.0, f64::NAN, 3.0, 4.0]);
7138 let expr = DExpr::RollingSum("x".into(), 2);
7139 let col = eval_expr_column(&df, &expr, 4).unwrap();
7140 match col {
7141 Column::Float(v) => {
7142 assert_eq!(v.len(), 4);
7143 assert!((v[0] - 1.0).abs() < 1e-12);
7144 assert!(v[1].is_nan());
7145 assert!(v[2].is_nan());
7146 assert!(v[3].is_nan()); }
7148 _ => panic!("expected Float column"),
7149 }
7150 }
7151
7152 #[test]
7153 fn rolling_determinism() {
7154 let df = make_df("x", vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
7155 let expr = DExpr::RollingSum("x".into(), 4);
7156 let mut runs: Vec<Vec<f64>> = Vec::new();
7157 for _ in 0..3 {
7158 let col = eval_expr_column(&df, &expr, 10).unwrap();
7159 match col {
7160 Column::Float(v) => runs.push(v),
7161 _ => panic!("expected Float column"),
7162 }
7163 }
7164 assert_eq!(runs[0], runs[1]);
7165 assert_eq!(runs[1], runs[2]);
7166 }
7167
7168 #[test]
7169 fn rolling_display() {
7170 let expr = DExpr::RollingSum("val".into(), 5);
7171 assert_eq!(format!("{}", expr), "rolling_sum(\"val\", 5)");
7172 let expr2 = DExpr::RollingMean("col".into(), 3);
7173 assert_eq!(format!("{}", expr2), "rolling_mean(\"col\", 3)");
7174 }
7175
7176 #[test]
7177 fn rolling_collect_columns() {
7178 let expr = DExpr::RollingSum("revenue".into(), 7);
7179 let mut cols = Vec::new();
7180 collect_expr_columns(&expr, &mut cols);
7181 assert_eq!(cols, vec!["revenue".to_string()]);
7182 }
7183
7184 #[test]
7185 fn rolling_not_allowed_in_row_context() {
7186 let df = make_df("x", vec![1.0, 2.0, 3.0]);
7187 let expr = DExpr::RollingSum("x".into(), 2);
7188 let result = eval_expr_row(&df, &expr, 0);
7189 assert!(result.is_err());
7190 }
7191}
7192
7193