Skip to main content

datafusion_functions/math/
power.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Math function: `power()`.
19use std::any::Any;
20
21use super::log::LogFunc;
22
23use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
24use arrow::array::{Array, ArrayRef};
25use arrow::datatypes::i256;
26use arrow::datatypes::{
27    ArrowNativeType, ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type,
28    Decimal128Type, Decimal256Type, Float64Type, Int64Type,
29};
30use arrow::error::ArrowError;
31use datafusion_common::types::{NativeType, logical_float64, logical_int64};
32use datafusion_common::utils::take_function_args;
33use datafusion_common::{Result, ScalarValue, internal_err};
34use datafusion_expr::expr::ScalarFunction;
35use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
36use datafusion_expr::{
37    Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF,
38    ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit,
39};
40use datafusion_macros::user_doc;
41use num_traits::{NumCast, ToPrimitive};
42
43#[user_doc(
44    doc_section(label = "Math Functions"),
45    description = "Returns a base expression raised to the power of an exponent.",
46    syntax_example = "power(base, exponent)",
47    sql_example = r#"```sql
48> SELECT power(2, 3);
49+-------------+
50| power(2,3)  |
51+-------------+
52| 8           |
53+-------------+
54```"#,
55    standard_argument(name = "base", prefix = "Numeric"),
56    standard_argument(name = "exponent", prefix = "Exponent numeric")
57)]
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct PowerFunc {
60    signature: Signature,
61    aliases: Vec<String>,
62}
63
64impl Default for PowerFunc {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl PowerFunc {
71    pub fn new() -> Self {
72        let integer = Coercion::new_implicit(
73            TypeSignatureClass::Native(logical_int64()),
74            vec![TypeSignatureClass::Integer],
75            NativeType::Int64,
76        );
77        let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
78        let float = Coercion::new_implicit(
79            TypeSignatureClass::Native(logical_float64()),
80            vec![TypeSignatureClass::Numeric],
81            NativeType::Float64,
82        );
83        Self {
84            signature: Signature::one_of(
85                vec![
86                    TypeSignature::Coercible(vec![decimal.clone(), integer]),
87                    TypeSignature::Coercible(vec![decimal.clone(), float.clone()]),
88                    TypeSignature::Coercible(vec![float; 2]),
89                ],
90                Volatility::Immutable,
91            ),
92            aliases: vec![String::from("pow")],
93        }
94    }
95}
96
97/// Binary function to calculate a math power to integer exponent
98/// for scaled integer types.
99///
100/// Formula
101/// The power for a scaled integer `b` is
102///
103/// ```text
104/// (b * 10^(-s)) ^ e
105/// ```
106/// However, the result should be scaled back from scale 0 to scale `s`,
107/// which is done by multiplying by `10^s`.
108/// At the end, the formula is:
109///
110/// ```text
111///   b^e * 10^(-s * e) * 10^s = b^e / 10^(s * (e-1))
112/// ```
113/// Example of 2.5 ^ 4 = 39:
114///   2.5 is represented as 25 with scale 1
115///   The unscaled result is 25^4 = 390625
116///   Scale it back to 1: 390625 / 10^4 = 39
117fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
118where
119    T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
120{
121    // Negative exponent: fall back to float computation
122    if exp < 0 {
123        return pow_decimal_float(base, scale, exp as f64);
124    }
125
126    let exp: u32 = exp.try_into().map_err(|_| {
127        ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
128    })?;
129    // Handle edge case for exp == 0
130    // If scale < 0, 10^scale (e.g., 10^-2 = 0.01) becomes 0 in integer arithmetic.
131    if exp == 0 {
132        return if scale >= 0 {
133            T::usize_as(10).pow_checked(scale as u32).map_err(|_| {
134                ArrowError::ArithmeticOverflow(format!(
135                    "Cannot make unscale factor for {scale} and {exp}"
136                ))
137            })
138        } else {
139            Ok(T::ZERO)
140        };
141    }
142    let powered: T = base.pow_checked(exp).map_err(|_| {
143        ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
144    })?;
145
146    // Calculate the scale adjustment: s * (e - 1)
147    // We use i64 to prevent overflow during the intermediate multiplication
148    let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
149
150    if mul_exp == 0 {
151        return Ok(powered);
152    }
153
154    // If mul_exp is positive, we divide (standard case).
155    // If mul_exp is negative, we multiply (negative scale case).
156    if mul_exp > 0 {
157        let div_factor: T =
158            T::usize_as(10).pow_checked(mul_exp as u32).map_err(|_| {
159                ArrowError::ArithmeticOverflow(format!(
160                    "Cannot make div factor for {scale} and {exp}"
161                ))
162            })?;
163        powered.div_checked(div_factor)
164    } else {
165        // mul_exp is negative, so we multiply by 10^(-mul_exp)
166        let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
167            ArrowError::ArithmeticOverflow(
168                "Overflow while negating scale exponent".to_string(),
169            )
170        })?;
171        let mul_factor: T =
172            T::usize_as(10).pow_checked(abs_exp as u32).map_err(|_| {
173                ArrowError::ArithmeticOverflow(format!(
174                    "Cannot make mul factor for {scale} and {exp}"
175                ))
176            })?;
177        powered.mul_checked(mul_factor)
178    }
179}
180
181/// Binary function to calculate a math power to float exponent
182/// for scaled integer types.
183fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
184where
185    T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
186{
187    if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
188        return pow_decimal_int(base, scale, exp as i64);
189    }
190
191    if !exp.is_finite() {
192        return Err(ArrowError::ComputeError(format!(
193            "Cannot use non-finite exp: {exp}"
194        )));
195    }
196
197    pow_decimal_float_fallback(base, scale, exp)
198}
199
200/// Compute the f64 power result and scale it back.
201/// Returns the rounded i128 result for conversion to target type.
202#[inline]
203fn compute_pow_f64_result(
204    base_f64: f64,
205    scale: i8,
206    exp: f64,
207) -> Result<i128, ArrowError> {
208    let result_f64 = base_f64.powf(exp);
209
210    if !result_f64.is_finite() {
211        return Err(ArrowError::ArithmeticOverflow(format!(
212            "Result of {base_f64}^{exp} is not finite"
213        )));
214    }
215
216    let scale_factor = 10f64.powi(scale as i32);
217    let result_scaled = result_f64 * scale_factor;
218    let result_rounded = result_scaled.round();
219
220    if result_rounded.abs() > i128::MAX as f64 {
221        return Err(ArrowError::ArithmeticOverflow(format!(
222            "Result {result_rounded} is too large for the target decimal type"
223        )));
224    }
225
226    Ok(result_rounded as i128)
227}
228
229/// Convert i128 result to target decimal native type using NumCast.
230/// Returns error if value overflows the target type.
231#[inline]
232fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
233where
234    T: NumCast,
235{
236    NumCast::from(value).ok_or_else(|| {
237        ArrowError::ArithmeticOverflow(format!(
238            "Value {value} is too large for the target decimal type"
239        ))
240    })
241}
242
243/// Fallback implementation using f64 for negative or non-integer exponents.
244/// This handles cases that cannot be computed using integer arithmetic.
245fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
246where
247    T: ToPrimitive + NumCast + Copy,
248{
249    if scale < 0 {
250        return Err(ArrowError::NotYetImplemented(format!(
251            "Negative scale is not yet supported: {scale}"
252        )));
253    }
254
255    let scale_factor = 10f64.powi(scale as i32);
256    let base_f64 = base.to_f64().ok_or_else(|| {
257        ArrowError::ComputeError("Cannot convert base to f64".to_string())
258    })? / scale_factor;
259
260    let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
261
262    decimal_from_i128(result_i128)
263}
264
265/// Decimal256 specialized float exponent version.
266fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result<i256, ArrowError> {
267    if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
268        return pow_decimal256_int(base, scale, exp as i64);
269    }
270
271    if !exp.is_finite() {
272        return Err(ArrowError::ComputeError(format!(
273            "Cannot use non-finite exp: {exp}"
274        )));
275    }
276
277    pow_decimal256_float_fallback(base, scale, exp)
278}
279
280/// Decimal256 specialized integer exponent version.
281fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result<i256, ArrowError> {
282    if exp < 0 {
283        return pow_decimal256_float(base, scale, exp as f64);
284    }
285
286    let exp: u32 = exp.try_into().map_err(|_| {
287        ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
288    })?;
289
290    if exp == 0 {
291        return if scale >= 0 {
292            i256::from_i128(10).pow_checked(scale as u32).map_err(|_| {
293                ArrowError::ArithmeticOverflow(format!(
294                    "Cannot make unscale factor for {scale} and {exp}"
295                ))
296            })
297        } else {
298            Ok(i256::from_i128(0))
299        };
300    }
301
302    let powered: i256 = base.pow_checked(exp).map_err(|_| {
303        ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
304    })?;
305
306    let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
307
308    if mul_exp == 0 {
309        return Ok(powered);
310    }
311
312    if mul_exp > 0 {
313        let div_factor: i256 =
314            i256::from_i128(10)
315                .pow_checked(mul_exp as u32)
316                .map_err(|_| {
317                    ArrowError::ArithmeticOverflow(format!(
318                        "Cannot make div factor for {scale} and {exp}"
319                    ))
320                })?;
321        powered.div_checked(div_factor)
322    } else {
323        let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
324            ArrowError::ArithmeticOverflow(
325                "Overflow while negating scale exponent".to_string(),
326            )
327        })?;
328        let mul_factor: i256 =
329            i256::from_i128(10)
330                .pow_checked(abs_exp as u32)
331                .map_err(|_| {
332                    ArrowError::ArithmeticOverflow(format!(
333                        "Cannot make mul factor for {scale} and {exp}"
334                    ))
335                })?;
336        powered.mul_checked(mul_factor)
337    }
338}
339
340/// Fallback implementation for Decimal256.
341fn pow_decimal256_float_fallback(
342    base: i256,
343    scale: i8,
344    exp: f64,
345) -> Result<i256, ArrowError> {
346    if scale < 0 {
347        return Err(ArrowError::NotYetImplemented(format!(
348            "Negative scale is not yet supported: {scale}"
349        )));
350    }
351
352    let scale_factor = 10f64.powi(scale as i32);
353    let base_f64 = base.to_f64().ok_or_else(|| {
354        ArrowError::ComputeError("Cannot convert base to f64".to_string())
355    })? / scale_factor;
356
357    let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
358
359    // i256 can be constructed from i128 directly
360    Ok(i256::from_i128(result_i128))
361}
362
363/// Fallback implementation for decimal power when exponent is an array.
364/// Casts decimal to float64, computes power, and casts back to original decimal type.
365/// This is used for performance when exponent varies per-row.
366fn pow_decimal_with_float_fallback(
367    base: &ArrayRef,
368    exponent: &ColumnarValue,
369    num_rows: usize,
370) -> Result<ColumnarValue> {
371    use arrow::compute::cast;
372
373    let original_type = base.data_type().clone();
374    let base_f64 = cast(base.as_ref(), &DataType::Float64)?;
375
376    let exp_f64 = match exponent {
377        ColumnarValue::Array(arr) => cast(arr.as_ref(), &DataType::Float64)?,
378        ColumnarValue::Scalar(scalar) => {
379            let scalar_f64 = scalar.cast_to(&DataType::Float64)?;
380            scalar_f64.to_array_of_size(num_rows)?
381        }
382    };
383
384    let result_f64 = calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
385        &base_f64,
386        &ColumnarValue::Array(exp_f64),
387        |b, e| Ok(f64::powf(b, e)),
388    )?;
389
390    let result = cast(result_f64.as_ref(), &original_type)?;
391    Ok(ColumnarValue::Array(result))
392}
393
394impl ScalarUDFImpl for PowerFunc {
395    fn as_any(&self) -> &dyn Any {
396        self
397    }
398
399    fn name(&self) -> &str {
400        "power"
401    }
402
403    fn signature(&self) -> &Signature {
404        &self.signature
405    }
406
407    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
408        if arg_types[0].is_null() {
409            Ok(DataType::Float64)
410        } else {
411            Ok(arg_types[0].clone())
412        }
413    }
414
415    fn aliases(&self) -> &[String] {
416        &self.aliases
417    }
418
419    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
420        let [base, exponent] = take_function_args(self.name(), &args.args)?;
421
422        // For decimal types, only use native decimal
423        // operations when we have a scalar exponent. When the exponent is an array,
424        // fall back to float computation for better performance.
425        let use_float_fallback = matches!(
426            base.data_type(),
427            DataType::Decimal32(_, _)
428                | DataType::Decimal64(_, _)
429                | DataType::Decimal128(_, _)
430                | DataType::Decimal256(_, _)
431        ) && matches!(exponent, ColumnarValue::Array(_));
432
433        let base = base.to_array(args.number_rows)?;
434
435        // If decimal with array exponent, cast to float and compute
436        if use_float_fallback {
437            return pow_decimal_with_float_fallback(&base, exponent, args.number_rows);
438        }
439
440        let arr: ArrayRef = match (base.data_type(), exponent.data_type()) {
441            (DataType::Float64, DataType::Float64) => {
442                calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
443                    &base,
444                    exponent,
445                    |b, e| Ok(f64::powf(b, e)),
446                )?
447            }
448            (DataType::Decimal32(precision, scale), DataType::Int64) => {
449                calculate_binary_decimal_math::<Decimal32Type, Int64Type, Decimal32Type, _>(
450                    &base,
451                    exponent,
452                    |b, e| pow_decimal_int(b, *scale, e),
453                    *precision,
454                    *scale,
455                )?
456            }
457            (DataType::Decimal32(precision, scale), DataType::Float64) => {
458                calculate_binary_decimal_math::<
459                    Decimal32Type,
460                    Float64Type,
461                    Decimal32Type,
462                    _,
463                >(
464                    &base,
465                    exponent,
466                    |b, e| pow_decimal_float(b, *scale, e),
467                    *precision,
468                    *scale,
469                )?
470            }
471            (DataType::Decimal64(precision, scale), DataType::Int64) => {
472                calculate_binary_decimal_math::<Decimal64Type, Int64Type, Decimal64Type, _>(
473                    &base,
474                    exponent,
475                    |b, e| pow_decimal_int(b, *scale, e),
476                    *precision,
477                    *scale,
478                )?
479            }
480            (DataType::Decimal64(precision, scale), DataType::Float64) => {
481                calculate_binary_decimal_math::<
482                    Decimal64Type,
483                    Float64Type,
484                    Decimal64Type,
485                    _,
486                >(
487                    &base,
488                    exponent,
489                    |b, e| pow_decimal_float(b, *scale, e),
490                    *precision,
491                    *scale,
492                )?
493            }
494            (DataType::Decimal128(precision, scale), DataType::Int64) => {
495                calculate_binary_decimal_math::<
496                    Decimal128Type,
497                    Int64Type,
498                    Decimal128Type,
499                    _,
500                >(
501                    &base,
502                    exponent,
503                    |b, e| pow_decimal_int(b, *scale, e),
504                    *precision,
505                    *scale,
506                )?
507            }
508            (DataType::Decimal128(precision, scale), DataType::Float64) => {
509                calculate_binary_decimal_math::<
510                    Decimal128Type,
511                    Float64Type,
512                    Decimal128Type,
513                    _,
514                >(
515                    &base,
516                    exponent,
517                    |b, e| pow_decimal_float(b, *scale, e),
518                    *precision,
519                    *scale,
520                )?
521            }
522            (DataType::Decimal256(precision, scale), DataType::Int64) => {
523                calculate_binary_decimal_math::<
524                    Decimal256Type,
525                    Int64Type,
526                    Decimal256Type,
527                    _,
528                >(
529                    &base,
530                    exponent,
531                    |b, e| pow_decimal256_int(b, *scale, e),
532                    *precision,
533                    *scale,
534                )?
535            }
536            (DataType::Decimal256(precision, scale), DataType::Float64) => {
537                calculate_binary_decimal_math::<
538                    Decimal256Type,
539                    Float64Type,
540                    Decimal256Type,
541                    _,
542                >(
543                    &base,
544                    exponent,
545                    |b, e| pow_decimal256_float(b, *scale, e),
546                    *precision,
547                    *scale,
548                )?
549            }
550            (base_type, exp_type) => {
551                return internal_err!(
552                    "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for power"
553                );
554            }
555        };
556        Ok(ColumnarValue::Array(arr))
557    }
558
559    /// Simplify the `power` function by the relevant rules:
560    /// 1. Power(a, 0) ===> 1
561    /// 2. Power(a, 1) ===> a
562    /// 3. Power(a, Log(a, b)) ===> b
563    fn simplify(
564        &self,
565        args: Vec<Expr>,
566        info: &SimplifyContext,
567    ) -> Result<ExprSimplifyResult> {
568        let [base, exponent] = take_function_args("power", args)?;
569        let base_type = info.get_data_type(&base)?;
570        let exponent_type = info.get_data_type(&exponent)?;
571
572        // Null propagation
573        if base_type.is_null() || exponent_type.is_null() {
574            let return_type = self.return_type(&[base_type, exponent_type])?;
575            return Ok(ExprSimplifyResult::Simplified(lit(
576                ScalarValue::Null.cast_to(&return_type)?
577            )));
578        }
579
580        match exponent {
581            Expr::Literal(value, _)
582                if value == ScalarValue::new_zero(&exponent_type)? =>
583            {
584                Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
585                    &base_type,
586                )?)))
587            }
588            Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
589                Ok(ExprSimplifyResult::Simplified(base))
590            }
591            Expr::ScalarFunction(ScalarFunction { func, mut args })
592                if is_log(&func) && args.len() == 2 && base == args[0] =>
593            {
594                let b = args.pop().unwrap(); // length checked above
595                Ok(ExprSimplifyResult::Simplified(b))
596            }
597            _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
598        }
599    }
600
601    fn documentation(&self) -> Option<&Documentation> {
602        self.doc()
603    }
604}
605
606/// Return true if this function call is a call to `Log`
607fn is_log(func: &ScalarUDF) -> bool {
608    func.inner().as_any().downcast_ref::<LogFunc>().is_some()
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614
615    #[test]
616    fn test_pow_decimal128_helper() {
617        // Expression: 2.5 ^ 4 = 39.0625
618        assert_eq!(pow_decimal_int(25i128, 1, 4).unwrap(), 390i128);
619        assert_eq!(pow_decimal_int(2500i128, 3, 4).unwrap(), 39062i128);
620        assert_eq!(pow_decimal_int(25000i128, 4, 4).unwrap(), 390625i128);
621
622        // Expression: 25 ^ 4 = 390625
623        assert_eq!(pow_decimal_int(25i128, 0, 4).unwrap(), 390625i128);
624
625        // Expressions for edge cases
626        assert_eq!(pow_decimal_int(25i128, 1, 1).unwrap(), 25i128);
627        assert_eq!(pow_decimal_int(25i128, 0, 1).unwrap(), 25i128);
628        assert_eq!(pow_decimal_int(25i128, 0, 0).unwrap(), 1i128);
629        assert_eq!(pow_decimal_int(25i128, 1, 0).unwrap(), 10i128);
630
631        assert_eq!(pow_decimal_int(25i128, -1, 4).unwrap(), 390625000i128);
632    }
633
634    #[test]
635    fn test_pow_decimal_float_fallback() {
636        // Test negative exponent: 4^(-1) = 0.25
637        // 4 with scale 2 = 400, result should be 25 (0.25 with scale 2)
638        let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap();
639        assert_eq!(result, 25);
640
641        // Test non-integer exponent: 4^0.5 = 2
642        // 4 with scale 2 = 400, result should be 200 (2.0 with scale 2)
643        let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap();
644        assert_eq!(result, 200);
645
646        // Test 8^(1/3) = 2 (cube root)
647        // 8 with scale 1 = 80, result should be 20 (2.0 with scale 1)
648        let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap();
649        assert_eq!(result, 20);
650
651        // Test negative base with integer exponent still works
652        // (-2)^3 = -8
653        // -2 with scale 1 = -20, result should be -80 (-8.0 with scale 1)
654        let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap();
655        assert_eq!(result, -80);
656
657        // Test positive integer exponent goes through fast path
658        // 2.5^4 = 39.0625
659        // 25 with scale 1, result should be 390 (39.0 with scale 1) - truncated
660        let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap();
661        assert_eq!(result, 390); // Uses integer path
662
663        // Test non-finite exponent returns error
664        assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err());
665        assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err());
666    }
667}