use std::collections::{BTreeMap, BTreeSet};
use std::fmt;
use std::rc::Rc;
use crate::bitmask::BitMask;
use crate::column::{Column, ColumnKeyRef, GroupKey};
use crate::dataframe::DataFrame;
use crate::expr::{self, DExpr, ExprValue};
use crate::kahan::KahanAccumulator;
#[derive(Debug, Clone)]
pub enum TidyError {
ColumnNotFound(String),
DuplicateColumn(String),
PredicateNotBool { got: String },
TypeMismatch { expected: String, got: String },
LengthMismatch { expected: usize, got: usize },
Internal(String),
EmptyGroup,
}
impl fmt::Display for TidyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TidyError::ColumnNotFound(n) => write!(f, "column `{}` not found", n),
TidyError::DuplicateColumn(n) => write!(f, "duplicate column `{}`", n),
TidyError::PredicateNotBool { got } => {
write!(f, "filter predicate must be Bool, got {}", got)
}
TidyError::TypeMismatch { expected, got } => {
write!(f, "type mismatch: expected {}, got {}", expected, got)
}
TidyError::LengthMismatch { expected, got } => {
write!(f, "length mismatch: expected {} rows, got {}", expected, got)
}
TidyError::Internal(msg) => write!(f, "internal error: {}", msg),
TidyError::EmptyGroup => write!(f, "aggregation on empty group"),
}
}
}
impl std::error::Error for TidyError {}
#[derive(Debug, Clone)]
pub struct ProjectionMap {
indices: Option<Vec<usize>>,
}
impl ProjectionMap {
pub fn all() -> Self {
Self { indices: None }
}
pub fn from_indices(indices: Vec<usize>) -> Self {
Self {
indices: Some(indices),
}
}
pub fn resolve(&self, ncols: usize) -> Vec<usize> {
match &self.indices {
Some(idx) => idx.clone(),
None => (0..ncols).collect(),
}
}
pub fn len(&self, ncols: usize) -> usize {
match &self.indices {
Some(idx) => idx.len(),
None => ncols,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrangeKey {
pub col_name: String,
pub descending: bool,
}
impl ArrangeKey {
pub fn asc(col_name: &str) -> Self {
Self {
col_name: col_name.to_string(),
descending: false,
}
}
pub fn desc(col_name: &str) -> Self {
Self {
col_name: col_name.to_string(),
descending: true,
}
}
}
#[derive(Debug, Clone)]
pub enum TidyAgg {
Count,
Sum(String),
Mean(String),
Min(String),
Max(String),
Sd(String),
Var(String),
First(String),
Last(String),
NDistinct(String),
}
#[derive(Debug, Clone)]
pub struct GroupMeta {
pub key_values: Vec<GroupKey>,
pub row_indices: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct GroupIndex {
pub groups: Vec<GroupMeta>,
pub key_names: Vec<String>,
}
impl GroupIndex {
pub fn build_fast_typed(
base: &DataFrame,
key_col_indices: &[usize],
visible_rows: &[usize],
key_names: Vec<String>,
) -> Self {
if key_col_indices.len() == 1 {
return Self::build_single(base, key_col_indices[0], visible_rows, key_names);
}
Self::build_multi(base, key_col_indices, visible_rows, key_names)
}
fn build_single(
base: &DataFrame,
key_col_idx: usize,
visible_rows: &[usize],
key_names: Vec<String>,
) -> Self {
let col = &base.columns[key_col_idx].1;
let mut groups: Vec<GroupMeta> = Vec::new();
let mut key_to_slot: BTreeMap<ColumnKeyRef<'_>, usize> = BTreeMap::new();
for &row in visible_rows {
let key = ColumnKeyRef::from_column(col, row);
if let Some(&slot) = key_to_slot.get(&key) {
groups[slot].row_indices.push(row);
} else {
let slot = groups.len();
let key_values = vec![key.to_owned_key()];
key_to_slot.insert(key, slot);
groups.push(GroupMeta {
key_values,
row_indices: vec![row],
});
}
}
GroupIndex { groups, key_names }
}
fn build_multi(
base: &DataFrame,
key_col_indices: &[usize],
visible_rows: &[usize],
key_names: Vec<String>,
) -> Self {
let mut groups: Vec<GroupMeta> = Vec::new();
let mut key_to_slot: BTreeMap<Vec<ColumnKeyRef<'_>>, usize> = BTreeMap::new();
let cols: Vec<&Column> = key_col_indices
.iter()
.map(|&ci| &base.columns[ci].1)
.collect();
for &row in visible_rows {
let key: Vec<ColumnKeyRef<'_>> =
cols.iter().map(|col| ColumnKeyRef::from_column(col, row)).collect();
if let Some(&slot) = key_to_slot.get(&key) {
groups[slot].row_indices.push(row);
} else {
let slot = groups.len();
let key_values: Vec<GroupKey> = key.iter().map(|k| k.to_owned_key()).collect();
key_to_slot.insert(key, slot);
groups.push(GroupMeta {
key_values,
row_indices: vec![row],
});
}
}
GroupIndex { groups, key_names }
}
}
#[derive(Debug, Clone)]
pub struct TidyView {
pub(crate) base: Rc<DataFrame>,
pub(crate) mask: BitMask,
pub(crate) proj: ProjectionMap,
pub(crate) ordering: Option<Rc<Vec<usize>>>,
}
impl TidyView {
pub fn new(df: DataFrame) -> Self {
let nrows = df.nrows();
Self {
base: Rc::new(df),
mask: BitMask::all_true(nrows),
proj: ProjectionMap::all(),
ordering: None,
}
}
pub fn from_rc(base: Rc<DataFrame>) -> Self {
let nrows = base.nrows();
Self {
base,
mask: BitMask::all_true(nrows),
proj: ProjectionMap::all(),
ordering: None,
}
}
pub fn nrows(&self) -> usize {
self.mask.count_ones()
}
pub fn ncols(&self) -> usize {
self.proj.len(self.base.ncols())
}
pub fn mask(&self) -> &BitMask {
&self.mask
}
pub fn base(&self) -> &DataFrame {
&self.base
}
fn visible_rows_ordered(&self) -> Vec<usize> {
if let Some(ref ord) = self.ordering {
ord.as_ref().clone()
} else {
self.mask.iter_set().collect()
}
}
fn resolve_ordering(&self) -> Option<TidyView> {
let ord = self.ordering.as_ref()?;
let row_indices: &[usize] = ord.as_ref();
let mut all_cols = Vec::with_capacity(self.base.ncols());
for (name, col) in &self.base.columns {
all_cols.push((name.clone(), col.gather(row_indices)));
}
let new_base =
DataFrame::from_columns(all_cols).expect("resolve_ordering: column length mismatch");
let nrows = new_base.nrows();
Some(TidyView {
base: Rc::new(new_base),
mask: BitMask::all_true(nrows),
proj: self.proj.clone(),
ordering: None,
})
}
fn view_from_row_indices(&self, row_indices: Vec<usize>) -> TidyView {
let nrows_base = self.base.nrows();
let mut words = vec![0u64; crate::bitmask::nwords_for(nrows_base)];
for &r in &row_indices {
words[r / 64] |= 1u64 << (r % 64);
}
TidyView {
base: Rc::clone(&self.base),
mask: BitMask {
words,
nrows: nrows_base,
},
proj: self.proj.clone(),
ordering: None,
}
}
pub fn filter(&self, predicate: &DExpr) -> Result<TidyView, TidyError> {
if let Some(ref ord) = self.ordering {
let pred_mask =
if let Some(m) = expr::try_eval_predicate_columnar(&self.base, predicate, &self.mask)
{
m
} else {
let nrows_base = self.base.nrows();
let mut new_words = self.mask.words.clone();
for &row in ord.iter() {
let b = expr::eval_expr_row(&self.base, predicate, row)
.map_err(|e| TidyError::Internal(e))?;
let pass = match b {
ExprValue::Bool(v) => v,
_ => {
return Err(TidyError::PredicateNotBool {
got: b.type_name().to_string(),
})
}
};
if !pass {
new_words[row / 64] &= !(1u64 << (row % 64));
}
}
BitMask {
words: new_words,
nrows: nrows_base,
}
};
let new_ord: Vec<usize> = ord.iter().filter(|&&row| pred_mask.get(row)).copied().collect();
return Ok(TidyView {
base: Rc::clone(&self.base),
mask: pred_mask,
proj: self.proj.clone(),
ordering: Some(Rc::new(new_ord)),
});
}
if let Some(new_mask) =
expr::try_eval_predicate_columnar(&self.base, predicate, &self.mask)
{
return Ok(TidyView {
base: Rc::clone(&self.base),
mask: new_mask,
proj: self.proj.clone(),
ordering: None,
});
}
let nrows_base = self.base.nrows();
let mut new_words = self.mask.words.clone();
for row in self.mask.iter_set() {
let b = expr::eval_expr_row(&self.base, predicate, row)
.map_err(|e| TidyError::Internal(e))?;
let pass = match b {
ExprValue::Bool(v) => v,
_ => {
return Err(TidyError::PredicateNotBool {
got: b.type_name().to_string(),
})
}
};
if !pass {
new_words[row / 64] &= !(1u64 << (row % 64));
}
}
Ok(TidyView {
base: Rc::clone(&self.base),
mask: BitMask {
words: new_words,
nrows: nrows_base,
},
proj: self.proj.clone(),
ordering: None,
})
}
pub fn select(&self, cols: &[&str]) -> Result<TidyView, TidyError> {
let mut seen = BTreeSet::new();
for &name in cols {
if !seen.insert(name) {
return Err(TidyError::DuplicateColumn(name.to_string()));
}
}
let mut new_indices = Vec::with_capacity(cols.len());
for &name in cols {
let idx = self
.base
.columns
.iter()
.position(|(n, _)| n == name)
.ok_or_else(|| TidyError::ColumnNotFound(name.to_string()))?;
new_indices.push(idx);
}
Ok(TidyView {
base: Rc::clone(&self.base),
mask: self.mask.clone(),
proj: ProjectionMap::from_indices(new_indices),
ordering: self.ordering.clone(),
})
}
pub fn mutate(&self, assignments: &[(&str, DExpr)]) -> Result<DataFrame, TidyError> {
let mut seen = BTreeSet::new();
for &(name, _) in assignments {
if !seen.insert(name) {
return Err(TidyError::DuplicateColumn(name.to_string()));
}
}
let mut df = self.materialize()?;
let snapshot_names: Vec<String> = df.columns.iter().map(|(n, _)| n.clone()).collect();
for &(col_name, ref dexpr) in assignments {
validate_expr_columns(dexpr, &snapshot_names)?;
let nrows = df.nrows();
let new_col = eval_expr_column(&df, dexpr, nrows)?;
if let Some(pos) = df.columns.iter().position(|(n, _)| n == col_name) {
df.columns[pos].1 = new_col;
} else {
df.columns.push((col_name.to_string(), new_col));
}
}
Ok(df)
}
pub fn group_by(&self, keys: &[&str]) -> Result<GroupedTidyView, TidyError> {
let mut key_col_indices = Vec::with_capacity(keys.len());
for &key in keys {
let idx = self
.base
.columns
.iter()
.position(|(n, _)| n == key)
.ok_or_else(|| TidyError::ColumnNotFound(key.to_string()))?;
key_col_indices.push(idx);
}
let key_names: Vec<String> = keys.iter().map(|s| s.to_string()).collect();
let visible_rows: Vec<usize> = if self.ordering.is_none()
&& self.mask.count_ones() == self.base.nrows()
{
(0..self.base.nrows()).collect()
} else {
self.visible_rows_ordered()
};
let index =
GroupIndex::build_fast_typed(&self.base, &key_col_indices, &visible_rows, key_names);
Ok(GroupedTidyView {
view: self.clone(),
index,
})
}
pub fn arrange(&self, keys: &[ArrangeKey]) -> Result<TidyView, TidyError> {
for key in keys {
if self.base.get_column(&key.col_name).is_none() {
return Err(TidyError::ColumnNotFound(key.col_name.clone()));
}
}
let mut row_indices: Vec<usize> = self.mask.iter_set().collect();
let key_cols: Vec<(&Column, bool)> = keys
.iter()
.map(|key| {
let col = self.base.get_column(&key.col_name).unwrap();
(col, key.descending)
})
.collect();
row_indices.sort_by(|&a, &b| {
for &(col, desc) in &key_cols {
let ord = col.compare_rows(a, b);
let ord = if desc { ord.reverse() } else { ord };
if ord != std::cmp::Ordering::Equal {
return ord;
}
}
std::cmp::Ordering::Equal
});
Ok(TidyView {
base: Rc::clone(&self.base),
mask: self.mask.clone(),
proj: self.proj.clone(),
ordering: Some(Rc::new(row_indices)),
})
}
pub fn slice_head(&self, n: usize) -> TidyView {
let rows: Vec<usize> = self.visible_rows_ordered().into_iter().take(n).collect();
self.view_from_row_indices(rows)
}
pub fn slice_tail(&self, n: usize) -> TidyView {
let all = self.visible_rows_ordered();
let start = all.len().saturating_sub(n);
let rows = all[start..].to_vec();
self.view_from_row_indices(rows)
}
pub fn slice_sample(&self, n: usize, seed: u64) -> TidyView {
let resolved = self.resolve_ordering();
let this = resolved.as_ref().unwrap_or(self);
let mut visible: Vec<usize> = this.mask.iter_set().collect();
let total = visible.len();
if n >= total {
return this.view_from_row_indices(visible);
}
let mut rng = seed;
for i in 0..n {
rng = rng
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let j = i + (rng as usize % (total - i));
visible.swap(i, j);
}
visible.truncate(n);
visible.sort_unstable();
this.view_from_row_indices(visible)
}
pub fn distinct(&self, cols: &[&str]) -> Result<TidyView, TidyError> {
let resolved = self.resolve_ordering();
let this = resolved.as_ref().unwrap_or(self);
let mut col_indices = Vec::with_capacity(cols.len());
for &name in cols {
let idx = this
.base
.columns
.iter()
.position(|(n, _)| n == name)
.ok_or_else(|| TidyError::ColumnNotFound(name.to_string()))?;
col_indices.push(idx);
}
let mut seen_keys: BTreeSet<Vec<String>> = BTreeSet::new();
let mut selected_rows: Vec<usize> = Vec::new();
for row in this.mask.iter_set() {
let key: Vec<String> = col_indices
.iter()
.map(|&ci| this.base.columns[ci].1.get_display(row))
.collect();
if seen_keys.insert(key) {
selected_rows.push(row);
}
}
Ok(this.view_from_row_indices(selected_rows))
}
pub fn inner_join(
&self,
right: &TidyView,
on: &[(&str, &str)],
) -> Result<DataFrame, TidyError> {
let l = self.resolve_ordering();
let lref = l.as_ref().unwrap_or(self);
let r = right.resolve_ordering();
let rref = r.as_ref().unwrap_or(right);
let (left_rows, right_rows) = join_match_rows(lref, rref, on)?;
build_join_frame(lref, rref, &left_rows, &right_rows, on)
}
pub fn left_join(
&self,
right: &TidyView,
on: &[(&str, &str)],
) -> Result<DataFrame, TidyError> {
let l = self.resolve_ordering();
let lref = l.as_ref().unwrap_or(self);
let r = right.resolve_ordering();
let rref = r.as_ref().unwrap_or(right);
let (left_rows, right_rows_opt) = join_match_rows_optional(lref, rref, on)?;
build_left_join_frame(lref, rref, &left_rows, &right_rows_opt, on)
}
pub fn materialize(&self) -> Result<DataFrame, TidyError> {
let resolved = self.resolve_ordering();
let this = resolved.as_ref().unwrap_or(self);
let visible: Vec<usize> = this.mask.iter_set().collect();
let proj_indices = this.proj.resolve(this.base.ncols());
let mut result_cols = Vec::with_capacity(proj_indices.len());
for &ci in &proj_indices {
let (name, col) = &this.base.columns[ci];
result_cols.push((name.clone(), col.gather(&visible)));
}
DataFrame::from_columns(result_cols).map_err(|e| TidyError::Internal(e.to_string()))
}
pub fn column_names(&self) -> Vec<String> {
let proj_indices = self.proj.resolve(self.base.ncols());
proj_indices
.iter()
.map(|&i| self.base.columns[i].0.clone())
.collect()
}
}
#[derive(Debug, Clone)]
pub struct GroupedTidyView {
pub view: TidyView,
pub index: GroupIndex,
}
impl GroupedTidyView {
pub fn summarise(
&self,
assignments: &[(&str, TidyAgg)],
) -> Result<DataFrame, TidyError> {
let mut seen = BTreeSet::new();
for &(name, _) in assignments {
if !seen.insert(name) {
return Err(TidyError::DuplicateColumn(name.to_string()));
}
}
let base = &self.view.base;
let n_groups = self.index.groups.len();
let mut result_columns: Vec<(String, Column)> = Vec::new();
for (ki, key_name) in self.index.key_names.iter().enumerate() {
let base_col = base
.get_column(key_name)
.ok_or_else(|| TidyError::ColumnNotFound(key_name.clone()))?;
let col = match base_col {
Column::Int(_) => {
let vals: Vec<i64> = self.index.groups.iter().map(|g| {
match &g.key_values[ki] {
GroupKey::Int(v) => *v,
_ => 0,
}
}).collect();
Column::Int(vals)
}
Column::Float(_) => {
let vals: Vec<f64> = self.index.groups.iter().map(|g| {
match &g.key_values[ki] {
GroupKey::Float(fk) => fk.0,
GroupKey::Int(v) => *v as f64,
_ => 0.0,
}
}).collect();
Column::Float(vals)
}
Column::Bool(_) => {
let vals: Vec<bool> = self.index.groups.iter().map(|g| {
match &g.key_values[ki] {
GroupKey::Bool(v) => *v,
_ => false,
}
}).collect();
Column::Bool(vals)
}
_ => {
let vals: Vec<String> = self.index.groups.iter().map(|g| {
g.key_values[ki].to_display()
}).collect();
Column::Str(vals)
}
};
result_columns.push((key_name.clone(), col));
}
for &(out_name, ref agg) in assignments {
let col = self.eval_agg(agg, n_groups, base)?;
result_columns.push((out_name.to_string(), col));
}
DataFrame::from_columns(result_columns).map_err(|e| TidyError::Internal(e.to_string()))
}
fn eval_agg(
&self,
agg: &TidyAgg,
_n_groups: usize,
base: &DataFrame,
) -> Result<Column, TidyError> {
match agg {
TidyAgg::Count => {
let counts: Vec<i64> = self
.index
.groups
.iter()
.map(|g| g.row_indices.len() as i64)
.collect();
Ok(Column::Int(counts))
}
TidyAgg::Sum(col_name) => {
let col = base
.get_column(col_name)
.ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
let sums: Vec<f64> = self.index.groups.iter().map(|g| {
let mut acc = KahanAccumulator::new();
for &i in &g.row_indices {
if let Some(v) = col.get_f64(i) {
acc.add(v);
}
}
acc.finalize()
}).collect();
Ok(Column::Float(sums))
}
TidyAgg::Mean(col_name) => {
let col = base
.get_column(col_name)
.ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
let means: Vec<f64> = self.index.groups.iter().map(|g| {
if g.row_indices.is_empty() {
return f64::NAN;
}
let mut acc = KahanAccumulator::new();
for &i in &g.row_indices {
if let Some(v) = col.get_f64(i) {
acc.add(v);
}
}
acc.finalize() / g.row_indices.len() as f64
}).collect();
Ok(Column::Float(means))
}
TidyAgg::Min(col_name) => {
let col = base
.get_column(col_name)
.ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
let mins: Vec<f64> = self.index.groups.iter().map(|g| {
let mut min = f64::INFINITY;
for &i in &g.row_indices {
if let Some(v) = col.get_f64(i) {
if v < min { min = v; }
}
}
min
}).collect();
Ok(Column::Float(mins))
}
TidyAgg::Max(col_name) => {
let col = base
.get_column(col_name)
.ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
let maxs: Vec<f64> = self.index.groups.iter().map(|g| {
let mut max = f64::NEG_INFINITY;
for &i in &g.row_indices {
if let Some(v) = col.get_f64(i) {
if v > max { max = v; }
}
}
max
}).collect();
Ok(Column::Float(maxs))
}
TidyAgg::Sd(col_name) => {
let col = base
.get_column(col_name)
.ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
let sds: Vec<f64> = self.index.groups.iter().map(|g| {
let n = g.row_indices.len();
if n < 2 { return f64::NAN; }
let mut acc = KahanAccumulator::new();
for &i in &g.row_indices {
if let Some(v) = col.get_f64(i) { acc.add(v); }
}
let mean = acc.finalize() / n as f64;
let mut var_acc = KahanAccumulator::new();
for &i in &g.row_indices {
if let Some(v) = col.get_f64(i) {
let diff = v - mean;
var_acc.add(diff * diff);
}
}
(var_acc.finalize() / (n - 1) as f64).sqrt()
}).collect();
Ok(Column::Float(sds))
}
TidyAgg::Var(col_name) => {
let col = base
.get_column(col_name)
.ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
let vars: Vec<f64> = self.index.groups.iter().map(|g| {
let n = g.row_indices.len();
if n < 2 { return f64::NAN; }
let mut acc = KahanAccumulator::new();
for &i in &g.row_indices {
if let Some(v) = col.get_f64(i) { acc.add(v); }
}
let mean = acc.finalize() / n as f64;
let mut var_acc = KahanAccumulator::new();
for &i in &g.row_indices {
if let Some(v) = col.get_f64(i) {
let diff = v - mean;
var_acc.add(diff * diff);
}
}
var_acc.finalize() / (n - 1) as f64
}).collect();
Ok(Column::Float(vars))
}
TidyAgg::First(col_name) => {
let col = base
.get_column(col_name)
.ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
let vals: Result<Vec<String>, _> = self.index.groups.iter().map(|g| {
g.row_indices.first()
.map(|&i| col.get_display(i))
.ok_or(TidyError::EmptyGroup)
}).collect();
Ok(Column::Str(vals?))
}
TidyAgg::Last(col_name) => {
let col = base
.get_column(col_name)
.ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
let vals: Result<Vec<String>, _> = self.index.groups.iter().map(|g| {
g.row_indices.last()
.map(|&i| col.get_display(i))
.ok_or(TidyError::EmptyGroup)
}).collect();
Ok(Column::Str(vals?))
}
TidyAgg::NDistinct(col_name) => {
let col = base
.get_column(col_name)
.ok_or_else(|| TidyError::ColumnNotFound(col_name.clone()))?;
let counts: Vec<i64> = self.index.groups.iter().map(|g| {
let mut uniq = BTreeSet::new();
for &i in &g.row_indices {
uniq.insert(col.get_display(i));
}
uniq.len() as i64
}).collect();
Ok(Column::Int(counts))
}
}
}
pub fn group_index(&self) -> &GroupIndex {
&self.index
}
}
fn resolve_join_keys(
left: &TidyView,
right: &TidyView,
on: &[(&str, &str)],
) -> Result<(Vec<usize>, Vec<usize>), TidyError> {
let mut left_indices = Vec::with_capacity(on.len());
let mut right_indices = Vec::with_capacity(on.len());
for &(lk, rk) in on {
let li = left
.base
.column_index(lk)
.ok_or_else(|| TidyError::ColumnNotFound(lk.to_string()))?;
let ri = right
.base
.column_index(rk)
.ok_or_else(|| TidyError::ColumnNotFound(rk.to_string()))?;
left_indices.push(li);
right_indices.push(ri);
}
Ok((left_indices, right_indices))
}
fn join_match_rows(
left: &TidyView,
right: &TidyView,
on: &[(&str, &str)],
) -> Result<(Vec<usize>, Vec<usize>), TidyError> {
let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
let mut out_left = Vec::new();
let mut out_right = Vec::new();
if left_key_cols.len() == 1 {
let r_col = &right.base.columns[right_key_cols[0]].1;
let l_col = &left.base.columns[left_key_cols[0]].1;
let mut lookup: BTreeMap<ColumnKeyRef<'_>, Vec<usize>> = BTreeMap::new();
for r in right.mask.iter_set() {
let key = ColumnKeyRef::from_column(r_col, r);
lookup.entry(key).or_default().push(r);
}
for l_row in left.mask.iter_set() {
let key = ColumnKeyRef::from_column(l_col, l_row);
if let Some(matches) = lookup.get(&key) {
for &r_row in matches {
out_left.push(l_row);
out_right.push(r_row);
}
}
}
return Ok((out_left, out_right));
}
let r_cols: Vec<&Column> = right_key_cols
.iter()
.map(|&ci| &right.base.columns[ci].1)
.collect();
let l_cols: Vec<&Column> = left_key_cols
.iter()
.map(|&ci| &left.base.columns[ci].1)
.collect();
let mut lookup: BTreeMap<Vec<ColumnKeyRef<'_>>, Vec<usize>> = BTreeMap::new();
for r in right.mask.iter_set() {
let key: Vec<ColumnKeyRef<'_>> = r_cols.iter().map(|col| ColumnKeyRef::from_column(col, r)).collect();
lookup.entry(key).or_default().push(r);
}
for l_row in left.mask.iter_set() {
let key: Vec<ColumnKeyRef<'_>> = l_cols.iter().map(|col| ColumnKeyRef::from_column(col, l_row)).collect();
if let Some(matches) = lookup.get(&key) {
for &r_row in matches {
out_left.push(l_row);
out_right.push(r_row);
}
}
}
Ok((out_left, out_right))
}
fn join_match_rows_optional(
left: &TidyView,
right: &TidyView,
on: &[(&str, &str)],
) -> Result<(Vec<usize>, Vec<Option<usize>>), TidyError> {
let (left_key_cols, right_key_cols) = resolve_join_keys(left, right, on)?;
let mut out_left = Vec::new();
let mut out_right: Vec<Option<usize>> = Vec::new();
if left_key_cols.len() == 1 {
let r_col = &right.base.columns[right_key_cols[0]].1;
let l_col = &left.base.columns[left_key_cols[0]].1;
let mut lookup: BTreeMap<ColumnKeyRef<'_>, Vec<usize>> = BTreeMap::new();
for r in right.mask.iter_set() {
let key = ColumnKeyRef::from_column(r_col, r);
lookup.entry(key).or_default().push(r);
}
for l_row in left.mask.iter_set() {
let key = ColumnKeyRef::from_column(l_col, l_row);
match lookup.get(&key) {
Some(matches) if !matches.is_empty() => {
for &r_row in matches {
out_left.push(l_row);
out_right.push(Some(r_row));
}
}
_ => {
out_left.push(l_row);
out_right.push(None);
}
}
}
return Ok((out_left, out_right));
}
let r_cols: Vec<&Column> = right_key_cols.iter().map(|&ci| &right.base.columns[ci].1).collect();
let l_cols: Vec<&Column> = left_key_cols.iter().map(|&ci| &left.base.columns[ci].1).collect();
let mut lookup: BTreeMap<Vec<ColumnKeyRef<'_>>, Vec<usize>> = BTreeMap::new();
for r in right.mask.iter_set() {
let key: Vec<ColumnKeyRef<'_>> = r_cols.iter().map(|col| ColumnKeyRef::from_column(col, r)).collect();
lookup.entry(key).or_default().push(r);
}
for l_row in left.mask.iter_set() {
let key: Vec<ColumnKeyRef<'_>> = l_cols.iter().map(|col| ColumnKeyRef::from_column(col, l_row)).collect();
match lookup.get(&key) {
Some(matches) if !matches.is_empty() => {
for &r_row in matches {
out_left.push(l_row);
out_right.push(Some(r_row));
}
}
_ => {
out_left.push(l_row);
out_right.push(None);
}
}
}
Ok((out_left, out_right))
}
fn build_join_frame(
left: &TidyView,
right: &TidyView,
left_rows: &[usize],
right_rows: &[usize],
on: &[(&str, &str)],
) -> Result<DataFrame, TidyError> {
let right_key_names: BTreeSet<&str> = on.iter().map(|&(_, rk)| rk).collect();
let mut result_cols: Vec<(String, Column)> = Vec::new();
for (name, col) in &left.base.columns {
result_cols.push((name.clone(), col.gather(left_rows)));
}
for (name, col) in &right.base.columns {
if !right_key_names.contains(name.as_str()) {
result_cols.push((name.clone(), col.gather(right_rows)));
}
}
DataFrame::from_columns(result_cols).map_err(|e| TidyError::Internal(e.to_string()))
}
fn build_left_join_frame(
left: &TidyView,
right: &TidyView,
left_rows: &[usize],
right_rows_opt: &[Option<usize>],
on: &[(&str, &str)],
) -> Result<DataFrame, TidyError> {
let right_key_names: BTreeSet<&str> = on.iter().map(|&(_, rk)| rk).collect();
let mut result_cols: Vec<(String, Column)> = Vec::new();
for (name, col) in &left.base.columns {
result_cols.push((name.clone(), col.gather(left_rows)));
}
for (name, col) in &right.base.columns {
if right_key_names.contains(name.as_str()) {
continue;
}
let gathered = match col {
Column::Int(v) => Column::Int(
right_rows_opt.iter().map(|opt| opt.map(|i| v[i]).unwrap_or(0)).collect(),
),
Column::Float(v) => Column::Float(
right_rows_opt.iter().map(|opt| opt.map(|i| v[i]).unwrap_or(0.0)).collect(),
),
Column::Str(v) => Column::Str(
right_rows_opt
.iter()
.map(|opt| opt.map(|i| v[i].clone()).unwrap_or_default())
.collect(),
),
Column::Bool(v) => Column::Bool(
right_rows_opt.iter().map(|opt| opt.map(|i| v[i]).unwrap_or(false)).collect(),
),
};
result_cols.push((name.clone(), gathered));
}
DataFrame::from_columns(result_cols).map_err(|e| TidyError::Internal(e.to_string()))
}
fn validate_expr_columns(expr: &DExpr, valid_names: &[String]) -> Result<(), TidyError> {
match expr {
DExpr::Col(name) => {
if !valid_names.iter().any(|n| n == name) {
return Err(TidyError::ColumnNotFound(name.clone()));
}
Ok(())
}
DExpr::BinOp { left, right, .. } => {
validate_expr_columns(left, valid_names)?;
validate_expr_columns(right, valid_names)
}
DExpr::Not(inner) => validate_expr_columns(inner, valid_names),
DExpr::And(a, b) | DExpr::Or(a, b) => {
validate_expr_columns(a, valid_names)?;
validate_expr_columns(b, valid_names)
}
_ => Ok(()),
}
}
fn eval_expr_column(
df: &DataFrame,
dexpr: &DExpr,
nrows: usize,
) -> Result<Column, TidyError> {
let mut floats = Vec::with_capacity(nrows);
let mut ints = Vec::with_capacity(nrows);
let mut strings = Vec::with_capacity(nrows);
let mut bools = Vec::with_capacity(nrows);
let mut first_type: Option<&str> = None;
for row in 0..nrows {
let val = expr::eval_expr_row(df, dexpr, row)
.map_err(|e| TidyError::Internal(e))?;
match &val {
ExprValue::Float(v) => {
if first_type.is_none() { first_type = Some("Float"); }
floats.push(*v);
}
ExprValue::Int(v) => {
if first_type.is_none() { first_type = Some("Int"); }
ints.push(*v);
}
ExprValue::Str(v) => {
if first_type.is_none() { first_type = Some("Str"); }
strings.push(v.clone());
}
ExprValue::Bool(v) => {
if first_type.is_none() { first_type = Some("Bool"); }
bools.push(*v);
}
}
}
match first_type {
Some("Float") => Ok(Column::Float(floats)),
Some("Int") => Ok(Column::Int(ints)),
Some("Str") => Ok(Column::Str(strings)),
Some("Bool") => Ok(Column::Bool(bools)),
_ => Ok(Column::Float(Vec::new())),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{binop, col, BinOp};
fn make_test_df() -> DataFrame {
DataFrame::from_columns(vec![
("id".into(), Column::Int(vec![1, 2, 3, 4, 5])),
(
"region".into(),
Column::Str(vec![
"West".into(), "East".into(), "West".into(),
"East".into(), "West".into(),
]),
),
("value".into(), Column::Float(vec![10.0, 20.0, 30.0, 40.0, 50.0])),
])
.unwrap()
}
#[test]
fn test_filter_no_copy() {
let df = make_test_df();
let view = TidyView::new(df);
let filtered = view
.filter(&binop(BinOp::Gt, col("value"), DExpr::LitFloat(25.0)))
.unwrap();
assert_eq!(filtered.nrows(), 3); assert_eq!(view.nrows(), 5); }
#[test]
fn test_chained_filter() {
let df = make_test_df();
let view = TidyView::new(df);
let filtered = view
.filter(&binop(BinOp::Gt, col("value"), DExpr::LitFloat(15.0)))
.unwrap()
.filter(&binop(BinOp::Lt, col("value"), DExpr::LitFloat(45.0)))
.unwrap();
assert_eq!(filtered.nrows(), 3); }
#[test]
fn test_select() {
let df = make_test_df();
let view = TidyView::new(df);
let selected = view.select(&["id", "value"]).unwrap();
assert_eq!(selected.ncols(), 2);
assert_eq!(view.ncols(), 3); }
#[test]
fn test_group_by_summarise() {
let df = make_test_df();
let view = TidyView::new(df);
let grouped = view.group_by(&["region"]).unwrap();
let summary = grouped
.summarise(&[
("n", TidyAgg::Count),
("total", TidyAgg::Sum("value".into())),
("avg", TidyAgg::Mean("value".into())),
])
.unwrap();
assert_eq!(summary.nrows(), 2);
assert_eq!(summary.ncols(), 4); }
#[test]
fn test_group_order_first_occurrence() {
let df = make_test_df();
let view = TidyView::new(df);
let grouped = view.group_by(&["region"]).unwrap();
assert_eq!(grouped.index.groups[0].key_values[0].to_display(), "West");
assert_eq!(grouped.index.groups[1].key_values[0].to_display(), "East");
}
#[test]
fn test_arrange() {
let df = make_test_df();
let view = TidyView::new(df);
let sorted = view.arrange(&[ArrangeKey::desc("value")]).unwrap();
let mat = sorted.materialize().unwrap();
if let Column::Float(vals) = &mat.columns[2].1 {
assert_eq!(vals, &[50.0, 40.0, 30.0, 20.0, 10.0]);
}
}
#[test]
fn test_inner_join() {
let left = DataFrame::from_columns(vec![
("id".into(), Column::Int(vec![1, 2, 3])),
("name".into(), Column::Str(vec!["a".into(), "b".into(), "c".into()])),
]).unwrap();
let right = DataFrame::from_columns(vec![
("id".into(), Column::Int(vec![1, 3, 4])),
("dept".into(), Column::Str(vec!["eng".into(), "sales".into(), "hr".into()])),
]).unwrap();
let lv = TidyView::new(left);
let rv = TidyView::new(right);
let joined = lv.inner_join(&rv, &[("id", "id")]).unwrap();
assert_eq!(joined.nrows(), 2); }
#[test]
fn test_deterministic_sample() {
let df = make_test_df();
let view = TidyView::new(df);
let s1 = view.slice_sample(3, 42);
let s2 = view.slice_sample(3, 42);
let r1: Vec<usize> = s1.mask.iter_set().collect();
let r2: Vec<usize> = s2.mask.iter_set().collect();
assert_eq!(r1, r2); }
#[test]
fn test_distinct() {
let df = DataFrame::from_columns(vec![
("x".into(), Column::Str(vec!["a".into(), "b".into(), "a".into(), "c".into()])),
]).unwrap();
let view = TidyView::new(df);
let unique = view.distinct(&["x"]).unwrap();
assert_eq!(unique.nrows(), 3); }
#[test]
fn test_kahan_summation_in_summarise() {
let n = 10_000;
let values: Vec<f64> = (0..n).map(|_| 0.1).collect();
let grps: Vec<String> = (0..n).map(|_| "a".to_string()).collect();
let df = DataFrame::from_columns(vec![
("grp".into(), Column::Str(grps)),
("val".into(), Column::Float(values)),
]).unwrap();
let view = TidyView::new(df);
let grouped = view.group_by(&["grp"]).unwrap();
let summary = grouped.summarise(&[("total", TidyAgg::Sum("val".into()))]).unwrap();
if let Column::Float(v) = &summary.columns[1].1 {
assert!(
(v[0] - 1000.0).abs() < 1e-6,
"Kahan sum {} should be close to 1000.0", v[0]
);
}
}
#[test]
fn test_snapshot_semantics() {
let df = DataFrame::from_columns(vec![
("x".into(), Column::Int(vec![1, 2, 3])),
]).unwrap();
let view = TidyView::new(df);
let result = view.mutate(&[
("a", binop(BinOp::Add, col("x"), DExpr::LitInt(1))),
("b", binop(BinOp::Mul, col("a"), DExpr::LitInt(2))),
]);
assert!(result.is_err()); }
}