use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Formatter;
use std::fmt::{self};
use std::hash::Hash;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_session::VortexSession;
use crate::ArrayRef;
use crate::Columnar;
use crate::ExecutionCtx;
use crate::aggregate_fn::AggregateFnId;
use crate::aggregate_fn::AggregateFnVTable;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::FieldName;
use crate::dtype::FieldNames;
use crate::dtype::Nullability;
use crate::dtype::StructFields;
use crate::scalar::Scalar;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct PairOptions<L, R>(pub L, pub R);
impl<L: Display, R: Display> Display for PairOptions<L, R> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "({}, {})", self.0, self.1)
}
}
type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options;
type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options;
type LeftPartial<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Partial;
type RightPartial<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Partial;
pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>;
pub trait BinaryCombined: 'static + Send + Sync + Clone {
type Left: AggregateFnVTable;
type Right: AggregateFnVTable;
fn id(&self) -> AggregateFnId;
fn left(&self) -> Self::Left;
fn right(&self) -> Self::Right;
fn left_name(&self) -> &'static str {
"left"
}
fn right_name(&self) -> &'static str {
"right"
}
fn return_dtype(&self, input_dtype: &DType) -> Option<DType>;
fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>;
fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar>;
fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
let _ = options;
Ok(None)
}
fn deserialize(
&self,
metadata: &[u8],
session: &VortexSession,
) -> VortexResult<CombinedOptions<Self>> {
let _ = (metadata, session);
vortex_bail!(
"Combined aggregate function {} is not deserializable",
BinaryCombined::id(self)
);
}
fn coerce_args(
&self,
options: &CombinedOptions<Self>,
input_dtype: &DType,
) -> VortexResult<DType> {
let left_coerced = self.left().coerce_args(&options.0, input_dtype)?;
self.right().coerce_args(&options.1, &left_coerced)
}
fn partial_struct_dtype(&self, left: DType, right: DType) -> DType {
DType::Struct(
StructFields::new(
FieldNames::from_iter([
FieldName::from(self.left_name()),
FieldName::from(self.right_name()),
]),
vec![left, right],
),
Nullability::NonNullable,
)
}
}
#[derive(Clone, Debug)]
pub struct Combined<T: BinaryCombined>(pub T);
impl<T: BinaryCombined> Combined<T> {
pub fn new(inner: T) -> Self {
Self(inner)
}
}
impl<T: BinaryCombined> AggregateFnVTable for Combined<T> {
type Options = CombinedOptions<T>;
type Partial = (LeftPartial<T>, RightPartial<T>);
fn id(&self) -> AggregateFnId {
self.0.id()
}
fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
BinaryCombined::serialize(&self.0, options)
}
fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<Self::Options> {
BinaryCombined::deserialize(&self.0, metadata, session)
}
fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
BinaryCombined::coerce_args(&self.0, options, input_dtype)
}
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
BinaryCombined::return_dtype(&self.0, input_dtype)
}
fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
let l = self.0.left().partial_dtype(&options.0, input_dtype)?;
let r = self.0.right().partial_dtype(&options.1, input_dtype)?;
Some(self.0.partial_struct_dtype(l, r))
}
fn empty_partial(
&self,
options: &Self::Options,
input_dtype: &DType,
) -> VortexResult<Self::Partial> {
Ok((
self.0.left().empty_partial(&options.0, input_dtype)?,
self.0.right().empty_partial(&options.1, input_dtype)?,
))
}
fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
if other.is_null() {
return Ok(());
}
let s = other.as_struct();
let lname = self.0.left_name();
let rname = self.0.right_name();
let l_field = s
.field(lname)
.ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?;
let r_field = s
.field(rname)
.ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?;
self.0.left().combine_partials(&mut partial.0, l_field)?;
self.0.right().combine_partials(&mut partial.1, r_field)?;
Ok(())
}
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
let l_scalar = self.0.left().to_scalar(&partial.0)?;
let r_scalar = self.0.right().to_scalar(&partial.1)?;
let dtype = self
.0
.partial_struct_dtype(l_scalar.dtype().clone(), r_scalar.dtype().clone());
Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar]))
}
fn reset(&self, partial: &mut Self::Partial) {
self.0.left().reset(&mut partial.0);
self.0.right().reset(&mut partial.1);
}
fn is_saturated(&self, partial: &Self::Partial) -> bool {
self.0.left().is_saturated(&partial.0) && self.0.right().is_saturated(&partial.1)
}
fn try_accumulate(
&self,
state: &mut Self::Partial,
batch: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<bool> {
let mut canonical: Option<Columnar> = None;
if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? {
let c = canonical.insert(batch.clone().execute::<Columnar>(ctx)?);
self.0.left().accumulate(&mut state.0, c, ctx)?;
}
if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? {
let c = match canonical.as_ref() {
Some(c) => c,
None => canonical.insert(batch.clone().execute::<Columnar>(ctx)?),
};
self.0.right().accumulate(&mut state.1, c, ctx)?;
}
Ok(true)
}
fn accumulate(
&self,
_state: &mut Self::Partial,
_batch: &Columnar,
_ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
unreachable!("Combined::try_accumulate handles all batches")
}
fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> {
let l_field = states.get_item(FieldName::from(self.0.left_name()))?;
let r_field = states.get_item(FieldName::from(self.0.right_name()))?;
let l_finalized = self.0.left().finalize(l_field)?;
let r_finalized = self.0.right().finalize(r_field)?;
BinaryCombined::finalize(&self.0, l_finalized, r_finalized)
}
fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
let l_scalar = self.0.left().finalize_scalar(&partial.0)?;
let r_scalar = self.0.right().finalize_scalar(&partial.1)?;
BinaryCombined::finalize_scalar(&self.0, l_scalar, r_scalar)
}
}