use std::collections::HashMap;
use std::fmt::Debug;
use std::mem::{size_of, size_of_val};
use std::sync::Arc;
use arrow::array::{
ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
};
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow::{
array::{Array, ArrayRef, AsArray},
datatypes::{DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type},
};
use num_traits::AsPrimitive;
use arrow::array::ArrowNativeTypeOp;
use datafusion_common::internal_err;
use datafusion_common::types::{NativeType, logical_float64};
use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator;
use crate::min_max::{max_udaf, min_udaf};
use datafusion_common::{
Result, ScalarValue, internal_datafusion_err, utils::take_function_args,
};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature,
TypeSignatureClass, Volatility,
};
use datafusion_expr::{EmitTo, GroupsAccumulator};
use datafusion_expr::{
expr::{AggregateFunction, Sort},
function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
simplify::SimplifyContext,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable};
use datafusion_macros::user_doc;
use crate::utils::validate_percentile_expr;
const INTERPOLATION_PRECISION: f64 = 1_000_000.0;
create_func!(PercentileCont, percentile_cont_udaf);
pub fn percentile_cont(order_by: Sort, percentile: Expr) -> Expr {
let expr = order_by.expr.clone();
let args = vec![expr, percentile];
Expr::AggregateFunction(AggregateFunction::new_udf(
percentile_cont_udaf(),
args,
false,
None,
vec![order_by],
None,
))
}
#[user_doc(
doc_section(label = "General Functions"),
description = "Returns the exact percentile of input values, interpolating between values if needed.",
syntax_example = "percentile_cont(percentile) WITHIN GROUP (ORDER BY expression)",
sql_example = r#"```sql
> SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
+----------------------------------------------------------+
| percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
+----------------------------------------------------------+
| 45.5 |
+----------------------------------------------------------+
```
An alternate syntax is also supported:
```sql
> SELECT percentile_cont(column_name, 0.75) FROM table_name;
+---------------------------------------+
| percentile_cont(column_name, 0.75) |
+---------------------------------------+
| 45.5 |
+---------------------------------------+
```"#,
standard_argument(name = "expression", prefix = "The"),
argument(
name = "percentile",
description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
)
)]
#[derive(PartialEq, Eq, Hash, Debug)]
pub struct PercentileCont {
signature: Signature,
aliases: Vec<String>,
}
impl Default for PercentileCont {
fn default() -> Self {
Self::new()
}
}
impl PercentileCont {
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![
Coercion::new_implicit(
TypeSignatureClass::Float,
vec![TypeSignatureClass::Numeric],
NativeType::Float64,
),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_float64()),
vec![TypeSignatureClass::Numeric],
NativeType::Float64,
),
],
Volatility::Immutable,
)
.with_parameter_names(vec!["expr", "percentile"])
.unwrap(),
aliases: vec![String::from("quantile_cont")],
}
}
}
impl AggregateUDFImpl for PercentileCont {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
"percentile_cont"
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match &arg_types[0] {
DataType::Null => Ok(DataType::Float64),
dt => Ok(dt.clone()),
}
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
let input_type = args.input_fields[0].data_type().clone();
if input_type.is_null() {
return Ok(vec![
Field::new(
format_state_name(args.name, self.name()),
DataType::Null,
true,
)
.into(),
]);
}
let field = Field::new_list_field(input_type, true);
let state_name = if args.is_distinct {
"distinct_percentile_cont"
} else {
"percentile_cont"
};
Ok(vec![
Field::new(
format_state_name(args.name, state_name),
DataType::List(Arc::new(field)),
true,
)
.into(),
])
}
fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let percentile = get_percentile(&args)?;
let input_dt = args.expr_fields[0].data_type();
if input_dt.is_null() {
return Ok(Box::new(NoopAccumulator::new(ScalarValue::Float64(None))));
}
if args.is_distinct {
match input_dt {
DataType::Float16 => Ok(Box::new(DistinctPercentileContAccumulator::<
Float16Type,
>::new(percentile))),
DataType::Float32 => Ok(Box::new(DistinctPercentileContAccumulator::<
Float32Type,
>::new(percentile))),
DataType::Float64 => Ok(Box::new(DistinctPercentileContAccumulator::<
Float64Type,
>::new(percentile))),
dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
}
} else {
match input_dt {
DataType::Float16 => Ok(Box::new(
PercentileContAccumulator::<Float16Type>::new(percentile),
)),
DataType::Float32 => Ok(Box::new(
PercentileContAccumulator::<Float32Type>::new(percentile),
)),
DataType::Float64 => Ok(Box::new(
PercentileContAccumulator::<Float64Type>::new(percentile),
)),
dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
}
}
}
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
!args.is_distinct && !args.expr_fields[0].data_type().is_null()
}
fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
let percentile = get_percentile(&args)?;
let input_dt = args.expr_fields[0].data_type();
match input_dt {
DataType::Float16 => Ok(Box::new(PercentileContGroupsAccumulator::<
Float16Type,
>::new(percentile))),
DataType::Float32 => Ok(Box::new(PercentileContGroupsAccumulator::<
Float32Type,
>::new(percentile))),
DataType::Float64 => Ok(Box::new(PercentileContGroupsAccumulator::<
Float64Type,
>::new(percentile))),
dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
}
}
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
Some(Box::new(|aggregate_function, info| {
simplify_percentile_cont_aggregate(aggregate_function, info)
}))
}
fn supports_within_group_clause(&self) -> bool {
true
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
fn get_percentile(args: &AccumulatorArgs) -> Result<f64> {
let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?;
let is_descending = args
.order_bys
.first()
.map(|sort_expr| sort_expr.options.descending)
.unwrap_or(false);
let percentile = if is_descending {
1.0 - percentile
} else {
percentile
};
Ok(percentile)
}
fn simplify_percentile_cont_aggregate(
aggregate_function: AggregateFunction,
info: &SimplifyContext,
) -> Result<Expr> {
enum PercentileRewriteTarget {
Min,
Max,
}
let params = &aggregate_function.params;
let [value, percentile] = take_function_args("percentile_cont", ¶ms.args)?;
let input_type = info.get_data_type(value)?;
if input_type.is_null() {
return Ok(Expr::AggregateFunction(aggregate_function));
}
let is_descending = params
.order_by
.first()
.map(|sort| !sort.asc)
.unwrap_or(false);
let rewrite_target = match percentile {
Expr::Literal(ScalarValue::Float64(Some(0.0)), _) => {
if is_descending {
PercentileRewriteTarget::Max
} else {
PercentileRewriteTarget::Min
}
}
Expr::Literal(ScalarValue::Float64(Some(1.0)), _) => {
if is_descending {
PercentileRewriteTarget::Min
} else {
PercentileRewriteTarget::Max
}
}
_ => return Ok(Expr::AggregateFunction(aggregate_function)),
};
let udaf = match rewrite_target {
PercentileRewriteTarget::Min => min_udaf(),
PercentileRewriteTarget::Max => max_udaf(),
};
let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
udaf,
vec![value.clone()],
params.distinct,
params.filter.clone(),
vec![],
params.null_treatment,
));
Ok(rewritten)
}
#[derive(Debug)]
struct PercentileContAccumulator<T: ArrowNumericType + Debug> {
all_values: Vec<T::Native>,
percentile: f64,
}
impl<T: ArrowNumericType + Debug> PercentileContAccumulator<T> {
fn new(percentile: f64) -> Self {
Self {
all_values: vec![],
percentile,
}
}
}
impl<T> Accumulator for PercentileContAccumulator<T>
where
T: ArrowNumericType + Debug,
T::Native: Copy + AsPrimitive<f64>,
f64: AsPrimitive<T::Native>,
{
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let offsets =
OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
let values_array = PrimitiveArray::<T>::new(
ScalarBuffer::from(std::mem::take(&mut self.all_values)),
None,
);
let list_array = ListArray::new(
Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
offsets,
Arc::new(values_array),
None,
);
Ok(vec![ScalarValue::List(Arc::new(list_array))])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<T>();
self.all_values.reserve(values.len() - values.null_count());
self.all_values.extend(values.iter().flatten());
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let array = states[0].as_list::<i32>();
self.update_batch(&[array.value(0)])?;
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let value = calculate_percentile::<T>(&mut self.all_values, self.percentile);
ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
}
fn size(&self) -> usize {
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
for i in 0..values[0].len() {
let v = ScalarValue::try_from_array(&values[0], i)?;
if !v.is_null() {
*to_remove.entry(v).or_default() += 1;
}
}
let mut i = 0;
while i < self.all_values.len() {
let k =
ScalarValue::new_primitive::<T>(Some(self.all_values[i]), &T::DATA_TYPE)?;
if let Some(count) = to_remove.get_mut(&k)
&& *count > 0
{
self.all_values.swap_remove(i);
*count -= 1;
if *count == 0 {
to_remove.remove(&k);
if to_remove.is_empty() {
break;
}
}
} else {
i += 1;
}
}
Ok(())
}
fn supports_retract_batch(&self) -> bool {
true
}
}
#[derive(Debug)]
struct PercentileContGroupsAccumulator<T: ArrowNumericType + Send> {
group_values: Vec<Vec<T::Native>>,
percentile: f64,
}
impl<T: ArrowNumericType + Send> PercentileContGroupsAccumulator<T> {
fn new(percentile: f64) -> Self {
Self {
group_values: vec![],
percentile,
}
}
}
impl<T> GroupsAccumulator for PercentileContGroupsAccumulator<T>
where
T: ArrowNumericType + Send,
T::Native: Copy + AsPrimitive<f64>,
f64: AsPrimitive<T::Native>,
{
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
let values = values[0].as_primitive::<T>();
self.group_values.resize(total_num_groups, Vec::new());
accumulate(
group_indices,
values,
opt_filter,
|group_index, new_value| {
self.group_values[group_index].push(new_value);
},
);
Ok(())
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
_opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "one argument to merge_batch");
let input_group_values = values[0].as_list::<i32>();
self.group_values.resize(total_num_groups, Vec::new());
group_indices
.iter()
.zip(input_group_values.iter())
.for_each(|(&group_index, values_opt)| {
if let Some(values) = values_opt {
let values = values.as_primitive::<T>();
self.group_values[group_index].extend(values.values().iter());
}
});
Ok(())
}
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let emit_group_values = emit_to.take_needed(&mut self.group_values);
let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
offsets.push(0);
let mut cur_len = 0_i32;
for group_value in &emit_group_values {
cur_len += group_value.len() as i32;
offsets.push(cur_len);
}
let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
let flatten_group_values =
emit_group_values.into_iter().flatten().collect::<Vec<_>>();
let group_values_array =
PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None);
let result_list_array = ListArray::new(
Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
offsets,
Arc::new(group_values_array),
None,
);
Ok(vec![Arc::new(result_list_array)])
}
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let mut emit_group_values = emit_to.take_needed(&mut self.group_values);
let mut evaluate_result_builder =
PrimitiveBuilder::<T>::with_capacity(emit_group_values.len());
for values in &mut emit_group_values {
let value = calculate_percentile::<T>(values.as_mut_slice(), self.percentile);
evaluate_result_builder.append_option(value);
}
Ok(Arc::new(evaluate_result_builder.finish()))
}
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
assert_eq!(values.len(), 1, "one argument to merge_batch");
let input_array = values[0].as_primitive::<T>();
let values = PrimitiveArray::<T>::new(input_array.values().clone(), None);
let offset_end = i32::try_from(input_array.len()).map_err(|e| {
internal_datafusion_err!(
"cast array_len to i32 failed in convert_to_state of group percentile_cont, err:{e:?}"
)
})?;
let offsets = (0..=offset_end).collect::<Vec<_>>();
let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
let nulls = filtered_null_mask(opt_filter, input_array);
let converted_list_array = ListArray::new(
Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
offsets,
Arc::new(values),
nulls,
);
Ok(vec![Arc::new(converted_list_array)])
}
fn supports_convert_to_state(&self) -> bool {
true
}
fn size(&self) -> usize {
self.group_values
.iter()
.map(|values| values.capacity() * size_of::<T::Native>())
.sum::<usize>()
+ self.group_values.capacity() * size_of::<Vec<T::Native>>()
}
}
#[derive(Debug)]
struct DistinctPercentileContAccumulator<T: ArrowNumericType> {
distinct_values: GenericDistinctBuffer<T>,
percentile: f64,
}
impl<T: ArrowNumericType + Debug> DistinctPercentileContAccumulator<T> {
fn new(percentile: f64) -> Self {
Self {
distinct_values: GenericDistinctBuffer::new(T::DATA_TYPE),
percentile,
}
}
}
impl<T> Accumulator for DistinctPercentileContAccumulator<T>
where
T: ArrowNumericType + Debug,
T::Native: Copy + AsPrimitive<f64>,
f64: AsPrimitive<T::Native>,
{
fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.distinct_values.state()
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
self.distinct_values.update_batch(values)
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.distinct_values.merge_batch(states)
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let mut values: Vec<T::Native> =
self.distinct_values.values.iter().map(|v| v.0).collect();
let value = calculate_percentile::<T>(&mut values, self.percentile);
ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
}
fn size(&self) -> usize {
size_of_val(self) + self.distinct_values.size()
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = values[0].as_primitive::<T>();
for value in arr.iter().flatten() {
self.distinct_values.values.remove(&Hashable(value));
}
Ok(())
}
fn supports_retract_batch(&self) -> bool {
true
}
}
fn calculate_percentile<T: ArrowNumericType>(
values: &mut [T::Native],
percentile: f64,
) -> Option<T::Native>
where
T::Native: Copy + AsPrimitive<f64>,
f64: AsPrimitive<T::Native>,
{
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
let len = values.len();
if len == 0 {
None
} else if len == 1 {
Some(values[0])
} else if percentile == 0.0 {
Some(
*values
.iter()
.min_by(|a, b| cmp(a, b))
.expect("we checked for len > 0 a few lines above"),
)
} else if percentile == 1.0 {
Some(
*values
.iter()
.max_by(|a, b| cmp(a, b))
.expect("we checked for len > 0 a few lines above"),
)
} else {
let index = percentile * ((len - 1) as f64);
let lower_index = index.floor() as usize;
let upper_index = index.ceil() as usize;
if lower_index == upper_index {
let (_, value, _) = values.select_nth_unstable_by(lower_index, cmp);
Some(*value)
} else {
let (_, lower_value, _) = values.select_nth_unstable_by(lower_index, cmp);
let lower_value = *lower_value;
let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp);
let upper_value = *upper_value;
let fraction = index - (lower_index as f64);
let scaled = (fraction * INTERPOLATION_PRECISION) as usize;
let weight = scaled as f64 / INTERPOLATION_PRECISION;
let lower_f: f64 = lower_value.as_();
let upper_f: f64 = upper_value.as_();
let interpolated_f = lower_f + (upper_f - lower_f) * weight;
Some(interpolated_f.as_())
}
}
}
#[cfg(test)]
mod tests {
use super::calculate_percentile;
use half::f16;
#[test]
fn f16_interpolation_does_not_overflow_to_nan() {
let mut values = vec![f16::from_f32(0.0), f16::from_f32(65504.0)];
let result =
calculate_percentile::<arrow::datatypes::Float16Type>(&mut values, 0.5)
.expect("non-empty input");
let result_f = result.to_f32();
assert!(
!result_f.is_nan(),
"expected non-NaN result, got {result_f}"
);
assert!(
(result_f - 32752.0).abs() < 1.0,
"unexpected result {result_f}"
);
}
}