use std::sync::Arc;
use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
use arrow::datatypes::{ArrowNativeTypeOp, DecimalType};
use arrow::error::ArrowError;
use arrow_buffer::ArrowNativeType;
use datafusion_common::{DataFusionError, Result};
pub(super) fn apply_decimal_op<T, F>(
array: &ArrayRef,
precision: u8,
scale: i8,
fn_name: &str,
op: F,
) -> Result<ArrayRef>
where
T: DecimalType,
T::Native: ArrowNativeType + ArrowNativeTypeOp,
F: Fn(T::Native, T::Native) -> T::Native,
{
if scale <= 0 {
return Ok(Arc::clone(array));
}
let factor = decimal_scale_factor::<T>(scale, fn_name)?;
let decimal = array.as_primitive::<T>();
let data_type = array.data_type().clone();
let result: PrimitiveArray<T> = decimal.try_unary(|value| {
let new_value = op(value, factor);
T::validate_decimal_precision(new_value, precision, scale).map_err(|_| {
ArrowError::ComputeError(format!("Decimal overflow while applying {fn_name}"))
})?;
Ok::<_, ArrowError>(new_value)
})?;
let result = result.with_data_type(data_type);
Ok(Arc::new(result))
}
fn decimal_scale_factor<T>(scale: i8, fn_name: &str) -> Result<T::Native>
where
T: DecimalType,
T::Native: ArrowNativeType + ArrowNativeTypeOp,
{
let base = <T::Native as ArrowNativeType>::from_usize(10).ok_or_else(|| {
DataFusionError::Execution(format!(
"Cannot get 10_{} from usize: {:?}",
std::any::type_name::<T::Native>(),
10_usize
))
})?;
base.pow_checked(scale as u32).map_err(|_| {
DataFusionError::Execution(format!("Decimal overflow while applying {fn_name}"))
})
}
pub(super) fn ceil_decimal_value<T>(value: T, factor: T) -> T
where
T: ArrowNativeTypeOp + std::ops::Rem<Output = T>,
{
let remainder = value % factor;
if remainder == T::ZERO {
return value;
}
if value >= T::ZERO {
let increment = factor.sub_wrapping(remainder);
value.add_wrapping(increment)
} else {
value.sub_wrapping(remainder)
}
}
pub(super) fn floor_decimal_value<T>(value: T, factor: T) -> T
where
T: ArrowNativeTypeOp + std::ops::Rem<Output = T>,
{
let remainder = value % factor;
if remainder == T::ZERO {
return value;
}
if value >= T::ZERO {
value.sub_wrapping(remainder)
} else {
let adjustment = factor.add_wrapping(remainder);
value.sub_wrapping(adjustment)
}
}