use std::any::Any;
use std::sync::Arc;
use datafusion::arrow::array::{
Array, ArrayRef, AsArray, Decimal128Array, Float64Array, GenericListArray, Int64Array,
OffsetSizeTrait,
};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ArrayAvgUdf {
signature: Signature,
}
impl Default for ArrayAvgUdf {
fn default() -> Self {
Self::new()
}
}
impl ArrayAvgUdf {
pub fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for ArrayAvgUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_array_avg"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let element_type = list_element_type(&arg_types[0], "array_avg")?;
match element_type {
DataType::Int64 | DataType::Float64 => Ok(DataType::Float64),
dt @ DataType::Decimal128(_, _) => Ok(dt.clone()),
dt => exec_err!("array_avg does not support element type {:?}", dt),
}
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() != 1 {
return exec_err!("array_avg expects exactly 1 argument, got {}", args.len());
}
match &args[0] {
ColumnarValue::Scalar(scalar) => match scalar_list_values(scalar, "array_avg")? {
Some(values) => compute_avg_scalar(&values),
None => Ok(ColumnarValue::Scalar(avg_null_for_element_type(
list_element_type(&scalar.data_type(), "array_avg")?.clone(),
))),
},
ColumnarValue::Array(array) => match array.data_type() {
DataType::List(_) => compute_avg_column(array.as_list::<i32>()),
DataType::LargeList(_) => compute_avg_column(array.as_list::<i64>()),
dt => exec_err!("array_avg expects array type, got {:?}", dt),
},
}
}
}
fn null_for_element_type(element_type: DataType) -> ScalarValue {
match element_type {
DataType::Float64 => ScalarValue::Float64(None),
DataType::Decimal128(p, s) => ScalarValue::Decimal128(None, p, s),
_ => ScalarValue::Int64(None),
}
}
fn avg_null_for_element_type(element_type: DataType) -> ScalarValue {
match element_type {
DataType::Decimal128(p, s) => ScalarValue::Decimal128(None, p, s),
_ => ScalarValue::Float64(None),
}
}
fn compute_avg_scalar(values: &dyn Array) -> Result<ColumnarValue> {
let len = values.len();
if len == 0 {
return Ok(ColumnarValue::Scalar(avg_null_for_element_type(
values.data_type().clone(),
)));
}
if let Some(int_arr) = values.as_any().downcast_ref::<Int64Array>() {
let (sum, count) = int_arr
.iter()
.filter_map(|x| x)
.fold((0.0f64, 0.0f64), |(s, c), v| (s + v as f64, c + 1.0));
if count == 0.0 {
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
}
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(
sum / count,
))));
}
if let Some(float_arr) = values.as_any().downcast_ref::<Float64Array>() {
let (sum, count) = float_arr
.iter()
.filter_map(|x| x)
.fold((0.0f64, 0.0f64), |(s, c), v| (s + v, c + 1.0));
if count == 0.0 {
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
}
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(
sum / count,
))));
}
if let Some(dec_arr) = values.as_any().downcast_ref::<Decimal128Array>() {
let (p, s) = (dec_arr.precision(), dec_arr.scale());
let (sum, count) = dec_arr
.iter()
.filter_map(|x| x)
.fold((0i128, 0i128), |(s, c), v| (s + v, c + 1));
if count == 0 {
return Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(None, p, s)));
}
return Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
Some(sum / count),
p,
s,
)));
}
exec_err!(
"array_avg: unsupported array element type {:?}",
values.data_type()
)
}
fn compute_avg_f64_from_ints(values: &dyn Array) -> Option<f64> {
let int_arr = values.as_any().downcast_ref::<Int64Array>()?;
let (sum, count) = int_arr
.iter()
.filter_map(|x| x)
.fold((0.0f64, 0.0f64), |(s, c), v| (s + v as f64, c + 1.0));
(count > 0.0).then(|| sum / count)
}
fn compute_avg_f64(values: &dyn Array) -> Option<f64> {
let float_arr = values.as_any().downcast_ref::<Float64Array>()?;
let (sum, count) = float_arr
.iter()
.filter_map(|x| x)
.fold((0.0f64, 0.0f64), |(s, c), v| (s + v, c + 1.0));
(count > 0.0).then(|| sum / count)
}
pub fn array_avg_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(ArrayAvgUdf::new())
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ArraySumUdf {
signature: Signature,
}
impl Default for ArraySumUdf {
fn default() -> Self {
Self::new()
}
}
impl ArraySumUdf {
pub fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for ArraySumUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_array_sum"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let element_type = list_element_type(&arg_types[0], "array_sum")?;
match element_type {
DataType::Int64 => Ok(DataType::Int64),
DataType::Float64 => Ok(DataType::Float64),
dt @ DataType::Decimal128(_, _) => Ok(dt.clone()),
dt => exec_err!("array_sum does not support element type {:?}", dt),
}
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() != 1 {
return exec_err!("array_sum expects exactly 1 argument, got {}", args.len());
}
match &args[0] {
ColumnarValue::Scalar(scalar) => match scalar_list_values(scalar, "array_sum")? {
Some(values) => compute_sum_scalar(&values),
None => Ok(ColumnarValue::Scalar(null_for_element_type(
list_element_type(&scalar.data_type(), "array_sum")?.clone(),
))),
},
ColumnarValue::Array(array) => match array.data_type() {
DataType::List(_) => compute_sum_column(array.as_list::<i32>()),
DataType::LargeList(_) => compute_sum_column(array.as_list::<i64>()),
dt => exec_err!("array_sum expects array type, got {:?}", dt),
},
}
}
}
fn compute_sum_scalar(values: &dyn Array) -> Result<ColumnarValue> {
let len = values.len();
if len == 0 {
return Ok(ColumnarValue::Scalar(null_for_element_type(
values.data_type().clone(),
)));
}
if let Some(int_arr) = values.as_any().downcast_ref::<Int64Array>() {
let sum: i64 = int_arr.iter().filter_map(|x| x).sum();
return Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(sum))));
}
if let Some(float_arr) = values.as_any().downcast_ref::<Float64Array>() {
let sum: f64 = float_arr.iter().filter_map(|x| x).sum();
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(sum))));
}
if let Some(dec_arr) = values.as_any().downcast_ref::<Decimal128Array>() {
let (p, s) = (dec_arr.precision(), dec_arr.scale());
let sum: i128 = dec_arr.iter().filter_map(|x| x).sum();
return Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
Some(sum),
p,
s,
)));
}
exec_err!(
"array_sum: unsupported array element type {:?}",
values.data_type()
)
}
fn compute_sum_i64(values: &dyn Array) -> Option<i64> {
let int_arr = values.as_any().downcast_ref::<Int64Array>()?;
Some(int_arr.iter().filter_map(|x| x).sum())
}
fn compute_sum_f64(values: &dyn Array) -> Option<f64> {
let float_arr = values.as_any().downcast_ref::<Float64Array>()?;
Some(float_arr.iter().filter_map(|x| x).sum())
}
fn list_element_type<'a>(dt: &'a DataType, fn_name: &str) -> Result<&'a DataType> {
match dt {
DataType::List(field) | DataType::LargeList(field) => Ok(field.data_type()),
_ => exec_err!("{fn_name} expects array type, got {:?}", dt),
}
}
fn scalar_list_values(scalar: &ScalarValue, fn_name: &str) -> Result<Option<ArrayRef>> {
match scalar {
ScalarValue::List(arr) => {
if arr.is_empty() || arr.is_null(0) {
Ok(None)
} else {
Ok(Some(arr.value(0)))
}
}
ScalarValue::LargeList(arr) => {
if arr.is_empty() || arr.is_null(0) {
Ok(None)
} else {
Ok(Some(arr.value(0)))
}
}
_ => exec_err!("{fn_name} expects List type, got {:?}", scalar),
}
}
fn compute_avg_column<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
) -> Result<ColumnarValue> {
let element_type = list_array.value_type();
match element_type {
DataType::Float64 => {
let results: Float64Array = (0..list_array.len())
.map(|i| {
if list_array.is_null(i) {
None
} else {
compute_avg_f64(&list_array.value(i))
}
})
.collect();
Ok(ColumnarValue::Array(Arc::new(results) as ArrayRef))
}
DataType::Decimal128(p, s) => {
let results = compute_decimal_column(list_array, p, s, |sum, count| sum / count)?;
Ok(ColumnarValue::Array(results))
}
DataType::Int64 => {
let results: Float64Array = (0..list_array.len())
.map(|i| {
if list_array.is_null(i) {
None
} else {
compute_avg_f64_from_ints(&list_array.value(i))
}
})
.collect();
Ok(ColumnarValue::Array(Arc::new(results) as ArrayRef))
}
dt => exec_err!("array_avg: unsupported element type {:?}", dt),
}
}
fn compute_sum_column<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
) -> Result<ColumnarValue> {
let element_type = list_array.value_type();
match element_type {
DataType::Float64 => {
let results: Float64Array = (0..list_array.len())
.map(|i| {
if list_array.is_null(i) {
None
} else {
compute_sum_f64(&list_array.value(i))
}
})
.collect();
Ok(ColumnarValue::Array(Arc::new(results) as ArrayRef))
}
DataType::Decimal128(p, s) => {
let results = compute_decimal_column(list_array, p, s, |sum, _count| sum)?;
Ok(ColumnarValue::Array(results))
}
DataType::Int64 => {
let results: Int64Array = (0..list_array.len())
.map(|i| {
if list_array.is_null(i) {
None
} else {
compute_sum_i64(&list_array.value(i))
}
})
.collect();
Ok(ColumnarValue::Array(Arc::new(results) as ArrayRef))
}
dt => exec_err!("array_sum: unsupported element type {:?}", dt),
}
}
fn compute_decimal_column<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
precision: u8,
scale: i8,
reduce: fn(i128, i128) -> i128,
) -> Result<ArrayRef> {
let values: Vec<Option<i128>> = (0..list_array.len())
.map(|i| {
if list_array.is_null(i) {
None
} else {
let values = list_array.value(i);
let dec_arr = values.as_any().downcast_ref::<Decimal128Array>()?;
let sum: i128 = dec_arr.iter().filter_map(|x| x).sum();
let count = dec_arr.iter().filter(|x| x.is_some()).count() as i128;
if count == 0 {
None
} else {
Some(reduce(sum, count))
}
}
})
.collect();
let arr = Decimal128Array::from(values).with_precision_and_scale(precision, scale)?;
Ok(Arc::new(arr) as ArrayRef)
}
pub fn array_sum_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(ArraySumUdf::new())
}