mod bitmap;
mod builder;
mod column;
mod dtype;
mod field;
mod schema;
mod selection;
pub use self::bitmap::Bitmap;
pub use self::builder::TableBuilder;
pub use self::column::{BooleanCol, Column, DictionaryCol, PrimitiveCol, Utf8Col};
pub use self::dtype::DataType;
pub use self::field::Field;
pub use self::schema::Schema;
pub use self::selection::RowSelection;
use std::sync::Arc;
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
};
use crate::error::{Error, Result};
use crate::ops::{
self, AggregateExpr, ColumnSelector, CompareOp, GroupKey, JoinOptions, JoinType, Literal,
Predicate, SortKey,
};
#[derive(Debug)]
struct TableStorage {
schema: Arc<Schema>,
columns: Vec<Column>,
nrows: u32,
}
#[derive(Clone, Debug)]
pub struct Table {
storage: Arc<TableStorage>,
selection: RowSelection,
}
impl Table {
pub fn from_columns(schema: Schema, columns: Vec<Column>) -> Result<Self> {
if schema.len() != columns.len() {
return Err(Error::Schema(format!(
"schema has {} fields but {} columns were provided",
schema.len(),
columns.len()
)));
}
let nrows = columns.first().map_or(0_u32, |column| column.len() as u32);
for column in &columns {
if column.len() as u32 != nrows {
return Err(Error::Schema(
"all columns must have the same row count".to_string(),
));
}
}
Ok(Self {
storage: Arc::new(TableStorage {
schema: Arc::new(schema),
columns,
nrows,
}),
selection: RowSelection::All,
})
}
pub fn empty() -> Self {
Self {
storage: Arc::new(TableStorage {
schema: Arc::new(Schema::new(Vec::new()).expect("empty schema is valid")),
columns: Vec::new(),
nrows: 0,
}),
selection: RowSelection::All,
}
}
pub fn schema(&self) -> &Schema {
self.storage.schema.as_ref()
}
pub fn columns(&self) -> &[Column] {
&self.storage.columns
}
pub fn column(&self, index: usize) -> Option<&Column> {
self.storage.columns.get(index)
}
pub fn column_by_name(&self, name: &str) -> Option<&Column> {
let index = self.schema().index_of(name)?;
self.column(index)
}
pub fn nrows(&self) -> usize {
self.selected_row_indices().len()
}
pub fn ncols(&self) -> usize {
self.storage.columns.len()
}
pub fn selection(&self) -> &RowSelection {
&self.selection
}
pub fn select(&self, selectors: &[ColumnSelector]) -> Result<Self> {
let indices = ops::resolve_selectors(self.schema(), selectors)?;
let fields = indices
.iter()
.map(|&index| self.schema().field(index).cloned())
.collect::<Option<Vec<_>>>()
.ok_or_else(|| Error::InvalidSelection("column index out of bounds".to_string()))?;
let columns = indices
.iter()
.map(|&index| self.storage.columns.get(index).cloned())
.collect::<Option<Vec<_>>>()
.ok_or_else(|| Error::InvalidSelection("column index out of bounds".to_string()))?;
Ok(Self {
storage: Arc::new(TableStorage {
schema: Arc::new(Schema::new(fields)?),
columns,
nrows: self.storage.nrows,
}),
selection: self.selection.clone(),
})
}
pub fn head(&self, len: usize) -> Self {
let indices = self.selected_row_indices();
let len = len.min(indices.len());
let selection = if len == indices.len() {
self.selection.clone()
} else {
RowSelection::Indices(indices.into_iter().take(len).collect())
};
Self {
storage: Arc::clone(&self.storage),
selection,
}
}
pub fn tail(&self, len: usize) -> Self {
let indices = self.selected_row_indices();
let len = len.min(indices.len());
let selection = if len == indices.len() {
self.selection.clone()
} else {
RowSelection::Indices(
indices
.into_iter()
.rev()
.take(len)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect(),
)
};
Self {
storage: Arc::clone(&self.storage),
selection,
}
}
pub fn filter(&self, predicate: &Predicate) -> Result<Self> {
let selected = self.selected_row_indices();
let mut matched = Vec::new();
for row_index in selected.iter().copied() {
if self.evaluate_predicate(row_index, predicate)? {
matched.push(row_index);
}
}
let selection = if matched.len() == selected.len() {
self.selection.clone()
} else {
RowSelection::Indices(matched)
};
Ok(Self {
storage: Arc::clone(&self.storage),
selection,
})
}
pub fn sort_by(&self, keys: &[SortKey]) -> Result<Self> {
if keys.is_empty() {
return Ok(self.clone());
}
let columns = keys
.iter()
.map(|key| self.resolve_column(&key.column))
.collect::<Result<Vec<_>>>()?;
let mut indices = self.selected_row_indices();
indices.sort_by(|left, right| compare_rows_for_sort(&columns, keys, *left, *right));
if indices == self.selected_row_indices() {
return Ok(self.clone());
}
Ok(Self {
storage: Arc::clone(&self.storage),
selection: RowSelection::Indices(indices),
})
}
pub fn join(&self, right: &Table, options: &JoinOptions) -> Result<Self> {
if options.keys.is_empty() {
return Err(Error::InvalidArgument(
"join requires at least one key".to_string(),
));
}
if options.keys.len() == 1 {
let key = options.keys.first().ok_or_else(|| {
Error::InvalidArgument("join requires at least one key".to_string())
})?;
let left_index = resolve_selector_index(self.schema(), &key.left)?;
let right_index = resolve_selector_index(right.schema(), &key.right)?;
let left_column = self.column(left_index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {left_index}"))
})?;
let right_column = right.column(right_index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {right_index}"))
})?;
if !join_key_types_compatible(left_column.dtype(), right_column.dtype()) {
return Err(Error::TypeMismatch {
expected: left_column.dtype().to_string(),
actual: right_column.dtype().to_string(),
});
}
if can_use_single_key_join_fast_path(left_column, right_column) {
return self.join_single_key(right, options, left_index, right_index);
}
}
self.join_multi_key(right, options)
}
fn join_multi_key(&self, right: &Table, options: &JoinOptions) -> Result<Self> {
let key_pairs = options
.keys
.iter()
.map(|key| {
let left_index = resolve_selector_index(self.schema(), &key.left)?;
let right_index = resolve_selector_index(right.schema(), &key.right)?;
let left_column = self.column(left_index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {left_index}"))
})?;
let right_column = right.column(right_index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {right_index}"))
})?;
if !join_key_types_compatible(left_column.dtype(), right_column.dtype()) {
return Err(Error::TypeMismatch {
expected: left_column.dtype().to_string(),
actual: right_column.dtype().to_string(),
});
}
Ok((left_index, right_index))
})
.collect::<Result<Vec<_>>>()?;
let left_key_columns = key_pairs
.iter()
.map(|(left_index, _)| {
self.column(*left_index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {left_index}"))
})
})
.collect::<Result<Vec<_>>>()?;
let right_key_columns = key_pairs
.iter()
.map(|(_, right_index)| {
right.column(*right_index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {right_index}"))
})
})
.collect::<Result<Vec<_>>>()?;
let left_indices = self.selected_row_indices();
let right_indices = right.selected_row_indices();
let mut right_lookup = HashMap::<Vec<GroupValue>, Vec<usize>>::new();
for (position, row_index) in right_indices.iter().copied().enumerate() {
if let Some(key) = join_key_from_columns(&right_key_columns, row_index)? {
right_lookup.entry(key).or_default().push(position);
}
}
let mut matched_right = vec![false; right_indices.len()];
let mut row_pairs = Vec::<(Option<u32>, Option<u32>)>::new();
for left_row in left_indices.iter().copied() {
let mut matched_any = false;
if let Some(key) = join_key_from_columns(&left_key_columns, left_row)? {
if let Some(right_positions) = right_lookup.get(&key) {
matched_any = true;
for &position in right_positions {
matched_right[position] = true;
row_pairs.push((Some(left_row), right_indices.get(position).copied()));
}
}
}
if !matched_any && matches!(options.join_type, JoinType::Left | JoinType::Full) {
row_pairs.push((Some(left_row), None));
}
}
if matches!(options.join_type, JoinType::Right | JoinType::Full) {
for (position, right_row) in right_indices.iter().copied().enumerate() {
if !matched_right[position] {
row_pairs.push((None, Some(right_row)));
}
}
}
let fields = build_join_fields(self.schema(), right.schema(), options);
let left_rows = row_pairs
.iter()
.map(|(left_row, _)| *left_row)
.collect::<Vec<_>>();
let right_rows = row_pairs
.iter()
.map(|(_, right_row)| *right_row)
.collect::<Vec<_>>();
let mut columns = Vec::with_capacity(self.ncols() + right.ncols());
for column in self.columns() {
columns.push(build_join_column(column, &left_rows)?);
}
for column in right.columns() {
columns.push(build_join_column(column, &right_rows)?);
}
Table::from_columns(Schema::new(fields)?, columns)
}
pub fn group_by(&self, keys: &[GroupKey], aggs: &[AggregateExpr]) -> Result<Self> {
if keys.is_empty() {
return Err(Error::InvalidArgument(
"group_by requires at least one key".to_string(),
));
}
if aggs.is_empty() {
return Err(Error::InvalidArgument(
"group_by requires at least one aggregate".to_string(),
));
}
if keys.len() == 1 {
return self.group_by_single_key(&keys[0], aggs);
}
let key_columns = keys
.iter()
.map(|key| {
let index = resolve_selector_index(self.schema(), &key.column)?;
let field = self.schema().field(index).ok_or_else(|| {
Error::InvalidSelection(format!("missing field at index {index}"))
})?;
let column = self.column(index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {index}"))
})?;
Ok((field.clone(), column))
})
.collect::<Result<Vec<_>>>()?;
let aggregates = aggs
.iter()
.map(|aggregate| resolve_aggregate(self, aggregate))
.collect::<Result<Vec<_>>>()?;
let mut group_index = HashMap::<Vec<GroupValue>, usize>::new();
let mut grouped_keys = Vec::<Vec<GroupValue>>::new();
let mut grouped_states = Vec::<Vec<AggState>>::new();
for row_index in self.selected_row_indices() {
let key = key_columns
.iter()
.map(|(_, column)| group_value_from_column(column, row_index))
.collect::<Result<Vec<_>>>()?;
let entry = if let Some(existing) = group_index.get(&key) {
*existing
} else {
let index = grouped_states.len();
group_index.insert(key.clone(), index);
grouped_keys.push(key);
grouped_states.push(
aggregates
.iter()
.map(AggState::new)
.collect::<Result<Vec<_>>>()?,
);
index
};
let states = grouped_states
.get_mut(entry)
.ok_or_else(|| Error::InvalidSelection("group index out of bounds".to_string()))?;
for (state, aggregate) in states.iter_mut().zip(aggregates.iter()) {
state.observe(aggregate, row_index)?;
}
}
let mut fields = key_columns
.iter()
.map(|(field, _)| field.clone())
.collect::<Vec<_>>();
fields.extend(
aggregates
.iter()
.map(|aggregate| {
Field::new(Arc::clone(&aggregate.alias), aggregate.output_dtype())
.with_nullability(aggregate.may_produce_null())
})
.collect::<Vec<_>>(),
);
let mut columns = Vec::new();
for (key_index, (field, _)) in key_columns.iter().enumerate() {
let values = grouped_keys
.iter()
.map(|key| key.get(key_index).cloned().unwrap_or(GroupValue::Null))
.collect::<Vec<_>>();
columns.push(build_column_from_group_values(field.dtype, &values)?);
}
for agg_index in 0..aggregates.len() {
let states = grouped_states
.iter()
.map(|states| {
states.get(agg_index).cloned().ok_or_else(|| {
Error::InvalidSelection("aggregate state missing".to_string())
})
})
.collect::<Result<Vec<_>>>()?;
columns.push(build_column_from_aggregate_states(
aggregates
.get(agg_index)
.ok_or_else(|| Error::InvalidSelection("aggregate spec missing".to_string()))?,
&states,
)?);
}
Table::from_columns(Schema::new(fields)?, columns)
}
pub fn materialize(&self) -> Result<Self> {
if self.selection.is_all() {
return Ok(self.clone());
}
let indices = self.selected_row_indices();
let fields = self.schema().fields().to_vec();
let columns = self
.storage
.columns
.iter()
.map(|column| column.gather(&indices))
.collect::<Result<Vec<_>>>()?;
Self::from_columns(Schema::new(fields)?, columns)
}
pub(crate) fn selected_row_indices(&self) -> Vec<u32> {
match &self.selection {
RowSelection::All => (0..self.storage.nrows).collect(),
RowSelection::Range { offset, len } => (*offset..(*offset + *len)).collect(),
RowSelection::Indices(indices) => indices.clone(),
RowSelection::Bitmap(bitmap) => (0..bitmap.len())
.filter(|&index| bitmap.is_set(index))
.collect(),
}
}
fn evaluate_predicate(&self, row_index: u32, predicate: &Predicate) -> Result<bool> {
match predicate {
Predicate::Comparison { column, op, value } => {
let column = self.resolve_column(column)?;
self.evaluate_comparison(column, row_index, *op, value.as_ref())
}
Predicate::And(predicates) => predicates.iter().try_fold(true, |acc, predicate| {
if !acc {
Ok(false)
} else {
self.evaluate_predicate(row_index, predicate)
}
}),
Predicate::Or(predicates) => predicates.iter().try_fold(false, |acc, predicate| {
if acc {
Ok(true)
} else {
self.evaluate_predicate(row_index, predicate)
}
}),
Predicate::Not(predicate) => self.evaluate_predicate(row_index, predicate).map(|v| !v),
}
}
fn resolve_column(&self, selector: &ColumnSelector) -> Result<&Column> {
match selector {
ColumnSelector::Name(name) => self
.column_by_name(name)
.ok_or_else(|| Error::ColumnNotFound(name.to_string())),
ColumnSelector::Index(index) => self.column(*index).ok_or_else(|| {
Error::InvalidSelection(format!("column index {index} is out of bounds"))
}),
}
}
fn evaluate_comparison(
&self,
column: &Column,
row_index: u32,
op: CompareOp,
value: Option<&Literal>,
) -> Result<bool> {
match op {
CompareOp::IsNull => Ok(column.is_null(row_index)),
CompareOp::IsNotNull => Ok(!column.is_null(row_index)),
_ if column.is_null(row_index) => Ok(false),
_ => {
let literal = value.ok_or_else(|| {
Error::InvalidArgument(format!("comparison operator {op:?} requires a literal"))
})?;
match (column, literal) {
(Column::Bool(_), Literal::Bool(expected)) => {
compare_ord(column.bool_value(row_index), Some(*expected), op)
}
(Column::I64(_) | Column::TimestampMs(_), Literal::I64(expected)) => {
compare_ord(column.i64_value(row_index), Some(*expected), op)
}
(Column::Date32(_), Literal::Date32(expected)) => {
compare_ord(column.i32_value(row_index), Some(*expected), op)
}
(Column::F64(_), Literal::F64(expected)) => {
compare_partial(column.f64_value(row_index), Some(*expected), op)
}
(Column::F64(_), Literal::I64(expected)) => {
compare_partial(column.f64_value(row_index), Some(*expected as f64), op)
}
(Column::Utf8(_) | Column::DictUtf8(_), Literal::Utf8(expected)) => {
compare_str(column.utf8_value(row_index), Some(expected.as_ref()), op)
}
_ => Err(Error::TypeMismatch {
expected: column.dtype().to_string(),
actual: literal_type_name(literal).to_string(),
}),
}
}
}
}
fn join_single_key(
&self,
right: &Table,
options: &JoinOptions,
left_index: usize,
right_index: usize,
) -> Result<Self> {
let left_column = self.column(left_index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {left_index}"))
})?;
let right_column = right.column(right_index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {right_index}"))
})?;
if !join_key_types_compatible(left_column.dtype(), right_column.dtype()) {
return Err(Error::TypeMismatch {
expected: left_column.dtype().to_string(),
actual: right_column.dtype().to_string(),
});
}
let left_indices = self.selected_row_indices();
let right_indices = right.selected_row_indices();
let mut right_lookup = HashMap::<SingleKey, Vec<usize>>::new();
for (position, row_index) in right_indices.iter().copied().enumerate() {
let key = single_key_from_column(right_column, row_index);
if !matches!(key, SingleKey::Null) {
right_lookup.entry(key).or_default().push(position);
}
}
let mut matched_right = vec![false; right_indices.len()];
let mut row_pairs = Vec::<(Option<u32>, Option<u32>)>::new();
for left_row in left_indices.iter().copied() {
let key = single_key_from_column(left_column, left_row);
let mut matched_any = false;
if !matches!(key, SingleKey::Null) {
if let Some(right_positions) = right_lookup.get(&key) {
matched_any = true;
for &position in right_positions {
matched_right[position] = true;
row_pairs.push((Some(left_row), right_indices.get(position).copied()));
}
}
}
if !matched_any && matches!(options.join_type, JoinType::Left | JoinType::Full) {
row_pairs.push((Some(left_row), None));
}
}
if matches!(options.join_type, JoinType::Right | JoinType::Full) {
for (position, right_row) in right_indices.iter().copied().enumerate() {
if !matched_right[position] {
row_pairs.push((None, Some(right_row)));
}
}
}
let fields = build_join_fields(self.schema(), right.schema(), options);
let left_rows = row_pairs
.iter()
.map(|(left_row, _)| *left_row)
.collect::<Vec<_>>();
let right_rows = row_pairs
.iter()
.map(|(_, right_row)| *right_row)
.collect::<Vec<_>>();
let mut columns = Vec::with_capacity(self.ncols() + right.ncols());
for column in self.columns() {
columns.push(build_join_column(column, &left_rows)?);
}
for column in right.columns() {
columns.push(build_join_column(column, &right_rows)?);
}
Table::from_columns(Schema::new(fields)?, columns)
}
fn group_by_single_key(&self, key: &GroupKey, aggs: &[AggregateExpr]) -> Result<Self> {
let key_index = resolve_selector_index(self.schema(), &key.column)?;
let key_field = self.schema().field(key_index).cloned().ok_or_else(|| {
Error::InvalidSelection(format!("missing field at index {key_index}"))
})?;
let key_column = self.column(key_index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {key_index}"))
})?;
let aggregates = aggs
.iter()
.map(|aggregate| resolve_aggregate(self, aggregate))
.collect::<Result<Vec<_>>>()?;
let mut group_index = HashMap::<SingleKey, usize>::new();
let mut grouped_keys = Vec::<SingleKey>::new();
let mut grouped_states = Vec::<Vec<AggState>>::new();
for row_index in self.selected_row_indices() {
let key = single_key_from_column(key_column, row_index);
let entry = if let Some(existing) = group_index.get(&key) {
*existing
} else {
let index = grouped_states.len();
group_index.insert(key.clone(), index);
grouped_keys.push(key);
grouped_states.push(
aggregates
.iter()
.map(AggState::new)
.collect::<Result<Vec<_>>>()?,
);
index
};
let states = grouped_states
.get_mut(entry)
.ok_or_else(|| Error::InvalidSelection("group index out of bounds".to_string()))?;
for (state, aggregate) in states.iter_mut().zip(aggregates.iter()) {
state.observe(aggregate, row_index)?;
}
}
let mut fields = vec![key_field.clone()];
fields.extend(
aggregates
.iter()
.map(|aggregate| {
Field::new(Arc::clone(&aggregate.alias), aggregate.output_dtype())
.with_nullability(aggregate.may_produce_null())
})
.collect::<Vec<_>>(),
);
let mut columns = vec![build_column_from_single_keys(
key_field.dtype,
key_column,
&grouped_keys,
)?];
for agg_index in 0..aggregates.len() {
let states = grouped_states
.iter()
.map(|states| {
states.get(agg_index).cloned().ok_or_else(|| {
Error::InvalidSelection("aggregate state missing".to_string())
})
})
.collect::<Result<Vec<_>>>()?;
columns.push(build_column_from_aggregate_states(
aggregates
.get(agg_index)
.ok_or_else(|| Error::InvalidSelection("aggregate spec missing".to_string()))?,
&states,
)?);
}
Table::from_columns(Schema::new(fields)?, columns)
}
}
fn compare_ord<T: Ord>(actual: Option<T>, expected: Option<T>, op: CompareOp) -> Result<bool> {
let actual =
actual.ok_or_else(|| Error::InvalidSelection("row index out of bounds".to_string()))?;
let expected =
expected.ok_or_else(|| Error::InvalidArgument("missing comparison literal".to_string()))?;
Ok(match op {
CompareOp::Eq => actual == expected,
CompareOp::NotEq => actual != expected,
CompareOp::Lt => actual < expected,
CompareOp::Lte => actual <= expected,
CompareOp::Gt => actual > expected,
CompareOp::Gte => actual >= expected,
other => {
return Err(Error::Unsupported(format!(
"operator {other:?} is not supported for ordered values"
)))
}
})
}
fn compare_partial(actual: Option<f64>, expected: Option<f64>, op: CompareOp) -> Result<bool> {
let actual =
actual.ok_or_else(|| Error::InvalidSelection("row index out of bounds".to_string()))?;
let expected =
expected.ok_or_else(|| Error::InvalidArgument("missing comparison literal".to_string()))?;
Ok(match op {
CompareOp::Eq => actual == expected,
CompareOp::NotEq => actual != expected,
CompareOp::Lt => actual < expected,
CompareOp::Lte => actual <= expected,
CompareOp::Gt => actual > expected,
CompareOp::Gte => actual >= expected,
other => {
return Err(Error::Unsupported(format!(
"operator {other:?} is not supported for floating-point values"
)))
}
})
}
fn compare_str(actual: Option<&str>, expected: Option<&str>, op: CompareOp) -> Result<bool> {
let actual =
actual.ok_or_else(|| Error::InvalidSelection("row index out of bounds".to_string()))?;
let expected =
expected.ok_or_else(|| Error::InvalidArgument("missing comparison literal".to_string()))?;
Ok(match op {
CompareOp::Eq => actual == expected,
CompareOp::NotEq => actual != expected,
CompareOp::Lt => actual < expected,
CompareOp::Lte => actual <= expected,
CompareOp::Gt => actual > expected,
CompareOp::Gte => actual >= expected,
CompareOp::Contains => actual.contains(expected),
CompareOp::StartsWith => actual.starts_with(expected),
CompareOp::EndsWith => actual.ends_with(expected),
other => {
return Err(Error::Unsupported(format!(
"operator {other:?} is not supported for string values"
)))
}
})
}
fn join_key_types_compatible(left: DataType, right: DataType) -> bool {
matches!(
(left, right),
(DataType::Bool, DataType::Bool)
| (DataType::I64, DataType::I64)
| (DataType::F64, DataType::F64)
| (DataType::Date32, DataType::Date32)
| (DataType::TimestampMs, DataType::TimestampMs)
| (DataType::Utf8, DataType::Utf8)
| (DataType::Utf8, DataType::DictUtf8)
| (DataType::DictUtf8, DataType::Utf8)
| (DataType::DictUtf8, DataType::DictUtf8)
)
}
fn can_use_single_key_join_fast_path(left: &Column, right: &Column) -> bool {
matches!(
(left.dtype(), right.dtype()),
(DataType::Bool, DataType::Bool)
| (DataType::I64, DataType::I64)
| (DataType::F64, DataType::F64)
| (DataType::Date32, DataType::Date32)
| (DataType::TimestampMs, DataType::TimestampMs)
)
}
fn build_join_fields(
left_schema: &Schema,
right_schema: &Schema,
options: &JoinOptions,
) -> Vec<Field> {
let left_names = left_schema
.fields()
.iter()
.map(|field| field.name.to_string())
.collect::<HashSet<_>>();
let right_names = right_schema
.fields()
.iter()
.map(|field| field.name.to_string())
.collect::<HashSet<_>>();
let duplicate_names = left_names
.intersection(&right_names)
.cloned()
.collect::<HashSet<_>>();
let left_force_nullable = matches!(options.join_type, JoinType::Right | JoinType::Full);
let right_force_nullable = matches!(options.join_type, JoinType::Left | JoinType::Full);
let mut used_names = HashSet::new();
let mut fields = Vec::with_capacity(left_schema.len() + right_schema.len());
for field in left_schema.fields() {
let preferred = if duplicate_names.contains(field.name.as_ref()) {
format!("{}{}", field.name, options.left_suffix)
} else {
field.name.to_string()
};
let name = allocate_join_name(&preferred, &mut used_names);
fields.push(
Field::new(Arc::<str>::from(name), field.dtype)
.with_nullability(field.nullable || left_force_nullable),
);
}
for field in right_schema.fields() {
let preferred = if duplicate_names.contains(field.name.as_ref()) {
format!("{}{}", field.name, options.right_suffix)
} else {
field.name.to_string()
};
let name = allocate_join_name(&preferred, &mut used_names);
fields.push(
Field::new(Arc::<str>::from(name), field.dtype)
.with_nullability(field.nullable || right_force_nullable),
);
}
fields
}
fn allocate_join_name(preferred: &str, used_names: &mut HashSet<String>) -> String {
if used_names.insert(preferred.to_string()) {
return preferred.to_string();
}
let mut suffix = 2_usize;
loop {
let candidate = format!("{preferred}_{suffix}");
if used_names.insert(candidate.clone()) {
return candidate;
}
suffix += 1;
}
}
fn join_key_from_columns(columns: &[&Column], row_index: u32) -> Result<Option<Vec<GroupValue>>> {
let mut values = Vec::with_capacity(columns.len());
for column in columns {
let value = group_value_from_column(column, row_index)?;
if matches!(value, GroupValue::Null) {
return Ok(None);
}
values.push(value);
}
Ok(Some(values))
}
fn build_join_column(column: &Column, row_indices: &[Option<u32>]) -> Result<Column> {
if let Some(indices) = row_indices.iter().copied().collect::<Option<Vec<_>>>() {
return column.gather(&indices);
}
let validity = row_indices
.iter()
.map(|row_index| {
row_index
.map(|row_index| !column.is_null(row_index))
.unwrap_or(false)
})
.collect::<Vec<_>>();
match column {
Column::Bool(_) => Ok(Column::Bool(BooleanCol::new(
Bitmap::from_bools(
&row_indices
.iter()
.map(|row_index| {
row_index
.and_then(|row_index| column.bool_value(row_index))
.unwrap_or(false)
})
.collect::<Vec<_>>(),
),
validity_to_bitmap(validity),
))),
Column::I64(_) => Ok(Column::I64(PrimitiveCol::new(
row_indices
.iter()
.map(|row_index| {
row_index
.and_then(|row_index| column.i64_value(row_index))
.unwrap_or_default()
})
.collect::<Vec<_>>(),
validity_to_bitmap(validity),
))),
Column::TimestampMs(_) => Ok(Column::TimestampMs(PrimitiveCol::new(
row_indices
.iter()
.map(|row_index| {
row_index
.and_then(|row_index| column.i64_value(row_index))
.unwrap_or_default()
})
.collect::<Vec<_>>(),
validity_to_bitmap(validity),
))),
Column::Date32(_) => Ok(Column::Date32(PrimitiveCol::new(
row_indices
.iter()
.map(|row_index| {
row_index
.and_then(|row_index| column.i32_value(row_index))
.unwrap_or_default()
})
.collect::<Vec<_>>(),
validity_to_bitmap(validity),
))),
Column::F64(_) => Ok(Column::F64(PrimitiveCol::new(
row_indices
.iter()
.map(|row_index| {
row_index
.and_then(|row_index| column.f64_value(row_index))
.unwrap_or_default()
})
.collect::<Vec<_>>(),
validity_to_bitmap(validity),
))),
Column::Utf8(_) => build_string_column(
row_indices
.iter()
.map(|row_index| row_index.and_then(|row_index| column.utf8_value(row_index)))
.collect::<Vec<_>>(),
false,
),
Column::DictUtf8(_) => build_string_column(
row_indices
.iter()
.map(|row_index| row_index.and_then(|row_index| column.utf8_value(row_index)))
.collect::<Vec<_>>(),
true,
),
}
}
fn literal_type_name(literal: &Literal) -> &'static str {
match literal {
Literal::Bool(_) => "bool",
Literal::I64(_) => "i64",
Literal::F64(_) => "f64",
Literal::Utf8(_) => "utf8",
Literal::Date32(_) => "date32",
Literal::TimestampMs(_) => "timestamp_ms",
}
}
fn compare_rows_for_sort(columns: &[&Column], keys: &[SortKey], left: u32, right: u32) -> Ordering {
for (column, key) in columns.iter().zip(keys.iter()) {
let ordering = compare_column_values_for_sort(column, key, left, right);
if ordering != Ordering::Equal {
return ordering;
}
}
Ordering::Equal
}
fn compare_column_values_for_sort(
column: &Column,
key: &SortKey,
left: u32,
right: u32,
) -> Ordering {
let left_null = column.is_null(left);
let right_null = column.is_null(right);
let ordering = match (left_null, right_null) {
(true, true) => Ordering::Equal,
(true, false) => match key.nulls {
crate::ops::NullOrder::First => Ordering::Less,
crate::ops::NullOrder::Last => Ordering::Greater,
},
(false, true) => match key.nulls {
crate::ops::NullOrder::First => Ordering::Greater,
crate::ops::NullOrder::Last => Ordering::Less,
},
(false, false) => match column {
Column::Bool(_) => column.bool_value(left).cmp(&column.bool_value(right)),
Column::I64(_) | Column::TimestampMs(_) => {
column.i64_value(left).cmp(&column.i64_value(right))
}
Column::Date32(_) => column.i32_value(left).cmp(&column.i32_value(right)),
Column::F64(_) => column
.f64_value(left)
.zip(column.f64_value(right))
.and_then(|(left, right)| left.partial_cmp(&right))
.unwrap_or(Ordering::Equal),
Column::Utf8(_) | Column::DictUtf8(_) => {
column.utf8_value(left).cmp(&column.utf8_value(right))
}
},
};
match key.order {
crate::ops::SortOrder::Ascending => ordering,
crate::ops::SortOrder::Descending => ordering.reverse(),
}
}
fn resolve_selector_index(schema: &Schema, selector: &ColumnSelector) -> Result<usize> {
match selector {
ColumnSelector::Name(name) => schema
.index_of(name)
.ok_or_else(|| Error::ColumnNotFound(name.to_string())),
ColumnSelector::Index(index) => {
if *index < schema.len() {
Ok(*index)
} else {
Err(Error::InvalidSelection(format!(
"column index {index} is out of bounds"
)))
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum GroupValue {
Null,
Bool(bool),
I64(i64),
F64Bits(u64),
Utf8(String),
Date32(i32),
TimestampMs(i64),
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum SingleKey {
Null,
Bool(bool),
I64(i64),
F64Bits(u64),
Utf8(Box<str>),
DictKey(u32),
Date32(i32),
TimestampMs(i64),
}
fn single_key_from_column(column: &Column, row_index: u32) -> SingleKey {
if column.is_null(row_index) {
return SingleKey::Null;
}
match column {
Column::Bool(_) => SingleKey::Bool(column.bool_value(row_index).unwrap_or(false)),
Column::I64(_) => SingleKey::I64(column.i64_value(row_index).unwrap_or_default()),
Column::TimestampMs(_) => {
SingleKey::TimestampMs(column.i64_value(row_index).unwrap_or_default())
}
Column::Date32(_) => SingleKey::Date32(column.i32_value(row_index).unwrap_or_default()),
Column::F64(_) => {
SingleKey::F64Bits(column.f64_value(row_index).unwrap_or_default().to_bits())
}
Column::Utf8(_) => SingleKey::Utf8(
column
.utf8_value(row_index)
.unwrap_or_default()
.to_string()
.into_boxed_str(),
),
Column::DictUtf8(_) => {
SingleKey::DictKey(column.dict_key_value(row_index).unwrap_or_default())
}
}
}
fn group_value_from_column(column: &Column, row_index: u32) -> Result<GroupValue> {
if column.is_null(row_index) {
return Ok(GroupValue::Null);
}
Ok(match column {
Column::Bool(_) => GroupValue::Bool(column.bool_value(row_index).unwrap_or(false)),
Column::I64(_) => GroupValue::I64(column.i64_value(row_index).unwrap_or_default()),
Column::TimestampMs(_) => {
GroupValue::TimestampMs(column.i64_value(row_index).unwrap_or_default())
}
Column::Date32(_) => GroupValue::Date32(column.i32_value(row_index).unwrap_or_default()),
Column::F64(_) => {
GroupValue::F64Bits(column.f64_value(row_index).unwrap_or_default().to_bits())
}
Column::Utf8(_) | Column::DictUtf8(_) => {
GroupValue::Utf8(column.utf8_value(row_index).unwrap_or_default().to_string())
}
})
}
#[derive(Clone)]
struct ResolvedAggregate<'a> {
alias: Arc<str>,
op: crate::ops::AggregateOp,
column: Option<&'a Column>,
input_dtype: Option<DataType>,
}
impl<'a> ResolvedAggregate<'a> {
fn output_dtype(&self) -> DataType {
match self.op {
crate::ops::AggregateOp::CountRows
| crate::ops::AggregateOp::CountNonNull
| crate::ops::AggregateOp::CountNulls => DataType::I64,
crate::ops::AggregateOp::Sum => match self.input_dtype {
Some(DataType::F64) => DataType::F64,
_ => DataType::I64,
},
crate::ops::AggregateOp::Mean => DataType::F64,
crate::ops::AggregateOp::Min | crate::ops::AggregateOp::Max => match self.input_dtype {
Some(DataType::DictUtf8) => DataType::Utf8,
Some(dtype) => dtype,
None => DataType::I64,
},
}
}
fn may_produce_null(&self) -> bool {
!matches!(
self.op,
crate::ops::AggregateOp::CountRows
| crate::ops::AggregateOp::CountNonNull
| crate::ops::AggregateOp::CountNulls
)
}
}
fn resolve_aggregate<'a>(
table: &'a Table,
aggregate: &AggregateExpr,
) -> Result<ResolvedAggregate<'a>> {
let (column, input_dtype) = match aggregate.op {
crate::ops::AggregateOp::CountRows => (None, None),
crate::ops::AggregateOp::CountNonNull
| crate::ops::AggregateOp::CountNulls
| crate::ops::AggregateOp::Sum
| crate::ops::AggregateOp::Min
| crate::ops::AggregateOp::Max
| crate::ops::AggregateOp::Mean => {
let selector = aggregate.input.as_ref().ok_or_else(|| {
Error::InvalidArgument(format!(
"aggregate '{}' requires an input column",
aggregate.alias
))
})?;
let index = resolve_selector_index(table.schema(), selector)?;
let column = table.column(index).ok_or_else(|| {
Error::InvalidSelection(format!("missing column at index {index}"))
})?;
(Some(column), Some(column.dtype()))
}
};
match aggregate.op {
crate::ops::AggregateOp::Sum | crate::ops::AggregateOp::Mean => match input_dtype {
Some(DataType::I64) | Some(DataType::F64) => {}
Some(dtype) => {
return Err(Error::Unsupported(format!(
"aggregate {:?} is not supported for {dtype}",
aggregate.op
)))
}
None => {}
},
_ => {}
}
Ok(ResolvedAggregate {
alias: Arc::clone(&aggregate.alias),
op: aggregate.op,
column,
input_dtype,
})
}
#[derive(Clone)]
enum AggState {
CountRows(i64),
CountNonNull(i64),
CountNulls(i64),
SumI64 { sum: i64, count: i64 },
SumF64 { sum: f64, count: i64 },
MinI64(Option<i64>),
MaxI64(Option<i64>),
MinF64(Option<f64>),
MaxF64(Option<f64>),
MinUtf8(Option<String>),
MaxUtf8(Option<String>),
MinDate32(Option<i32>),
MaxDate32(Option<i32>),
MinTimestampMs(Option<i64>),
MaxTimestampMs(Option<i64>),
}
impl AggState {
fn new(spec: &ResolvedAggregate<'_>) -> Result<Self> {
Ok(match spec.op {
crate::ops::AggregateOp::CountRows => Self::CountRows(0),
crate::ops::AggregateOp::CountNonNull => Self::CountNonNull(0),
crate::ops::AggregateOp::CountNulls => Self::CountNulls(0),
crate::ops::AggregateOp::Sum => match spec.input_dtype {
Some(DataType::I64) => Self::SumI64 { sum: 0, count: 0 },
Some(DataType::F64) => Self::SumF64 { sum: 0.0, count: 0 },
Some(dtype) => {
return Err(Error::Unsupported(format!(
"sum is not supported for {dtype}"
)))
}
None => {
return Err(Error::InvalidArgument(
"sum requires an input column".to_string(),
))
}
},
crate::ops::AggregateOp::Mean => match spec.input_dtype {
Some(DataType::I64) | Some(DataType::F64) => Self::SumF64 { sum: 0.0, count: 0 },
Some(dtype) => {
return Err(Error::Unsupported(format!(
"mean is not supported for {dtype}"
)))
}
None => {
return Err(Error::InvalidArgument(
"mean requires an input column".to_string(),
))
}
},
crate::ops::AggregateOp::Min => match spec.input_dtype {
Some(DataType::I64) => Self::MinI64(None),
Some(DataType::F64) => Self::MinF64(None),
Some(DataType::Utf8) | Some(DataType::DictUtf8) => Self::MinUtf8(None),
Some(DataType::Date32) => Self::MinDate32(None),
Some(DataType::TimestampMs) => Self::MinTimestampMs(None),
Some(dtype) => {
return Err(Error::Unsupported(format!(
"min is not supported for {dtype}"
)))
}
None => {
return Err(Error::InvalidArgument(
"min requires an input column".to_string(),
))
}
},
crate::ops::AggregateOp::Max => match spec.input_dtype {
Some(DataType::I64) => Self::MaxI64(None),
Some(DataType::F64) => Self::MaxF64(None),
Some(DataType::Utf8) | Some(DataType::DictUtf8) => Self::MaxUtf8(None),
Some(DataType::Date32) => Self::MaxDate32(None),
Some(DataType::TimestampMs) => Self::MaxTimestampMs(None),
Some(dtype) => {
return Err(Error::Unsupported(format!(
"max is not supported for {dtype}"
)))
}
None => {
return Err(Error::InvalidArgument(
"max requires an input column".to_string(),
))
}
},
})
}
fn observe(&mut self, spec: &ResolvedAggregate<'_>, row_index: u32) -> Result<()> {
match self {
Self::CountRows(count) => *count += 1,
Self::CountNonNull(count) => {
if let Some(column) = spec.column {
if !column.is_null(row_index) {
*count += 1;
}
}
}
Self::CountNulls(count) => {
if let Some(column) = spec.column {
if column.is_null(row_index) {
*count += 1;
}
}
}
Self::SumI64 { sum, count } => {
if let Some(column) = spec.column {
if let Some(value) = column.i64_value(row_index) {
*sum += value;
*count += 1;
}
}
}
Self::SumF64 { sum, count } => {
if let Some(column) = spec.column {
if let Some(value) = column.f64_value(row_index) {
*sum += value;
*count += 1;
} else if matches!(spec.input_dtype, Some(DataType::I64)) {
if let Some(value) = column.i64_value(row_index) {
*sum += value as f64;
*count += 1;
}
}
}
}
Self::MinI64(current) => {
update_min(current, spec.column.and_then(|c| c.i64_value(row_index)))
}
Self::MaxI64(current) => {
update_max(current, spec.column.and_then(|c| c.i64_value(row_index)))
}
Self::MinF64(current) => {
update_min_partial(current, spec.column.and_then(|c| c.f64_value(row_index)))
}
Self::MaxF64(current) => {
update_max_partial(current, spec.column.and_then(|c| c.f64_value(row_index)))
}
Self::MinUtf8(current) => {
update_min_string(current, spec.column.and_then(|c| c.utf8_value(row_index)))
}
Self::MaxUtf8(current) => {
update_max_string(current, spec.column.and_then(|c| c.utf8_value(row_index)))
}
Self::MinDate32(current) => {
update_min(current, spec.column.and_then(|c| c.i32_value(row_index)))
}
Self::MaxDate32(current) => {
update_max(current, spec.column.and_then(|c| c.i32_value(row_index)))
}
Self::MinTimestampMs(current) => {
update_min(current, spec.column.and_then(|c| c.i64_value(row_index)))
}
Self::MaxTimestampMs(current) => {
update_max(current, spec.column.and_then(|c| c.i64_value(row_index)))
}
}
Ok(())
}
}
fn update_min<T: Ord + Copy>(current: &mut Option<T>, value: Option<T>) {
if let Some(value) = value {
*current = Some(match current {
Some(existing) => (*existing).min(value),
None => value,
});
}
}
fn update_max<T: Ord + Copy>(current: &mut Option<T>, value: Option<T>) {
if let Some(value) = value {
*current = Some(match current {
Some(existing) => (*existing).max(value),
None => value,
});
}
}
fn update_min_partial(current: &mut Option<f64>, value: Option<f64>) {
if let Some(value) = value {
*current = Some(match current {
Some(existing) if *existing <= value => *existing,
_ => value,
});
}
}
fn update_max_partial(current: &mut Option<f64>, value: Option<f64>) {
if let Some(value) = value {
*current = Some(match current {
Some(existing) if *existing >= value => *existing,
_ => value,
});
}
}
fn update_min_string(current: &mut Option<String>, value: Option<&str>) {
if let Some(value) = value {
*current = Some(match current {
Some(existing) if existing.as_str() <= value => existing.clone(),
_ => value.to_string(),
});
}
}
fn update_max_string(current: &mut Option<String>, value: Option<&str>) {
if let Some(value) = value {
*current = Some(match current {
Some(existing) if existing.as_str() >= value => existing.clone(),
_ => value.to_string(),
});
}
}
fn build_column_from_group_values(dtype: DataType, values: &[GroupValue]) -> Result<Column> {
match dtype {
DataType::Bool => Ok(Column::Bool(BooleanCol::new(
Bitmap::from_bools(
&values
.iter()
.map(|value| matches!(value, GroupValue::Bool(true)))
.collect::<Vec<_>>(),
),
group_validity(values),
))),
DataType::I64 => Ok(Column::I64(PrimitiveCol::new(
values
.iter()
.map(|value| match value {
GroupValue::I64(value) => *value,
_ => 0,
})
.collect::<Vec<_>>(),
group_validity(values),
))),
DataType::Date32 => Ok(Column::Date32(PrimitiveCol::new(
values
.iter()
.map(|value| match value {
GroupValue::Date32(value) => *value,
_ => 0,
})
.collect::<Vec<_>>(),
group_validity(values),
))),
DataType::TimestampMs => Ok(Column::TimestampMs(PrimitiveCol::new(
values
.iter()
.map(|value| match value {
GroupValue::TimestampMs(value) => *value,
_ => 0,
})
.collect::<Vec<_>>(),
group_validity(values),
))),
DataType::F64 => Ok(Column::F64(PrimitiveCol::new(
values
.iter()
.map(|value| match value {
GroupValue::F64Bits(bits) => f64::from_bits(*bits),
_ => 0.0,
})
.collect::<Vec<_>>(),
group_validity(values),
))),
DataType::Utf8 => build_string_column(
values
.iter()
.map(|value| match value {
GroupValue::Utf8(value) => Some(value.as_str()),
_ => None,
})
.collect::<Vec<_>>(),
false,
),
DataType::DictUtf8 => build_string_column(
values
.iter()
.map(|value| match value {
GroupValue::Utf8(value) => Some(value.as_str()),
_ => None,
})
.collect::<Vec<_>>(),
true,
),
}
}
fn build_column_from_single_keys(
dtype: DataType,
source_column: &Column,
keys: &[SingleKey],
) -> Result<Column> {
match dtype {
DataType::Bool => {
let mut values = Vec::with_capacity(keys.len());
let mut validity = Vec::with_capacity(keys.len());
for key in keys {
match key {
SingleKey::Bool(value) => {
values.push(*value);
validity.push(true);
}
SingleKey::Null => {
values.push(false);
validity.push(false);
}
other => {
return Err(Error::TypeMismatch {
expected: "bool".to_string(),
actual: single_key_type_name(other).to_string(),
})
}
}
}
Ok(Column::Bool(BooleanCol::new(
Bitmap::from_bools(&values),
validity_to_bitmap(validity),
)))
}
DataType::I64 => {
let mut values = Vec::with_capacity(keys.len());
let mut validity = Vec::with_capacity(keys.len());
for key in keys {
match key {
SingleKey::I64(value) => {
values.push(*value);
validity.push(true);
}
SingleKey::Null => {
values.push(0);
validity.push(false);
}
other => {
return Err(Error::TypeMismatch {
expected: "i64".to_string(),
actual: single_key_type_name(other).to_string(),
})
}
}
}
Ok(Column::I64(PrimitiveCol::new(
values,
validity_to_bitmap(validity),
)))
}
DataType::Date32 => {
let mut values = Vec::with_capacity(keys.len());
let mut validity = Vec::with_capacity(keys.len());
for key in keys {
match key {
SingleKey::Date32(value) => {
values.push(*value);
validity.push(true);
}
SingleKey::Null => {
values.push(0);
validity.push(false);
}
other => {
return Err(Error::TypeMismatch {
expected: "date32".to_string(),
actual: single_key_type_name(other).to_string(),
})
}
}
}
Ok(Column::Date32(PrimitiveCol::new(
values,
validity_to_bitmap(validity),
)))
}
DataType::TimestampMs => {
let mut values = Vec::with_capacity(keys.len());
let mut validity = Vec::with_capacity(keys.len());
for key in keys {
match key {
SingleKey::TimestampMs(value) => {
values.push(*value);
validity.push(true);
}
SingleKey::Null => {
values.push(0);
validity.push(false);
}
other => {
return Err(Error::TypeMismatch {
expected: "timestamp_ms".to_string(),
actual: single_key_type_name(other).to_string(),
})
}
}
}
Ok(Column::TimestampMs(PrimitiveCol::new(
values,
validity_to_bitmap(validity),
)))
}
DataType::F64 => {
let mut values = Vec::with_capacity(keys.len());
let mut validity = Vec::with_capacity(keys.len());
for key in keys {
match key {
SingleKey::F64Bits(bits) => {
values.push(f64::from_bits(*bits));
validity.push(true);
}
SingleKey::Null => {
values.push(0.0);
validity.push(false);
}
other => {
return Err(Error::TypeMismatch {
expected: "f64".to_string(),
actual: single_key_type_name(other).to_string(),
})
}
}
}
Ok(Column::F64(PrimitiveCol::new(
values,
validity_to_bitmap(validity),
)))
}
DataType::Utf8 => build_string_column(
keys.iter()
.map(|key| match key {
SingleKey::Utf8(value) => Some(value.as_ref()),
SingleKey::Null => None,
_ => None,
})
.collect::<Vec<_>>(),
false,
),
DataType::DictUtf8 => match source_column {
Column::DictUtf8(column) => {
let mut dict_keys = Vec::with_capacity(keys.len());
let mut validity = Vec::with_capacity(keys.len());
for key in keys {
match key {
SingleKey::DictKey(value) => {
dict_keys.push(*value);
validity.push(true);
}
SingleKey::Null => {
dict_keys.push(0);
validity.push(false);
}
other => {
return Err(Error::TypeMismatch {
expected: "dict_utf8".to_string(),
actual: single_key_type_name(other).to_string(),
})
}
}
}
Ok(Column::DictUtf8(DictionaryCol::new(
dict_keys,
column.values.clone(),
validity_to_bitmap(validity),
)))
}
_ => build_string_column(
keys.iter()
.map(|key| match key {
SingleKey::Utf8(value) => Some(value.as_ref()),
SingleKey::Null => None,
_ => None,
})
.collect::<Vec<_>>(),
true,
),
},
}
}
fn build_column_from_aggregate_states(
aggregate: &ResolvedAggregate<'_>,
states: &[AggState],
) -> Result<Column> {
match aggregate.output_dtype() {
DataType::I64 => {
let (values, validity) = states
.iter()
.map(|state| match state {
AggState::CountRows(count)
| AggState::CountNonNull(count)
| AggState::CountNulls(count) => (*count, true),
AggState::SumI64 { sum, count } => (*sum, *count > 0),
AggState::MinI64(value)
| AggState::MaxI64(value)
| AggState::MinTimestampMs(value)
| AggState::MaxTimestampMs(value) => {
(value.unwrap_or_default(), value.is_some())
}
_ => (0, false),
})
.unzip::<_, _, Vec<_>, Vec<_>>();
Ok(Column::I64(PrimitiveCol::new(
values,
validity_to_bitmap(validity),
)))
}
DataType::Date32 => {
let (values, validity) = states
.iter()
.map(|state| match state {
AggState::MinDate32(value) | AggState::MaxDate32(value) => {
(value.unwrap_or_default(), value.is_some())
}
_ => (0, false),
})
.unzip::<_, _, Vec<_>, Vec<_>>();
Ok(Column::Date32(PrimitiveCol::new(
values,
validity_to_bitmap(validity),
)))
}
DataType::TimestampMs => {
let (values, validity) = states
.iter()
.map(|state| match state {
AggState::MinTimestampMs(value) | AggState::MaxTimestampMs(value) => {
(value.unwrap_or_default(), value.is_some())
}
_ => (0, false),
})
.unzip::<_, _, Vec<_>, Vec<_>>();
Ok(Column::TimestampMs(PrimitiveCol::new(
values,
validity_to_bitmap(validity),
)))
}
DataType::F64 => {
let (values, validity) = states
.iter()
.map(|state| match state {
AggState::SumF64 { sum, count }
if matches!(aggregate.op, crate::ops::AggregateOp::Mean) =>
{
if *count > 0 {
(*sum / *count as f64, true)
} else {
(0.0, false)
}
}
AggState::SumF64 { sum, count } => (*sum, *count > 0),
AggState::MinF64(value) | AggState::MaxF64(value) => {
(value.unwrap_or_default(), value.is_some())
}
_ => (0.0, false),
})
.unzip::<_, _, Vec<_>, Vec<_>>();
Ok(Column::F64(PrimitiveCol::new(
values,
validity_to_bitmap(validity),
)))
}
DataType::Utf8 | DataType::DictUtf8 => build_string_column(
states
.iter()
.map(|state| match state {
AggState::MinUtf8(value) | AggState::MaxUtf8(value) => value.as_deref(),
_ => None,
})
.collect::<Vec<_>>(),
matches!(aggregate.output_dtype(), DataType::DictUtf8),
),
DataType::Bool => Err(Error::Unsupported(
"boolean aggregates are not implemented yet".to_string(),
)),
}
}
fn group_validity(values: &[GroupValue]) -> Option<Bitmap> {
let validity = values
.iter()
.map(|value| !matches!(value, GroupValue::Null))
.collect::<Vec<_>>();
validity_to_bitmap(validity)
}
fn single_key_type_name(key: &SingleKey) -> &'static str {
match key {
SingleKey::Null => "null",
SingleKey::Bool(_) => "bool",
SingleKey::I64(_) => "i64",
SingleKey::F64Bits(_) => "f64",
SingleKey::Utf8(_) => "utf8",
SingleKey::DictKey(_) => "dict_utf8",
SingleKey::Date32(_) => "date32",
SingleKey::TimestampMs(_) => "timestamp_ms",
}
}
fn validity_to_bitmap(validity: Vec<bool>) -> Option<Bitmap> {
if validity.iter().all(|value| *value) {
None
} else {
Some(Bitmap::from_bools(&validity))
}
}
fn build_string_column(values: Vec<Option<&str>>, dictionary: bool) -> Result<Column> {
if dictionary {
let mut keys = Vec::with_capacity(values.len());
let mut validity = Vec::with_capacity(values.len());
let mut dictionary_index = HashMap::<String, u32>::new();
let mut dictionary_values = Vec::<String>::new();
for value in values {
match value {
Some(value) => {
let key = if let Some(existing) = dictionary_index.get(value) {
*existing
} else {
let key = dictionary_values.len() as u32;
dictionary_values.push(value.to_string());
dictionary_index.insert(value.to_string(), key);
key
};
keys.push(key);
validity.push(true);
}
None => {
keys.push(0);
validity.push(false);
}
}
}
let dictionary_values = {
let mut offsets = Vec::with_capacity(dictionary_values.len() + 1);
let mut bytes = Vec::new();
offsets.push(0);
for value in dictionary_values {
bytes.extend_from_slice(value.as_bytes());
offsets.push(bytes.len() as u32);
}
Utf8Col::new(offsets, bytes, None)?
};
Ok(Column::DictUtf8(DictionaryCol::new(
keys,
dictionary_values,
validity_to_bitmap(validity),
)))
} else {
let mut offsets = Vec::with_capacity(values.len() + 1);
let mut bytes = Vec::new();
let mut validity = Vec::with_capacity(values.len());
offsets.push(0);
for value in values {
match value {
Some(value) => {
bytes.extend_from_slice(value.as_bytes());
offsets.push(bytes.len() as u32);
validity.push(true);
}
None => {
offsets.push(bytes.len() as u32);
validity.push(false);
}
}
}
Ok(Column::Utf8(Utf8Col::new(
offsets,
bytes,
validity_to_bitmap(validity),
)?))
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::{build_string_column, Table};
use crate::ops::{
AggregateExpr, AggregateOp, ColumnSelector, GroupKey, JoinKey, JoinOptions, JoinType,
NullOrder, SortKey, SortOrder,
};
use crate::table::{Bitmap, Column, DataType, Field, PrimitiveCol, Schema};
fn i64_values(column: &Column) -> Vec<Option<i64>> {
(0..column.len())
.map(|index| column.i64_value(index as u32))
.collect()
}
fn f64_values(column: &Column) -> Vec<Option<f64>> {
(0..column.len())
.map(|index| column.f64_value(index as u32))
.collect()
}
fn utf8_values(column: &Column) -> Vec<Option<String>> {
(0..column.len())
.map(|index| column.utf8_value(index as u32).map(ToString::to_string))
.collect()
}
#[test]
fn rejects_mismatched_column_lengths() {
let schema = Schema::new(vec![
Field::new(Arc::<str>::from("id"), DataType::I64),
Field::new(Arc::<str>::from("score"), DataType::I64),
])
.expect("schema should be valid");
let columns = vec![
Column::I64(PrimitiveCol::new(vec![1_i64, 2_i64], None)),
Column::I64(PrimitiveCol::new(vec![10_i64], None)),
];
let error =
Table::from_columns(schema, columns).expect_err("mismatched lengths should fail");
assert!(error.to_string().contains("same row count"));
}
#[test]
fn sorts_rows_by_numeric_column_with_nulls_last() {
let schema = Schema::new(vec![
Field::new(Arc::<str>::from("id"), DataType::I64).not_null(),
Field::new(Arc::<str>::from("score"), DataType::I64),
])
.expect("schema should be valid");
let columns = vec![
Column::I64(PrimitiveCol::new(
vec![10_i64, 11_i64, 12_i64, 13_i64],
None,
)),
Column::I64(PrimitiveCol::new(
vec![3_i64, 1_i64, 0_i64, 2_i64],
Some(Bitmap::from_bools(&[true, true, false, true])),
)),
];
let table = Table::from_columns(schema, columns).expect("table should build");
let sorted = table
.sort_by(&[SortKey {
column: ColumnSelector::from("score"),
order: SortOrder::Ascending,
nulls: NullOrder::Last,
}])
.expect("sort should succeed")
.materialize()
.expect("materialize should succeed");
assert_eq!(
i64_values(sorted.column_by_name("id").expect("id column should exist")),
vec![Some(11), Some(13), Some(10), Some(12)]
);
assert_eq!(
i64_values(
sorted
.column_by_name("score")
.expect("score column should exist")
),
vec![Some(1), Some(2), Some(3), None]
);
}
#[test]
fn groups_rows_and_computes_mixed_aggregates() {
let schema = Schema::new(vec![
Field::new(Arc::<str>::from("team"), DataType::Utf8),
Field::new(Arc::<str>::from("points"), DataType::I64),
Field::new(Arc::<str>::from("efficiency"), DataType::F64),
Field::new(Arc::<str>::from("label"), DataType::Utf8),
])
.expect("schema should be valid");
let columns = vec![
build_string_column(vec![Some("a"), Some("b"), Some("a"), None], false)
.expect("team column should build"),
Column::I64(PrimitiveCol::new(vec![5_i64, 1_i64, 7_i64, 2_i64], None)),
Column::F64(PrimitiveCol::new(
vec![1.5_f64, 3.0_f64, 2.5_f64, 4.0_f64],
None,
)),
build_string_column(vec![Some("x"), Some("z"), Some("y"), Some("n")], false)
.expect("label column should build"),
];
let table = Table::from_columns(schema, columns).expect("table should build");
let grouped = table
.group_by(
&[GroupKey {
column: ColumnSelector::from("team"),
}],
&[
AggregateExpr {
input: None,
op: AggregateOp::CountRows,
alias: Arc::from("rows"),
},
AggregateExpr {
input: Some(ColumnSelector::from("points")),
op: AggregateOp::Sum,
alias: Arc::from("points_sum"),
},
AggregateExpr {
input: Some(ColumnSelector::from("efficiency")),
op: AggregateOp::Mean,
alias: Arc::from("efficiency_mean"),
},
AggregateExpr {
input: Some(ColumnSelector::from("label")),
op: AggregateOp::Max,
alias: Arc::from("label_max"),
},
],
)
.expect("group by should succeed");
assert_eq!(grouped.nrows(), 3);
assert_eq!(
utf8_values(
grouped
.column_by_name("team")
.expect("team column should exist")
),
vec![Some("a".to_string()), Some("b".to_string()), None]
);
assert_eq!(
i64_values(
grouped
.column_by_name("rows")
.expect("rows column should exist")
),
vec![Some(2), Some(1), Some(1)]
);
assert_eq!(
i64_values(
grouped
.column_by_name("points_sum")
.expect("points_sum column should exist")
),
vec![Some(12), Some(1), Some(2)]
);
assert_eq!(
f64_values(
grouped
.column_by_name("efficiency_mean")
.expect("efficiency_mean column should exist")
),
vec![Some(2.0), Some(3.0), Some(4.0)]
);
assert_eq!(
utf8_values(
grouped
.column_by_name("label_max")
.expect("label_max column should exist")
),
vec![
Some("y".to_string()),
Some("z".to_string()),
Some("n".to_string())
]
);
}
#[test]
fn joins_inner_rows_and_suffixes_duplicate_names() {
let left_schema = Schema::new(vec![
Field::new(Arc::<str>::from("id"), DataType::I64).not_null(),
Field::new(Arc::<str>::from("name"), DataType::Utf8),
])
.expect("left schema should be valid");
let right_schema = Schema::new(vec![
Field::new(Arc::<str>::from("id"), DataType::I64).not_null(),
Field::new(Arc::<str>::from("name"), DataType::Utf8),
])
.expect("right schema should be valid");
let left = Table::from_columns(
left_schema,
vec![
Column::I64(PrimitiveCol::new(vec![1_i64, 2_i64, 2_i64], None)),
build_string_column(vec![Some("a"), Some("b"), Some("c")], false)
.expect("left name column should build"),
],
)
.expect("left table should build");
let right = Table::from_columns(
right_schema,
vec![
Column::I64(PrimitiveCol::new(vec![2_i64, 2_i64, 3_i64], None)),
build_string_column(vec![Some("x"), Some("y"), Some("z")], false)
.expect("right name column should build"),
],
)
.expect("right table should build");
let joined = left
.join(
&right,
&JoinOptions {
join_type: JoinType::Inner,
keys: vec![JoinKey {
left: ColumnSelector::from("id"),
right: ColumnSelector::from("id"),
}],
..JoinOptions::default()
},
)
.expect("join should succeed");
assert_eq!(joined.nrows(), 4);
assert_eq!(
joined
.schema()
.fields()
.iter()
.map(|field| field.name.as_ref())
.collect::<Vec<_>>(),
vec!["id_left", "name_left", "id_right", "name_right"]
);
assert_eq!(
i64_values(
joined
.column_by_name("id_left")
.expect("left id should exist")
),
vec![Some(2), Some(2), Some(2), Some(2)]
);
assert_eq!(
utf8_values(
joined
.column_by_name("name_left")
.expect("left name should exist")
),
vec![
Some("b".to_string()),
Some("b".to_string()),
Some("c".to_string()),
Some("c".to_string())
]
);
assert_eq!(
utf8_values(
joined
.column_by_name("name_right")
.expect("right name should exist")
),
vec![
Some("x".to_string()),
Some("y".to_string()),
Some("x".to_string()),
Some("y".to_string())
]
);
}
#[test]
fn supports_full_outer_join_with_null_fill() {
let left_schema = Schema::new(vec![
Field::new(Arc::<str>::from("id"), DataType::I64),
Field::new(Arc::<str>::from("left_value"), DataType::Utf8),
])
.expect("left schema should be valid");
let right_schema = Schema::new(vec![
Field::new(Arc::<str>::from("id"), DataType::I64),
Field::new(Arc::<str>::from("right_value"), DataType::Utf8),
])
.expect("right schema should be valid");
let left = Table::from_columns(
left_schema,
vec![
Column::I64(PrimitiveCol::new(
vec![1_i64, 2_i64, 0_i64],
Some(Bitmap::from_bools(&[true, true, false])),
)),
build_string_column(
vec![Some("left-1"), Some("left-2"), Some("left-null")],
false,
)
.expect("left string column should build"),
],
)
.expect("left table should build");
let right = Table::from_columns(
right_schema,
vec![
Column::I64(PrimitiveCol::new(
vec![2_i64, 3_i64, 0_i64],
Some(Bitmap::from_bools(&[true, true, false])),
)),
build_string_column(
vec![Some("right-2"), Some("right-3"), Some("right-null")],
false,
)
.expect("right string column should build"),
],
)
.expect("right table should build");
let joined = left
.join(
&right,
&JoinOptions {
join_type: JoinType::Full,
keys: vec![JoinKey {
left: ColumnSelector::from("id"),
right: ColumnSelector::from("id"),
}],
..JoinOptions::default()
},
)
.expect("full join should succeed");
assert_eq!(joined.nrows(), 5);
assert_eq!(
i64_values(
joined
.column_by_name("id_left")
.expect("left id column should exist")
),
vec![Some(1), Some(2), None, None, None]
);
assert_eq!(
utf8_values(
joined
.column_by_name("left_value")
.expect("left value column should exist")
),
vec![
Some("left-1".to_string()),
Some("left-2".to_string()),
Some("left-null".to_string()),
None,
None
]
);
assert_eq!(
i64_values(
joined
.column_by_name("id_right")
.expect("right id column should exist")
),
vec![None, Some(2), None, Some(3), None]
);
assert_eq!(
utf8_values(
joined
.column_by_name("right_value")
.expect("right value column should exist")
),
vec![
None,
Some("right-2".to_string()),
None,
Some("right-3".to_string()),
Some("right-null".to_string())
]
);
}
}