datafusion_functions/math/
power.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Math function: `power()`.
19use std::any::Any;
20use std::sync::Arc;
21
22use super::log::LogFunc;
23
24use arrow::array::{ArrayRef, AsArray, Int64Array};
25use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type};
26use datafusion_common::{
27    arrow_datafusion_err, exec_datafusion_err, exec_err, internal_datafusion_err,
28    plan_datafusion_err, DataFusionError, Result, ScalarValue,
29};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{
33    ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, TypeSignature,
34};
35use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "Math Functions"),
40    description = "Returns a base expression raised to the power of an exponent.",
41    syntax_example = "power(base, exponent)",
42    standard_argument(name = "base", prefix = "Numeric"),
43    standard_argument(name = "exponent", prefix = "Exponent numeric")
44)]
45#[derive(Debug)]
46pub struct PowerFunc {
47    signature: Signature,
48    aliases: Vec<String>,
49}
50
51impl Default for PowerFunc {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl PowerFunc {
58    pub fn new() -> Self {
59        use DataType::*;
60        Self {
61            signature: Signature::one_of(
62                vec![
63                    TypeSignature::Exact(vec![Int64, Int64]),
64                    TypeSignature::Exact(vec![Float64, Float64]),
65                ],
66                Volatility::Immutable,
67            ),
68            aliases: vec![String::from("pow")],
69        }
70    }
71}
72
73impl ScalarUDFImpl for PowerFunc {
74    fn as_any(&self) -> &dyn Any {
75        self
76    }
77    fn name(&self) -> &str {
78        "power"
79    }
80
81    fn signature(&self) -> &Signature {
82        &self.signature
83    }
84
85    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
86        match arg_types[0] {
87            DataType::Int64 => Ok(DataType::Int64),
88            _ => Ok(DataType::Float64),
89        }
90    }
91
92    fn aliases(&self) -> &[String] {
93        &self.aliases
94    }
95
96    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
97        let args = ColumnarValue::values_to_arrays(&args.args)?;
98
99        let arr: ArrayRef = match args[0].data_type() {
100            DataType::Float64 => {
101                let bases = args[0].as_primitive::<Float64Type>();
102                let exponents = args[1].as_primitive::<Float64Type>();
103                let result = arrow::compute::binary::<_, _, _, Float64Type>(
104                    bases,
105                    exponents,
106                    f64::powf,
107                )?;
108                Arc::new(result) as _
109            }
110            DataType::Int64 => {
111                let bases = downcast_named_arg!(&args[0], "base", Int64Array);
112                let exponents = downcast_named_arg!(&args[1], "exponent", Int64Array);
113                bases
114                    .iter()
115                    .zip(exponents.iter())
116                    .map(|(base, exp)| match (base, exp) {
117                        (Some(base), Some(exp)) => Ok(Some(base.pow_checked(
118                            exp.try_into().map_err(|_| {
119                                exec_datafusion_err!(
120                                    "Can't use negative exponents: {exp} in integer computation, please use Float."
121                                )
122                            })?,
123                        ).map_err(|e| arrow_datafusion_err!(e))?)),
124                        _ => Ok(None),
125                    })
126                    .collect::<Result<Int64Array>>()
127                    .map(Arc::new)? as _
128            }
129
130            other => {
131                return exec_err!(
132                    "Unsupported data type {other:?} for function {}",
133                    self.name()
134                )
135            }
136        };
137
138        Ok(ColumnarValue::Array(arr))
139    }
140
141    /// Simplify the `power` function by the relevant rules:
142    /// 1. Power(a, 0) ===> 0
143    /// 2. Power(a, 1) ===> a
144    /// 3. Power(a, Log(a, b)) ===> b
145    fn simplify(
146        &self,
147        mut args: Vec<Expr>,
148        info: &dyn SimplifyInfo,
149    ) -> Result<ExprSimplifyResult> {
150        let exponent = args.pop().ok_or_else(|| {
151            plan_datafusion_err!("Expected power to have 2 arguments, got 0")
152        })?;
153        let base = args.pop().ok_or_else(|| {
154            plan_datafusion_err!("Expected power to have 2 arguments, got 1")
155        })?;
156
157        let exponent_type = info.get_data_type(&exponent)?;
158        match exponent {
159            Expr::Literal(value, _)
160                if value == ScalarValue::new_zero(&exponent_type)? =>
161            {
162                Ok(ExprSimplifyResult::Simplified(Expr::Literal(
163                    ScalarValue::new_one(&info.get_data_type(&base)?)?,
164                    None,
165                )))
166            }
167            Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
168                Ok(ExprSimplifyResult::Simplified(base))
169            }
170            Expr::ScalarFunction(ScalarFunction { func, mut args })
171                if is_log(&func) && args.len() == 2 && base == args[0] =>
172            {
173                let b = args.pop().unwrap(); // length checked above
174                Ok(ExprSimplifyResult::Simplified(b))
175            }
176            _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
177        }
178    }
179
180    fn documentation(&self) -> Option<&Documentation> {
181        self.doc()
182    }
183}
184
185/// Return true if this function call is a call to `Log`
186fn is_log(func: &ScalarUDF) -> bool {
187    func.inner().as_any().downcast_ref::<LogFunc>().is_some()
188}
189
190#[cfg(test)]
191mod tests {
192    use arrow::array::Float64Array;
193    use arrow::datatypes::Field;
194    use datafusion_common::cast::{as_float64_array, as_int64_array};
195
196    use super::*;
197
198    #[test]
199    fn test_power_f64() {
200        let arg_fields = vec![
201            Field::new("a", DataType::Float64, true).into(),
202            Field::new("a", DataType::Float64, true).into(),
203        ];
204        let args = ScalarFunctionArgs {
205            args: vec![
206                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
207                    2.0, 2.0, 3.0, 5.0,
208                ]))), // base
209                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
210                    3.0, 2.0, 4.0, 4.0,
211                ]))), // exponent
212            ],
213            arg_fields,
214            number_rows: 4,
215            return_field: Field::new("f", DataType::Float64, true).into(),
216        };
217        let result = PowerFunc::new()
218            .invoke_with_args(args)
219            .expect("failed to initialize function power");
220
221        match result {
222            ColumnarValue::Array(arr) => {
223                let floats = as_float64_array(&arr)
224                    .expect("failed to convert result to a Float64Array");
225                assert_eq!(floats.len(), 4);
226                assert_eq!(floats.value(0), 8.0);
227                assert_eq!(floats.value(1), 4.0);
228                assert_eq!(floats.value(2), 81.0);
229                assert_eq!(floats.value(3), 625.0);
230            }
231            ColumnarValue::Scalar(_) => {
232                panic!("Expected an array value")
233            }
234        }
235    }
236
237    #[test]
238    fn test_power_i64() {
239        let arg_fields = vec![
240            Field::new("a", DataType::Int64, true).into(),
241            Field::new("a", DataType::Int64, true).into(),
242        ];
243        let args = ScalarFunctionArgs {
244            args: vec![
245                ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
246                ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
247            ],
248            arg_fields,
249            number_rows: 4,
250            return_field: Field::new("f", DataType::Int64, true).into(),
251        };
252        let result = PowerFunc::new()
253            .invoke_with_args(args)
254            .expect("failed to initialize function power");
255
256        match result {
257            ColumnarValue::Array(arr) => {
258                let ints = as_int64_array(&arr)
259                    .expect("failed to convert result to a Int64Array");
260
261                assert_eq!(ints.len(), 4);
262                assert_eq!(ints.value(0), 8);
263                assert_eq!(ints.value(1), 4);
264                assert_eq!(ints.value(2), 81);
265                assert_eq!(ints.value(3), 625);
266            }
267            ColumnarValue::Scalar(_) => {
268                panic!("Expected an array value")
269            }
270        }
271    }
272}