use std::any::Any;
use std::sync::Arc;
use datafusion::arrow::array::{
Array, ArrayRef, AsArray, Float32Array, Float64Array, GenericListArray, Int32Array, Int64Array,
OffsetSizeTrait, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct WidthBucketArrayUdf {
signature: Signature,
}
impl Default for WidthBucketArrayUdf {
fn default() -> Self {
Self::new()
}
}
impl WidthBucketArrayUdf {
pub fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for WidthBucketArrayUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_width_bucket"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() != 2 {
return exec_err!(
"width_bucket expects exactly 2 arguments, got {}",
args.len()
);
}
match (&args[0], &args[1]) {
(ColumnarValue::Scalar(x_scalar), ColumnarValue::Scalar(bins_scalar)) => {
let x = scalar_to_f64(x_scalar)?;
let bins = scalar_to_bins(bins_scalar)?;
let bucket = match bins {
Some(ref b) => compute_bucket(x, b),
None => None,
};
Ok(ColumnarValue::Scalar(ScalarValue::Int64(bucket)))
}
(ColumnarValue::Array(x_array), ColumnarValue::Scalar(bins_scalar)) => {
let bins = scalar_to_bins(bins_scalar)?;
let results = compute_bucket_array(x_array, &bins)?;
Ok(ColumnarValue::Array(results))
}
(ColumnarValue::Scalar(x_scalar), ColumnarValue::Array(bins_array)) => {
let x = scalar_to_f64(x_scalar)?;
let results =
bucket_from_list_column(bins_array, |_i, bins| Ok(compute_bucket(x, bins)))?;
Ok(ColumnarValue::Array(Arc::new(results) as ArrayRef))
}
(ColumnarValue::Array(x_array), ColumnarValue::Array(bins_array)) => {
if x_array.len() != bins_array.len() {
return exec_err!("width_bucket x and bins arrays must have same length");
}
let results = bucket_from_list_column(bins_array, |i, bins| {
let x = get_x_f64(x_array, i)?;
Ok(compute_bucket(x, bins))
})?;
Ok(ColumnarValue::Array(Arc::new(results) as ArrayRef))
}
}
}
}
enum BinsResult {
Valid,
Null,
}
fn bucket_from_list_column(
bins_array: &ArrayRef,
f: impl FnMut(usize, &[f64]) -> Result<Option<i64>>,
) -> Result<Int64Array> {
match bins_array.data_type() {
DataType::List(_) => bucket_from_generic_list(bins_array.as_list::<i32>(), f),
DataType::LargeList(_) => bucket_from_generic_list(bins_array.as_list::<i64>(), f),
dt => exec_err!("width_bucket bins must be an array, got {:?}", dt),
}
}
fn bucket_from_generic_list<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
mut f: impl FnMut(usize, &[f64]) -> Result<Option<i64>>,
) -> Result<Int64Array> {
let mut buf = Vec::new();
let mut results: Vec<Option<i64>> = Vec::with_capacity(list_array.len());
for i in 0..list_array.len() {
if list_array.is_null(i) {
results.push(None);
continue;
}
buf.clear();
match array_to_bins_into(&list_array.value(i), &mut buf)? {
BinsResult::Null => results.push(None),
BinsResult::Valid => results.push(f(i, &buf)?),
}
}
Ok(results.into_iter().collect())
}
fn scalar_to_f64(scalar: &ScalarValue) -> Result<Option<f64>> {
match scalar {
ScalarValue::Int8(v) => Ok(v.map(|x| x as f64)),
ScalarValue::Int16(v) => Ok(v.map(|x| x as f64)),
ScalarValue::Int32(v) => Ok(v.map(|x| x as f64)),
ScalarValue::Int64(v) => Ok(v.map(|x| x as f64)),
ScalarValue::UInt8(v) => Ok(v.map(|x| x as f64)),
ScalarValue::UInt16(v) => Ok(v.map(|x| x as f64)),
ScalarValue::UInt32(v) => Ok(v.map(|x| x as f64)),
ScalarValue::UInt64(v) => Ok(v.map(|x| x as f64)),
ScalarValue::Float32(v) => Ok(v.map(|x| x as f64)),
ScalarValue::Float64(v) => Ok(*v),
ScalarValue::Null => Ok(None),
_ => exec_err!("width_bucket x must be numeric, got {:?}", scalar),
}
}
fn scalar_to_bins(scalar: &ScalarValue) -> Result<Option<Vec<f64>>> {
match scalar {
ScalarValue::List(arr) => {
if arr.is_empty() || arr.is_null(0) {
return Ok(None);
}
array_to_bins(&arr.value(0))
}
ScalarValue::LargeList(arr) => {
if arr.is_empty() || arr.is_null(0) {
return Ok(None);
}
array_to_bins(&arr.value(0))
}
ScalarValue::Null => Ok(None),
_ => exec_err!("width_bucket bins must be an array, got {:?}", scalar),
}
}
fn array_to_bins(array: &dyn Array) -> Result<Option<Vec<f64>>> {
let mut buf = Vec::new();
match array_to_bins_into(array, &mut buf)? {
BinsResult::Null => Ok(None),
BinsResult::Valid => Ok(Some(buf)),
}
}
fn array_to_bins_into(array: &dyn Array, buf: &mut Vec<f64>) -> Result<BinsResult> {
macro_rules! extract_bins {
($arr:expr) => {{
buf.reserve($arr.len());
for i in 0..$arr.len() {
if $arr.is_null(i) {
return Ok(BinsResult::Null);
}
let v = $arr.value(i) as f64;
if v.is_nan() {
return exec_err!("width_bucket bin boundaries must not contain NaN");
}
buf.push(v);
}
}};
}
if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
extract_bins!(arr);
} else if let Some(arr) = array.as_any().downcast_ref::<Float32Array>() {
extract_bins!(arr);
} else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
extract_bins!(arr);
} else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
extract_bins!(arr);
} else if let Some(arr) = array.as_any().downcast_ref::<UInt64Array>() {
extract_bins!(arr);
} else if let Some(arr) = array.as_any().downcast_ref::<UInt32Array>() {
extract_bins!(arr);
} else if let Some(arr) = array.as_any().downcast_ref::<UInt16Array>() {
extract_bins!(arr);
} else if let Some(arr) = array.as_any().downcast_ref::<UInt8Array>() {
extract_bins!(arr);
} else {
return exec_err!(
"width_bucket bins must be a numeric array, got {:?}",
array.data_type()
);
}
for w in buf.windows(2) {
if w[0] > w[1] {
return exec_err!(
"width_bucket bin boundaries must be sorted in ascending order, \
found {} followed by {}",
w[0],
w[1]
);
}
}
Ok(BinsResult::Valid)
}
fn get_x_f64(array: &ArrayRef, i: usize) -> Result<Option<f64>> {
if array.is_null(i) {
return Ok(None);
}
if let Some(arr) = array.as_any().downcast_ref::<Float64Array>() {
return Ok(Some(arr.value(i)));
}
if let Some(arr) = array.as_any().downcast_ref::<Float32Array>() {
return Ok(Some(arr.value(i) as f64));
}
if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
return Ok(Some(arr.value(i) as f64));
}
if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
return Ok(Some(arr.value(i) as f64));
}
if let Some(arr) = array.as_any().downcast_ref::<UInt64Array>() {
return Ok(Some(arr.value(i) as f64));
}
if let Some(arr) = array.as_any().downcast_ref::<UInt32Array>() {
return Ok(Some(arr.value(i) as f64));
}
if let Some(arr) = array.as_any().downcast_ref::<UInt16Array>() {
return Ok(Some(arr.value(i) as f64));
}
if let Some(arr) = array.as_any().downcast_ref::<UInt8Array>() {
return Ok(Some(arr.value(i) as f64));
}
exec_err!(
"width_bucket x must be numeric array, got {:?}",
array.data_type()
)
}
fn compute_bucket_array(x_array: &ArrayRef, bins: &Option<Vec<f64>>) -> Result<ArrayRef> {
let bins = match bins {
Some(b) => b,
None => {
let results: Int64Array = (0..x_array.len()).map(|_| None::<i64>).collect();
return Ok(Arc::new(results) as ArrayRef);
}
};
macro_rules! map_bucket {
($arr:expr) => {{
let results: Int64Array = $arr
.iter()
.map(|v| compute_bucket(v.map(|x| x as f64), bins))
.collect();
return Ok(Arc::new(results) as ArrayRef);
}};
}
if let Some(arr) = x_array.as_any().downcast_ref::<Int64Array>() {
map_bucket!(arr);
}
if let Some(arr) = x_array.as_any().downcast_ref::<Float64Array>() {
let results: Int64Array = arr.iter().map(|v| compute_bucket(v, bins)).collect();
return Ok(Arc::new(results) as ArrayRef);
}
if let Some(arr) = x_array.as_any().downcast_ref::<Int32Array>() {
map_bucket!(arr);
}
if let Some(arr) = x_array.as_any().downcast_ref::<Float32Array>() {
map_bucket!(arr);
}
if let Some(arr) = x_array.as_any().downcast_ref::<UInt64Array>() {
map_bucket!(arr);
}
if let Some(arr) = x_array.as_any().downcast_ref::<UInt32Array>() {
map_bucket!(arr);
}
if let Some(arr) = x_array.as_any().downcast_ref::<UInt16Array>() {
map_bucket!(arr);
}
if let Some(arr) = x_array.as_any().downcast_ref::<UInt8Array>() {
map_bucket!(arr);
}
exec_err!(
"width_bucket x must be numeric array, got {:?}",
x_array.data_type()
)
}
fn compute_bucket(x: Option<f64>, bins: &[f64]) -> Option<i64> {
let x = x?;
if bins.is_empty() {
return Some(0);
}
if x.is_nan() {
return Some(bins.len() as i64);
}
match bins.binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal))
{
Ok(i) => Some((i + 1) as i64), Err(i) => Some(i as i64), }
}
pub fn width_bucket_array_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(WidthBucketArrayUdf::new())
}