use std::any::Any;
use super::log::LogFunc;
use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::i256;
use arrow::datatypes::{
ArrowNativeType, ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type,
Decimal128Type, Decimal256Type, Float64Type, Int64Type,
};
use arrow::error::ArrowError;
use datafusion_common::types::{NativeType, logical_float64, logical_int64};
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue, internal_err};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit,
};
use datafusion_macros::user_doc;
use num_traits::{NumCast, ToPrimitive};
#[user_doc(
doc_section(label = "Math Functions"),
description = "Returns a base expression raised to the power of an exponent.",
syntax_example = "power(base, exponent)",
sql_example = r#"```sql
> SELECT power(2, 3);
+-------------+
| power(2,3) |
+-------------+
| 8 |
+-------------+
```"#,
standard_argument(name = "base", prefix = "Numeric"),
standard_argument(name = "exponent", prefix = "Exponent numeric")
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct PowerFunc {
signature: Signature,
aliases: Vec<String>,
}
impl Default for PowerFunc {
fn default() -> Self {
Self::new()
}
}
impl PowerFunc {
pub fn new() -> Self {
let integer = Coercion::new_implicit(
TypeSignatureClass::Native(logical_int64()),
vec![TypeSignatureClass::Integer],
NativeType::Int64,
);
let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
let float = Coercion::new_implicit(
TypeSignatureClass::Native(logical_float64()),
vec![TypeSignatureClass::Numeric],
NativeType::Float64,
);
Self {
signature: Signature::one_of(
vec![
TypeSignature::Coercible(vec![decimal.clone(), integer]),
TypeSignature::Coercible(vec![decimal.clone(), float.clone()]),
TypeSignature::Coercible(vec![float; 2]),
],
Volatility::Immutable,
),
aliases: vec![String::from("pow")],
}
}
}
fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
where
T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
{
if exp < 0 {
return pow_decimal_float(base, scale, exp as f64);
}
let exp: u32 = exp.try_into().map_err(|_| {
ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
})?;
if exp == 0 {
return if scale >= 0 {
T::usize_as(10).pow_checked(scale as u32).map_err(|_| {
ArrowError::ArithmeticOverflow(format!(
"Cannot make unscale factor for {scale} and {exp}"
))
})
} else {
Ok(T::ZERO)
};
}
let powered: T = base.pow_checked(exp).map_err(|_| {
ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
})?;
let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
if mul_exp == 0 {
return Ok(powered);
}
if mul_exp > 0 {
let div_factor: T =
T::usize_as(10).pow_checked(mul_exp as u32).map_err(|_| {
ArrowError::ArithmeticOverflow(format!(
"Cannot make div factor for {scale} and {exp}"
))
})?;
powered.div_checked(div_factor)
} else {
let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
ArrowError::ArithmeticOverflow(
"Overflow while negating scale exponent".to_string(),
)
})?;
let mul_factor: T =
T::usize_as(10).pow_checked(abs_exp as u32).map_err(|_| {
ArrowError::ArithmeticOverflow(format!(
"Cannot make mul factor for {scale} and {exp}"
))
})?;
powered.mul_checked(mul_factor)
}
}
fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
where
T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
{
if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
return pow_decimal_int(base, scale, exp as i64);
}
if !exp.is_finite() {
return Err(ArrowError::ComputeError(format!(
"Cannot use non-finite exp: {exp}"
)));
}
pow_decimal_float_fallback(base, scale, exp)
}
#[inline]
fn compute_pow_f64_result(
base_f64: f64,
scale: i8,
exp: f64,
) -> Result<i128, ArrowError> {
let result_f64 = base_f64.powf(exp);
if !result_f64.is_finite() {
return Err(ArrowError::ArithmeticOverflow(format!(
"Result of {base_f64}^{exp} is not finite"
)));
}
let scale_factor = 10f64.powi(scale as i32);
let result_scaled = result_f64 * scale_factor;
let result_rounded = result_scaled.round();
if result_rounded.abs() > i128::MAX as f64 {
return Err(ArrowError::ArithmeticOverflow(format!(
"Result {result_rounded} is too large for the target decimal type"
)));
}
Ok(result_rounded as i128)
}
#[inline]
fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
where
T: NumCast,
{
NumCast::from(value).ok_or_else(|| {
ArrowError::ArithmeticOverflow(format!(
"Value {value} is too large for the target decimal type"
))
})
}
fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
where
T: ToPrimitive + NumCast + Copy,
{
if scale < 0 {
return Err(ArrowError::NotYetImplemented(format!(
"Negative scale is not yet supported: {scale}"
)));
}
let scale_factor = 10f64.powi(scale as i32);
let base_f64 = base.to_f64().ok_or_else(|| {
ArrowError::ComputeError("Cannot convert base to f64".to_string())
})? / scale_factor;
let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
decimal_from_i128(result_i128)
}
fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result<i256, ArrowError> {
if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
return pow_decimal256_int(base, scale, exp as i64);
}
if !exp.is_finite() {
return Err(ArrowError::ComputeError(format!(
"Cannot use non-finite exp: {exp}"
)));
}
pow_decimal256_float_fallback(base, scale, exp)
}
fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result<i256, ArrowError> {
if exp < 0 {
return pow_decimal256_float(base, scale, exp as f64);
}
let exp: u32 = exp.try_into().map_err(|_| {
ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
})?;
if exp == 0 {
return if scale >= 0 {
i256::from_i128(10).pow_checked(scale as u32).map_err(|_| {
ArrowError::ArithmeticOverflow(format!(
"Cannot make unscale factor for {scale} and {exp}"
))
})
} else {
Ok(i256::from_i128(0))
};
}
let powered: i256 = base.pow_checked(exp).map_err(|_| {
ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
})?;
let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
if mul_exp == 0 {
return Ok(powered);
}
if mul_exp > 0 {
let div_factor: i256 =
i256::from_i128(10)
.pow_checked(mul_exp as u32)
.map_err(|_| {
ArrowError::ArithmeticOverflow(format!(
"Cannot make div factor for {scale} and {exp}"
))
})?;
powered.div_checked(div_factor)
} else {
let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
ArrowError::ArithmeticOverflow(
"Overflow while negating scale exponent".to_string(),
)
})?;
let mul_factor: i256 =
i256::from_i128(10)
.pow_checked(abs_exp as u32)
.map_err(|_| {
ArrowError::ArithmeticOverflow(format!(
"Cannot make mul factor for {scale} and {exp}"
))
})?;
powered.mul_checked(mul_factor)
}
}
fn pow_decimal256_float_fallback(
base: i256,
scale: i8,
exp: f64,
) -> Result<i256, ArrowError> {
if scale < 0 {
return Err(ArrowError::NotYetImplemented(format!(
"Negative scale is not yet supported: {scale}"
)));
}
let scale_factor = 10f64.powi(scale as i32);
let base_f64 = base.to_f64().ok_or_else(|| {
ArrowError::ComputeError("Cannot convert base to f64".to_string())
})? / scale_factor;
let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
Ok(i256::from_i128(result_i128))
}
fn pow_decimal_with_float_fallback(
base: &ArrayRef,
exponent: &ColumnarValue,
num_rows: usize,
) -> Result<ColumnarValue> {
use arrow::compute::cast;
let original_type = base.data_type().clone();
let base_f64 = cast(base.as_ref(), &DataType::Float64)?;
let exp_f64 = match exponent {
ColumnarValue::Array(arr) => cast(arr.as_ref(), &DataType::Float64)?,
ColumnarValue::Scalar(scalar) => {
let scalar_f64 = scalar.cast_to(&DataType::Float64)?;
scalar_f64.to_array_of_size(num_rows)?
}
};
let result_f64 = calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
&base_f64,
&ColumnarValue::Array(exp_f64),
|b, e| Ok(f64::powf(b, e)),
)?;
let result = cast(result_f64.as_ref(), &original_type)?;
Ok(ColumnarValue::Array(result))
}
impl ScalarUDFImpl for PowerFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"power"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types[0].is_null() {
Ok(DataType::Float64)
} else {
Ok(arg_types[0].clone())
}
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [base, exponent] = take_function_args(self.name(), &args.args)?;
let use_float_fallback = matches!(
base.data_type(),
DataType::Decimal32(_, _)
| DataType::Decimal64(_, _)
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
) && matches!(exponent, ColumnarValue::Array(_));
let base = base.to_array(args.number_rows)?;
if use_float_fallback {
return pow_decimal_with_float_fallback(&base, exponent, args.number_rows);
}
let arr: ArrayRef = match (base.data_type(), exponent.data_type()) {
(DataType::Float64, DataType::Float64) => {
calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
&base,
exponent,
|b, e| Ok(f64::powf(b, e)),
)?
}
(DataType::Decimal32(precision, scale), DataType::Int64) => {
calculate_binary_decimal_math::<Decimal32Type, Int64Type, Decimal32Type, _>(
&base,
exponent,
|b, e| pow_decimal_int(b, *scale, e),
*precision,
*scale,
)?
}
(DataType::Decimal32(precision, scale), DataType::Float64) => {
calculate_binary_decimal_math::<
Decimal32Type,
Float64Type,
Decimal32Type,
_,
>(
&base,
exponent,
|b, e| pow_decimal_float(b, *scale, e),
*precision,
*scale,
)?
}
(DataType::Decimal64(precision, scale), DataType::Int64) => {
calculate_binary_decimal_math::<Decimal64Type, Int64Type, Decimal64Type, _>(
&base,
exponent,
|b, e| pow_decimal_int(b, *scale, e),
*precision,
*scale,
)?
}
(DataType::Decimal64(precision, scale), DataType::Float64) => {
calculate_binary_decimal_math::<
Decimal64Type,
Float64Type,
Decimal64Type,
_,
>(
&base,
exponent,
|b, e| pow_decimal_float(b, *scale, e),
*precision,
*scale,
)?
}
(DataType::Decimal128(precision, scale), DataType::Int64) => {
calculate_binary_decimal_math::<
Decimal128Type,
Int64Type,
Decimal128Type,
_,
>(
&base,
exponent,
|b, e| pow_decimal_int(b, *scale, e),
*precision,
*scale,
)?
}
(DataType::Decimal128(precision, scale), DataType::Float64) => {
calculate_binary_decimal_math::<
Decimal128Type,
Float64Type,
Decimal128Type,
_,
>(
&base,
exponent,
|b, e| pow_decimal_float(b, *scale, e),
*precision,
*scale,
)?
}
(DataType::Decimal256(precision, scale), DataType::Int64) => {
calculate_binary_decimal_math::<
Decimal256Type,
Int64Type,
Decimal256Type,
_,
>(
&base,
exponent,
|b, e| pow_decimal256_int(b, *scale, e),
*precision,
*scale,
)?
}
(DataType::Decimal256(precision, scale), DataType::Float64) => {
calculate_binary_decimal_math::<
Decimal256Type,
Float64Type,
Decimal256Type,
_,
>(
&base,
exponent,
|b, e| pow_decimal256_float(b, *scale, e),
*precision,
*scale,
)?
}
(base_type, exp_type) => {
return internal_err!(
"Unsupported data types for base {base_type:?} and exponent {exp_type:?} for power"
);
}
};
Ok(ColumnarValue::Array(arr))
}
fn simplify(
&self,
args: Vec<Expr>,
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
let [base, exponent] = take_function_args("power", args)?;
let base_type = info.get_data_type(&base)?;
let exponent_type = info.get_data_type(&exponent)?;
if base_type.is_null() || exponent_type.is_null() {
let return_type = self.return_type(&[base_type, exponent_type])?;
return Ok(ExprSimplifyResult::Simplified(lit(
ScalarValue::Null.cast_to(&return_type)?
)));
}
match exponent {
Expr::Literal(value, _)
if value == ScalarValue::new_zero(&exponent_type)? =>
{
Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
&base_type,
)?)))
}
Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
Ok(ExprSimplifyResult::Simplified(base))
}
Expr::ScalarFunction(ScalarFunction { func, mut args })
if is_log(&func) && args.len() == 2 && base == args[0] =>
{
let b = args.pop().unwrap(); Ok(ExprSimplifyResult::Simplified(b))
}
_ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
}
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
fn is_log(func: &ScalarUDF) -> bool {
func.inner().as_any().downcast_ref::<LogFunc>().is_some()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pow_decimal128_helper() {
assert_eq!(pow_decimal_int(25i128, 1, 4).unwrap(), 390i128);
assert_eq!(pow_decimal_int(2500i128, 3, 4).unwrap(), 39062i128);
assert_eq!(pow_decimal_int(25000i128, 4, 4).unwrap(), 390625i128);
assert_eq!(pow_decimal_int(25i128, 0, 4).unwrap(), 390625i128);
assert_eq!(pow_decimal_int(25i128, 1, 1).unwrap(), 25i128);
assert_eq!(pow_decimal_int(25i128, 0, 1).unwrap(), 25i128);
assert_eq!(pow_decimal_int(25i128, 0, 0).unwrap(), 1i128);
assert_eq!(pow_decimal_int(25i128, 1, 0).unwrap(), 10i128);
assert_eq!(pow_decimal_int(25i128, -1, 4).unwrap(), 390625000i128);
}
#[test]
fn test_pow_decimal_float_fallback() {
let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap();
assert_eq!(result, 25);
let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap();
assert_eq!(result, 200);
let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap();
assert_eq!(result, 20);
let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap();
assert_eq!(result, -80);
let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap();
assert_eq!(result, 390);
assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err());
assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err());
}
}