use crate::base::{
database::{
filter_util::filter_column_by_index, order_by_util::compare_indexes_by_columns, Column,
},
if_rayon,
scalar::Scalar,
};
use alloc::vec::Vec;
use bumpalo::Bump;
use core::cmp::Ordering;
use itertools::Itertools;
#[cfg(feature = "rayon")]
use rayon::prelude::ParallelSliceMut;
use snafu::Snafu;
#[derive(Debug)]
pub struct AggregatedColumns<'a, S: Scalar> {
pub group_by_columns: Vec<Column<'a, S>>,
pub sum_columns: Vec<&'a [S]>,
#[cfg_attr(not(test), expect(dead_code, reason = "only used by tests for now"))]
pub max_columns: Vec<&'a [Option<S>]>,
#[cfg_attr(not(test), expect(dead_code, reason = "only used by tests for now"))]
pub min_columns: Vec<&'a [Option<S>]>,
pub count_column: &'a [i64],
}
#[derive(Snafu, Debug, PartialEq, Eq)]
pub enum AggregateColumnsError {
#[snafu(display("Column length mismatch"))]
ColumnLengthMismatch,
}
#[expect(clippy::missing_panics_doc)]
pub fn aggregate_columns<'a, S: Scalar>(
alloc: &'a Bump,
group_by_columns_in: &[Column<'a, S>],
sum_columns_in: &[Column<S>],
max_columns_in: &[Column<S>],
min_columns_in: &[Column<S>],
selection_column_in: &[bool],
) -> Result<AggregatedColumns<'a, S>, AggregateColumnsError> {
let len = selection_column_in.len();
if group_by_columns_in
.iter()
.chain(sum_columns_in.iter())
.chain(max_columns_in.iter())
.chain(min_columns_in.iter())
.any(|col| col.len() != len)
{
return Err(AggregateColumnsError::ColumnLengthMismatch);
}
let mut filtered_indexes: Vec<_> = selection_column_in
.iter()
.enumerate()
.filter(|&(_, &b)| b)
.map(|(i, _)| i)
.collect();
if_rayon!(
filtered_indexes.par_sort_unstable_by(|&a, &b| compare_indexes_by_columns(
group_by_columns_in,
a,
b
)),
filtered_indexes.sort_unstable_by(|&a, &b| compare_indexes_by_columns(
group_by_columns_in,
a,
b
))
);
let (counts, group_by_result_indexes): (Vec<_>, Vec<_>) = filtered_indexes
.iter()
.dedup_by_with_count(|&&a, &&b| {
compare_indexes_by_columns(group_by_columns_in, a, b) == Ordering::Equal
})
.multiunzip();
let group_by_columns_out: Vec<_> = group_by_columns_in
.iter()
.map(|column| filter_column_by_index(alloc, column, &group_by_result_indexes))
.collect();
let sum_columns_out: Vec<_> = sum_columns_in
.iter()
.map(|column| {
sum_aggregate_column_by_index_counts(alloc, column, &counts, &filtered_indexes)
})
.collect();
let max_columns_out: Vec<_> = max_columns_in
.iter()
.map(|column| {
max_aggregate_column_by_index_counts(alloc, column, &counts, &filtered_indexes)
})
.collect();
let min_columns_out: Vec<_> = min_columns_in
.iter()
.map(|column| {
min_aggregate_column_by_index_counts(alloc, column, &counts, &filtered_indexes)
})
.collect();
let count_column_out = alloc.alloc_slice_fill_iter(
counts
.into_iter()
.map(|c| c.try_into().expect("Count should fit within i64")),
);
Ok(AggregatedColumns {
group_by_columns: group_by_columns_out,
sum_columns: sum_columns_out,
max_columns: max_columns_out,
min_columns: min_columns_out,
count_column: count_column_out,
})
}
pub(crate) fn sum_aggregate_column_by_index_counts<'a, S: Scalar>(
alloc: &'a Bump,
column: &Column<S>,
counts: &[usize],
indexes: &[usize],
) -> &'a [S] {
match column {
Column::Uint8(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::TinyInt(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::SmallInt(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::BigInt(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int128(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Decimal75(_, _, col) => {
sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes)
}
Column::Scalar(col) => sum_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::VarChar(_)
| Column::TimestampTZ(_, _, _)
| Column::Boolean(_)
| Column::VarBinary(_) => {
unreachable!("SUM can not be applied to non-numeric types")
}
}
}
pub(crate) fn max_aggregate_column_by_index_counts<'a, S: Scalar>(
alloc: &'a Bump,
column: &Column<S>,
counts: &[usize],
indexes: &[usize],
) -> &'a [Option<S>] {
match column {
Column::Boolean(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Uint8(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::TinyInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::SmallInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::BigInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int128(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Decimal75(_, _, col) => {
max_aggregate_slice_by_index_counts(alloc, col, counts, indexes)
}
Column::TimestampTZ(_, _, col) => {
max_aggregate_slice_by_index_counts(alloc, col, counts, indexes)
}
Column::Scalar(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::VarBinary(_) => {
unreachable!("MAX can not be applied to varbinary")
}
Column::VarChar(_) => {
unreachable!("MAX can not be applied to varchar")
}
}
}
pub(crate) fn min_aggregate_column_by_index_counts<'a, S: Scalar>(
alloc: &'a Bump,
column: &Column<S>,
counts: &[usize],
indexes: &[usize],
) -> &'a [Option<S>] {
match column {
Column::Boolean(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Uint8(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::TinyInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::SmallInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::BigInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int128(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Decimal75(_, _, col) => {
min_aggregate_slice_by_index_counts(alloc, col, counts, indexes)
}
Column::TimestampTZ(_, _, col) => {
min_aggregate_slice_by_index_counts(alloc, col, counts, indexes)
}
Column::Scalar(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::VarBinary(_) => unreachable!("MIN can not be applied to varbinary"),
Column::VarChar(_) => {
unreachable!("MIN can not be applied to varchar")
}
}
}
pub(crate) fn sum_aggregate_slice_by_index_counts<'a, S, T>(
alloc: &'a Bump,
slice: &[T],
counts: &[usize],
indexes: &[usize],
) -> &'a [S]
where
for<'b> S: From<&'b T> + Scalar,
{
let mut index = 0;
alloc.alloc_slice_fill_iter(counts.iter().map(|&count| {
let start = index;
index += count;
indexes[start..index]
.iter()
.map(|i| S::from(&slice[*i]))
.sum()
}))
}
pub(crate) fn max_aggregate_slice_by_index_counts<'a, S, T>(
alloc: &'a Bump,
slice: &[T],
counts: &[usize],
indexes: &[usize],
) -> &'a [Option<S>]
where
for<'b> S: From<&'b T> + Scalar,
{
let mut index = 0;
alloc.alloc_slice_fill_iter(counts.iter().map(|&count| {
let start = index;
index += count;
indexes[start..index]
.iter()
.map(|i| S::from(&slice[*i]))
.max_by(super::super::scalar::ScalarExt::signed_cmp)
}))
}
pub(crate) fn min_aggregate_slice_by_index_counts<'a, S, T>(
alloc: &'a Bump,
slice: &[T],
counts: &[usize],
indexes: &[usize],
) -> &'a [Option<S>]
where
for<'b> S: From<&'b T> + Scalar,
{
let mut index = 0;
alloc.alloc_slice_fill_iter(counts.iter().map(|&count| {
let start = index;
index += count;
indexes[start..index]
.iter()
.map(|i| S::from(&slice[*i]))
.min_by(super::super::scalar::ScalarExt::signed_cmp)
}))
}