use crate::{AggregateExpr, PhysicalSortExpr};
use arrow::array::ArrayRef;
use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION};
use arrow_array::cast::AsArray;
use arrow_array::types::{
Decimal128Type, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
use arrow_schema::{DataType, Field};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Accumulator;
use std::any::Any;
use std::sync::Arc;
pub fn get_accum_scalar_values_as_arrays(
accum: &dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
Ok(accum
.state()?
.iter()
.map(|s| s.to_array_of_size(1))
.collect::<Vec<_>>())
}
pub(crate) struct Decimal128Averager {
sum_mul: i128,
target_mul: i128,
target_min: i128,
target_max: i128,
}
impl Decimal128Averager {
pub fn try_new(
sum_scale: i8,
target_precision: u8,
target_scale: i8,
) -> Result<Self> {
let sum_mul = 10_i128.pow(sum_scale as u32);
let target_mul = 10_i128.pow(target_scale as u32);
let target_min = MIN_DECIMAL_FOR_EACH_PRECISION[target_precision as usize - 1];
let target_max = MAX_DECIMAL_FOR_EACH_PRECISION[target_precision as usize - 1];
if target_mul >= sum_mul {
Ok(Self {
sum_mul,
target_mul,
target_min,
target_max,
})
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
#[inline(always)]
pub fn avg(&self, sum: i128, count: i128) -> Result<i128> {
if let Some(value) = sum.checked_mul(self.target_mul / self.sum_mul) {
let new_value = value / count;
if new_value >= self.target_min && new_value <= self.target_max {
Ok(new_value)
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
}
pub fn adjust_output_array(
data_type: &DataType,
array: ArrayRef,
) -> Result<ArrayRef, DataFusionError> {
let array = match data_type {
DataType::Decimal128(p, s) => Arc::new(
array
.as_primitive::<Decimal128Type>()
.clone()
.with_precision_and_scale(*p, *s)?,
) as ArrayRef,
DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, tz) => Arc::new(
array
.as_primitive::<TimestampNanosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => Arc::new(
array
.as_primitive::<TimestampMicrosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => Arc::new(
array
.as_primitive::<TimestampMillisecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(arrow_schema::TimeUnit::Second, tz) => Arc::new(
array
.as_primitive::<TimestampSecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
_ => array,
};
Ok(array)
}
pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn AggregateExpr>>() {
any.downcast_ref::<Arc<dyn AggregateExpr>>()
.unwrap()
.as_any()
} else if any.is::<Box<dyn AggregateExpr>>() {
any.downcast_ref::<Box<dyn AggregateExpr>>()
.unwrap()
.as_any()
} else {
any
}
}
pub(crate) fn ordering_fields(
ordering_req: &[PhysicalSortExpr],
data_types: &[DataType],
) -> Vec<Field> {
ordering_req
.iter()
.zip(data_types.iter())
.map(|(expr, dtype)| {
Field::new(
expr.to_string().as_str(),
dtype.clone(),
true,
)
})
.collect()
}