use crate::{AggregateExpr, PhysicalSortExpr};
use arrow::array::ArrayRef;
use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION};
use arrow_schema::{DataType, Field};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Accumulator;
use std::any::Any;
use std::sync::Arc;
pub fn get_accum_scalar_values_as_arrays(
accum: &dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
Ok(accum
.state()?
.iter()
.map(|s| s.to_array_of_size(1))
.collect::<Vec<_>>())
}
pub(crate) struct Decimal128Averager {
sum_mul: i128,
target_mul: i128,
target_min: i128,
target_max: i128,
}
impl Decimal128Averager {
pub fn try_new(
sum_scale: i8,
target_precision: u8,
target_scale: i8,
) -> Result<Self> {
let sum_mul = 10_i128.pow(sum_scale as u32);
let target_mul = 10_i128.pow(target_scale as u32);
let target_min = MIN_DECIMAL_FOR_EACH_PRECISION[target_precision as usize - 1];
let target_max = MAX_DECIMAL_FOR_EACH_PRECISION[target_precision as usize - 1];
if target_mul >= sum_mul {
Ok(Self {
sum_mul,
target_mul,
target_min,
target_max,
})
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
#[inline(always)]
pub fn avg(&self, sum: i128, count: i128) -> Result<i128> {
if let Some(value) = sum.checked_mul(self.target_mul / self.sum_mul) {
let new_value = value / count;
if new_value >= self.target_min && new_value <= self.target_max {
Ok(new_value)
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
}
pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn AggregateExpr>>() {
any.downcast_ref::<Arc<dyn AggregateExpr>>()
.unwrap()
.as_any()
} else if any.is::<Box<dyn AggregateExpr>>() {
any.downcast_ref::<Box<dyn AggregateExpr>>()
.unwrap()
.as_any()
} else {
any
}
}
pub(crate) fn ordering_fields(
ordering_req: &[PhysicalSortExpr],
data_types: &[DataType],
) -> Vec<Field> {
ordering_req
.iter()
.zip(data_types.iter())
.map(|(expr, dtype)| {
Field::new(
expr.to_string().as_str(),
dtype.clone(),
true,
)
})
.collect()
}