use crate::math_funcs::utils::get_precision_scale;
use arrow::datatypes::DataType;
use arrow::{
array::{AsArray, Decimal128Builder},
datatypes::{validate_decimal_precision, Int64Type},
};
use datafusion::common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion::physical_plan::ColumnarValue;
use std::sync::Arc;
pub fn spark_make_decimal(
args: &[ColumnarValue],
data_type: &DataType,
) -> DataFusionResult<ColumnarValue> {
let (precision, scale) = get_precision_scale(data_type);
match &args[0] {
ColumnarValue::Scalar(v) => match v {
ScalarValue::Int64(n) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
long_to_decimal(n, precision),
precision,
scale,
))),
sv => internal_err!("Expected Int64 but found {sv:?}"),
},
ColumnarValue::Array(a) => {
let arr = a.as_primitive::<Int64Type>();
let mut result = Decimal128Builder::new();
for v in arr.into_iter() {
result.append_option(long_to_decimal(&v, precision))
}
let result_type = DataType::Decimal128(precision, scale);
Ok(ColumnarValue::Array(Arc::new(
result.finish().with_data_type(result_type),
)))
}
}
}
#[inline]
fn long_to_decimal(v: &Option<i64>, precision: u8) -> Option<i128> {
match v {
Some(v) if validate_decimal_precision(*v as i128, precision).is_ok() => Some(*v as i128),
_ => None,
}
}