use crate::AggregateExpr;
use arrow::array::ArrayRef;
use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION};
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::Accumulator;
use std::any::Any;
use std::sync::Arc;
pub fn get_accum_scalar_values_as_arrays(
accum: &dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
Ok(accum
.state()?
.iter()
.map(|s| s.to_array_of_size(1))
.collect::<Vec<_>>())
}
pub fn calculate_result_decimal_for_avg(
lit_value: i128,
count: i128,
scale: i8,
target_type: &DataType,
) -> Result<ScalarValue> {
match target_type {
DataType::Decimal128(p, s) => {
let (target_mul, target_min, target_max) = (
10_i128.pow(*s as u32),
MIN_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1],
MAX_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1],
);
let lit_scale_mul = 10_i128.pow(scale as u32);
if target_mul >= lit_scale_mul {
if let Some(value) = lit_value.checked_mul(target_mul / lit_scale_mul) {
let new_value = value / count;
if new_value >= target_min && new_value <= target_max {
Ok(ScalarValue::Decimal128(Some(new_value), *p, *s))
} else {
Err(DataFusionError::Internal(
"Arithmetic Overflow in AvgAccumulator".to_string(),
))
}
} else {
Err(DataFusionError::Internal(
"Arithmetic Overflow in AvgAccumulator".to_string(),
))
}
} else {
Err(DataFusionError::Internal(
"Arithmetic Overflow in AvgAccumulator".to_string(),
))
}
}
other => Err(DataFusionError::Internal(format!(
"Error returned data type in AvgAccumulator {other:?}"
))),
}
}
pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn AggregateExpr>>() {
any.downcast_ref::<Arc<dyn AggregateExpr>>()
.unwrap()
.as_any()
} else if any.is::<Box<dyn AggregateExpr>>() {
any.downcast_ref::<Box<dyn AggregateExpr>>()
.unwrap()
.as_any()
} else {
any
}
}