use std::sync::Arc;
use arrow::array::RecordBatch;
use arrow::datatypes::Schema;
use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err, plan_err};
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
pub(crate) fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
let empty_schema = Arc::new(Schema::empty());
let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
Ok(s)
} else {
internal_err!("Didn't expect ColumnarValue::Array")
}
}
pub(crate) fn validate_percentile_expr(
expr: &Arc<dyn PhysicalExpr>,
fn_name: &str,
) -> Result<f64> {
let scalar_value = get_scalar_value(expr).map_err(|_e| {
DataFusionError::Plan(format!(
"Percentile value for '{fn_name}' must be a literal"
))
})?;
let percentile = match scalar_value {
ScalarValue::Float32(Some(value)) => value as f64,
ScalarValue::Float64(Some(value)) => value,
sv => {
return plan_err!(
"Percentile value for '{fn_name}' must be Float32 or Float64 literal (got data type {})",
sv.data_type()
);
}
};
if !(0.0..=1.0).contains(&percentile) {
return plan_err!(
"Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
);
}
Ok(percentile)
}