use std::marker::PhantomData;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::builders::ArrayBuilder;
use vortex_array::builders::BoolBuilder;
use vortex_array::builders::builder_with_capacity;
use vortex_array::dtype::DType;
use vortex_array::dtype::FieldName;
use vortex_array::dtype::Nullability;
use vortex_array::expr::stats::Stat;
use vortex_array::scalar::Scalar;
use vortex_array::scalar::ScalarTruncation;
use vortex_array::scalar::lower_bound;
use vortex_array::scalar::upper_bound;
use vortex_buffer::BufferString;
use vortex_buffer::ByteBuffer;
use vortex_error::VortexResult;
pub const MAX_IS_TRUNCATED: &str = "max_is_truncated";
pub const MIN_IS_TRUNCATED: &str = "min_is_truncated";
pub fn stats_builder_with_capacity(
stat: Stat,
dtype: &DType,
capacity: usize,
max_length: usize,
) -> Box<dyn StatsArrayBuilder> {
let values_builder = builder_with_capacity(dtype, capacity);
match stat {
Stat::Max => match dtype {
DType::Utf8(_) => Box::new(TruncatedMaxBinaryStatsBuilder::<BufferString>::new(
values_builder,
BoolBuilder::with_capacity(Nullability::NonNullable, capacity),
max_length,
)),
DType::Binary(_) => Box::new(TruncatedMaxBinaryStatsBuilder::<ByteBuffer>::new(
values_builder,
BoolBuilder::with_capacity(Nullability::NonNullable, capacity),
max_length,
)),
_ => Box::new(StatNameArrayBuilder::new(stat, values_builder)),
},
Stat::Min => match dtype {
DType::Utf8(_) => Box::new(TruncatedMinBinaryStatsBuilder::<BufferString>::new(
values_builder,
BoolBuilder::with_capacity(Nullability::NonNullable, capacity),
max_length,
)),
DType::Binary(_) => Box::new(TruncatedMinBinaryStatsBuilder::<ByteBuffer>::new(
values_builder,
BoolBuilder::with_capacity(Nullability::NonNullable, capacity),
max_length,
)),
_ => Box::new(StatNameArrayBuilder::new(stat, values_builder)),
},
_ => Box::new(StatNameArrayBuilder::new(stat, values_builder)),
}
}
pub struct NamedArrays {
pub names: Vec<FieldName>,
pub arrays: Vec<ArrayRef>,
}
impl NamedArrays {
pub fn all_invalid(&self) -> VortexResult<bool> {
self.arrays[0].all_invalid()
}
}
pub trait StatsArrayBuilder: Send {
fn stat(&self) -> Stat;
fn append_scalar(&mut self, value: Scalar) -> VortexResult<()>;
fn append_null(&mut self);
fn finish(&mut self) -> NamedArrays;
}
pub struct StatNameArrayBuilder {
stat: Stat,
builder: Box<dyn ArrayBuilder>,
}
impl StatNameArrayBuilder {
pub fn new(stat: Stat, builder: Box<dyn ArrayBuilder>) -> Self {
Self { stat, builder }
}
}
impl StatsArrayBuilder for StatNameArrayBuilder {
fn stat(&self) -> Stat {
self.stat
}
fn append_scalar(&mut self, value: Scalar) -> VortexResult<()> {
self.builder.append_scalar(&value)
}
fn append_null(&mut self) {
self.builder.append_null()
}
fn finish(&mut self) -> NamedArrays {
let array = self.builder.finish();
let len = array.len();
match self.stat {
Stat::Max => NamedArrays {
names: vec![self.stat.name().into(), MAX_IS_TRUNCATED.into()],
arrays: vec![array, ConstantArray::new(false, len).into_array()],
},
Stat::Min => NamedArrays {
names: vec![self.stat.name().into(), MIN_IS_TRUNCATED.into()],
arrays: vec![array, ConstantArray::new(false, len).into_array()],
},
_ => NamedArrays {
names: vec![self.stat.name().into()],
arrays: vec![array],
},
}
}
}
struct TruncatedMaxBinaryStatsBuilder<T: ScalarTruncation> {
values: Box<dyn ArrayBuilder>,
is_truncated: BoolBuilder,
max_value_length: usize,
_marker: PhantomData<T>,
}
impl<T: ScalarTruncation> TruncatedMaxBinaryStatsBuilder<T> {
pub fn new(
values: Box<dyn ArrayBuilder>,
is_truncated: BoolBuilder,
max_value_length: usize,
) -> Self {
Self {
values,
is_truncated,
max_value_length,
_marker: PhantomData,
}
}
}
struct TruncatedMinBinaryStatsBuilder<T: ScalarTruncation> {
values: Box<dyn ArrayBuilder>,
is_truncated: BoolBuilder,
max_value_length: usize,
_marker: PhantomData<T>,
}
impl<T: ScalarTruncation> TruncatedMinBinaryStatsBuilder<T> {
pub fn new(
values: Box<dyn ArrayBuilder>,
is_truncated: BoolBuilder,
max_value_length: usize,
) -> Self {
Self {
values,
is_truncated,
max_value_length,
_marker: PhantomData,
}
}
}
impl<T: ScalarTruncation> StatsArrayBuilder for TruncatedMaxBinaryStatsBuilder<T> {
fn stat(&self) -> Stat {
Stat::Max
}
fn append_scalar(&mut self, value: Scalar) -> VortexResult<()> {
let nullability = value.dtype().nullability();
if let Some((upper_bound, truncated)) =
upper_bound(T::from_scalar(value)?, self.max_value_length, nullability)
{
self.values.append_scalar(&upper_bound)?;
self.is_truncated.append_value(truncated);
} else {
self.append_null()
}
Ok(())
}
#[inline]
fn append_null(&mut self) {
ArrayBuilder::append_null(self.values.as_mut());
self.is_truncated.append_value(false);
}
fn finish(&mut self) -> NamedArrays {
NamedArrays {
names: vec![Stat::Max.name().into(), MAX_IS_TRUNCATED.into()],
arrays: vec![
ArrayBuilder::finish(self.values.as_mut()),
ArrayBuilder::finish(&mut self.is_truncated),
],
}
}
}
impl<T: ScalarTruncation> StatsArrayBuilder for TruncatedMinBinaryStatsBuilder<T> {
fn stat(&self) -> Stat {
Stat::Min
}
fn append_scalar(&mut self, value: Scalar) -> VortexResult<()> {
let nullability = value.dtype().nullability();
if let Some((lower_bound, truncated)) =
lower_bound(T::from_scalar(value)?, self.max_value_length, nullability)
{
self.values.append_scalar(&lower_bound)?;
self.is_truncated.append_value(truncated);
} else {
self.append_null()
}
Ok(())
}
#[inline]
fn append_null(&mut self) {
ArrayBuilder::append_null(self.values.as_mut());
self.is_truncated.append_value(false);
}
fn finish(&mut self) -> NamedArrays {
NamedArrays {
names: vec![Stat::Min.name().into(), MIN_IS_TRUNCATED.into()],
arrays: vec![
ArrayBuilder::finish(self.values.as_mut()),
ArrayBuilder::finish(&mut self.is_truncated),
],
}
}
}