datafusion_functions/math/
power.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Math function: `power()`.
19use std::any::Any;
20
21use super::log::LogFunc;
22
23use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
24use arrow::array::{Array, ArrayRef};
25use arrow::datatypes::{
26    ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
27    Decimal256Type, Float64Type, Int64Type,
28};
29use arrow::error::ArrowError;
30use datafusion_common::types::{NativeType, logical_float64, logical_int64};
31use datafusion_common::utils::take_function_args;
32use datafusion_common::{Result, ScalarValue, internal_err};
33use datafusion_expr::expr::ScalarFunction;
34use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
35use datafusion_expr::{
36    Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF,
37    ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit,
38};
39use datafusion_macros::user_doc;
40
41#[user_doc(
42    doc_section(label = "Math Functions"),
43    description = "Returns a base expression raised to the power of an exponent.",
44    syntax_example = "power(base, exponent)",
45    sql_example = r#"```sql
46> SELECT power(2, 3);
47+-------------+
48| power(2,3)  |
49+-------------+
50| 8           |
51+-------------+
52```"#,
53    standard_argument(name = "base", prefix = "Numeric"),
54    standard_argument(name = "exponent", prefix = "Exponent numeric")
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct PowerFunc {
58    signature: Signature,
59    aliases: Vec<String>,
60}
61
62impl Default for PowerFunc {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl PowerFunc {
69    pub fn new() -> Self {
70        let integer = Coercion::new_implicit(
71            TypeSignatureClass::Native(logical_int64()),
72            vec![TypeSignatureClass::Integer],
73            NativeType::Int64,
74        );
75        let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
76        let float = Coercion::new_implicit(
77            TypeSignatureClass::Native(logical_float64()),
78            vec![TypeSignatureClass::Numeric],
79            NativeType::Float64,
80        );
81        Self {
82            signature: Signature::one_of(
83                vec![
84                    TypeSignature::Coercible(vec![decimal.clone(), integer]),
85                    TypeSignature::Coercible(vec![decimal.clone(), float.clone()]),
86                    TypeSignature::Coercible(vec![float; 2]),
87                ],
88                Volatility::Immutable,
89            ),
90            aliases: vec![String::from("pow")],
91        }
92    }
93}
94
95/// Binary function to calculate a math power to integer exponent
96/// for scaled integer types.
97///
98/// Formula
99/// The power for a scaled integer `b` is
100///
101/// ```text
102/// (b * 10^(-s)) ^ e
103/// ```
104/// However, the result should be scaled back from scale 0 to scale `s`,
105/// which is done by multiplying by `10^s`.
106/// At the end, the formula is:
107///
108/// ```text
109///   b^e * 10^(-s * e) * 10^s = b^e / 10^(s * (e-1))
110/// ```
111/// Example of 2.5 ^ 4 = 39:
112///   2.5 is represented as 25 with scale 1
113///   The unscaled result is 25^4 = 390625
114///   Scale it back to 1: 390625 / 10^4 = 39
115///
116/// Returns error if base is invalid
117fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
118where
119    T: From<i32> + ArrowNativeTypeOp,
120{
121    let exp: u32 = exp.try_into().map_err(|_| {
122        ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
123    })?;
124    // Handle edge case for exp == 0
125    // If scale < 0, 10^scale (e.g., 10^-2 = 0.01) becomes 0 in integer arithmetic.
126    if exp == 0 {
127        return if scale >= 0 {
128            T::from(10).pow_checked(scale as u32).map_err(|_| {
129                ArrowError::ArithmeticOverflow(format!(
130                    "Cannot make unscale factor for {scale} and {exp}"
131                ))
132            })
133        } else {
134            Ok(T::from(0))
135        };
136    }
137    let powered: T = base.pow_checked(exp).map_err(|_| {
138        ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
139    })?;
140
141    // Calculate the scale adjustment: s * (e - 1)
142    // We use i64 to prevent overflow during the intermediate multiplication
143    let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
144
145    if mul_exp == 0 {
146        return Ok(powered);
147    }
148
149    // If mul_exp is positive, we divide (standard case).
150    // If mul_exp is negative, we multiply (negative scale case).
151    if mul_exp > 0 {
152        let div_factor: T = T::from(10).pow_checked(mul_exp as u32).map_err(|_| {
153            ArrowError::ArithmeticOverflow(format!(
154                "Cannot make div factor for {scale} and {exp}"
155            ))
156        })?;
157        powered.div_checked(div_factor)
158    } else {
159        // mul_exp is negative, so we multiply by 10^(-mul_exp)
160        let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
161            ArrowError::ArithmeticOverflow(
162                "Overflow while negating scale exponent".to_string(),
163            )
164        })?;
165        let mul_factor: T = T::from(10).pow_checked(abs_exp as u32).map_err(|_| {
166            ArrowError::ArithmeticOverflow(format!(
167                "Cannot make mul factor for {scale} and {exp}"
168            ))
169        })?;
170        powered.mul_checked(mul_factor)
171    }
172}
173
174/// Binary function to calculate a math power to float exponent
175/// for scaled integer types.
176/// Returns error if exponent is negative or non-integer, or base invalid
177fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
178where
179    T: From<i32> + ArrowNativeTypeOp,
180{
181    if !exp.is_finite() || exp.trunc() != exp {
182        return Err(ArrowError::ComputeError(format!(
183            "Cannot use non-integer exp: {exp}"
184        )));
185    }
186    if exp < 0f64 || exp >= u32::MAX as f64 {
187        return Err(ArrowError::ArithmeticOverflow(format!(
188            "Unsupported exp value: {exp}"
189        )));
190    }
191    pow_decimal_int(base, scale, exp as i64)
192}
193
194impl ScalarUDFImpl for PowerFunc {
195    fn as_any(&self) -> &dyn Any {
196        self
197    }
198
199    fn name(&self) -> &str {
200        "power"
201    }
202
203    fn signature(&self) -> &Signature {
204        &self.signature
205    }
206
207    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
208        if arg_types[0].is_null() {
209            Ok(DataType::Float64)
210        } else {
211            Ok(arg_types[0].clone())
212        }
213    }
214
215    fn aliases(&self) -> &[String] {
216        &self.aliases
217    }
218
219    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
220        let [base, exponent] = take_function_args(self.name(), &args.args)?;
221        let base = base.to_array(args.number_rows)?;
222
223        let arr: ArrayRef = match (base.data_type(), exponent.data_type()) {
224            (DataType::Float64, DataType::Float64) => {
225                calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
226                    &base,
227                    exponent,
228                    |b, e| Ok(f64::powf(b, e)),
229                )?
230            }
231            (DataType::Decimal32(precision, scale), DataType::Int64) => {
232                calculate_binary_decimal_math::<Decimal32Type, Int64Type, Decimal32Type, _>(
233                    &base,
234                    exponent,
235                    |b, e| pow_decimal_int(b, *scale, e),
236                    *precision,
237                    *scale,
238                )?
239            }
240            (DataType::Decimal32(precision, scale), DataType::Float64) => {
241                calculate_binary_decimal_math::<
242                    Decimal32Type,
243                    Float64Type,
244                    Decimal32Type,
245                    _,
246                >(
247                    &base,
248                    exponent,
249                    |b, e| pow_decimal_float(b, *scale, e),
250                    *precision,
251                    *scale,
252                )?
253            }
254            (DataType::Decimal64(precision, scale), DataType::Int64) => {
255                calculate_binary_decimal_math::<Decimal64Type, Int64Type, Decimal64Type, _>(
256                    &base,
257                    exponent,
258                    |b, e| pow_decimal_int(b, *scale, e),
259                    *precision,
260                    *scale,
261                )?
262            }
263            (DataType::Decimal64(precision, scale), DataType::Float64) => {
264                calculate_binary_decimal_math::<
265                    Decimal64Type,
266                    Float64Type,
267                    Decimal64Type,
268                    _,
269                >(
270                    &base,
271                    exponent,
272                    |b, e| pow_decimal_float(b, *scale, e),
273                    *precision,
274                    *scale,
275                )?
276            }
277            (DataType::Decimal128(precision, scale), DataType::Int64) => {
278                calculate_binary_decimal_math::<
279                    Decimal128Type,
280                    Int64Type,
281                    Decimal128Type,
282                    _,
283                >(
284                    &base,
285                    exponent,
286                    |b, e| pow_decimal_int(b, *scale, e),
287                    *precision,
288                    *scale,
289                )?
290            }
291            (DataType::Decimal128(precision, scale), DataType::Float64) => {
292                calculate_binary_decimal_math::<
293                    Decimal128Type,
294                    Float64Type,
295                    Decimal128Type,
296                    _,
297                >(
298                    &base,
299                    exponent,
300                    |b, e| pow_decimal_float(b, *scale, e),
301                    *precision,
302                    *scale,
303                )?
304            }
305            (DataType::Decimal256(precision, scale), DataType::Int64) => {
306                calculate_binary_decimal_math::<
307                    Decimal256Type,
308                    Int64Type,
309                    Decimal256Type,
310                    _,
311                >(
312                    &base,
313                    exponent,
314                    |b, e| pow_decimal_int(b, *scale, e),
315                    *precision,
316                    *scale,
317                )?
318            }
319            (DataType::Decimal256(precision, scale), DataType::Float64) => {
320                calculate_binary_decimal_math::<
321                    Decimal256Type,
322                    Float64Type,
323                    Decimal256Type,
324                    _,
325                >(
326                    &base,
327                    exponent,
328                    |b, e| pow_decimal_float(b, *scale, e),
329                    *precision,
330                    *scale,
331                )?
332            }
333            (base_type, exp_type) => {
334                return internal_err!(
335                    "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for power"
336                );
337            }
338        };
339        Ok(ColumnarValue::Array(arr))
340    }
341
342    /// Simplify the `power` function by the relevant rules:
343    /// 1. Power(a, 0) ===> 1
344    /// 2. Power(a, 1) ===> a
345    /// 3. Power(a, Log(a, b)) ===> b
346    fn simplify(
347        &self,
348        args: Vec<Expr>,
349        info: &dyn SimplifyInfo,
350    ) -> Result<ExprSimplifyResult> {
351        let [base, exponent] = take_function_args("power", args)?;
352        let base_type = info.get_data_type(&base)?;
353        let exponent_type = info.get_data_type(&exponent)?;
354
355        // Null propagation
356        if base_type.is_null() || exponent_type.is_null() {
357            let return_type = self.return_type(&[base_type, exponent_type])?;
358            return Ok(ExprSimplifyResult::Simplified(lit(
359                ScalarValue::Null.cast_to(&return_type)?
360            )));
361        }
362
363        match exponent {
364            Expr::Literal(value, _)
365                if value == ScalarValue::new_zero(&exponent_type)? =>
366            {
367                Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
368                    &base_type,
369                )?)))
370            }
371            Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
372                Ok(ExprSimplifyResult::Simplified(base))
373            }
374            Expr::ScalarFunction(ScalarFunction { func, mut args })
375                if is_log(&func) && args.len() == 2 && base == args[0] =>
376            {
377                let b = args.pop().unwrap(); // length checked above
378                Ok(ExprSimplifyResult::Simplified(b))
379            }
380            _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
381        }
382    }
383
384    fn documentation(&self) -> Option<&Documentation> {
385        self.doc()
386    }
387}
388
389/// Return true if this function call is a call to `Log`
390fn is_log(func: &ScalarUDF) -> bool {
391    func.inner().as_any().downcast_ref::<LogFunc>().is_some()
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    #[test]
399    fn test_pow_decimal128_helper() {
400        // Expression: 2.5 ^ 4 = 39.0625
401        assert_eq!(pow_decimal_int(25, 1, 4).unwrap(), i128::from(390));
402        assert_eq!(pow_decimal_int(2500, 3, 4).unwrap(), i128::from(39062));
403        assert_eq!(pow_decimal_int(25000, 4, 4).unwrap(), i128::from(390625));
404
405        // Expression: 25 ^ 4 = 390625
406        assert_eq!(pow_decimal_int(25, 0, 4).unwrap(), i128::from(390625));
407
408        // Expressions for edge cases
409        assert_eq!(pow_decimal_int(25, 1, 1).unwrap(), i128::from(25));
410        assert_eq!(pow_decimal_int(25, 0, 1).unwrap(), i128::from(25));
411        assert_eq!(pow_decimal_int(25, 0, 0).unwrap(), i128::from(1));
412        assert_eq!(pow_decimal_int(25, 1, 0).unwrap(), i128::from(10));
413
414        assert_eq!(pow_decimal_int(25, -1, 4).unwrap(), i128::from(390625000));
415    }
416}