datafusion_functions/math/
power.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Math function: `power()`.
19use std::any::Any;
20use std::sync::Arc;
21
22use super::log::LogFunc;
23
24use arrow::array::{ArrayRef, AsArray, Int64Array};
25use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type};
26use datafusion_common::{
27    arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err,
28    DataFusionError, Result, ScalarValue,
29};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{
33    ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, TypeSignature,
34};
35use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "Math Functions"),
40    description = "Returns a base expression raised to the power of an exponent.",
41    syntax_example = "power(base, exponent)",
42    sql_example = r#"```sql
43> SELECT power(2, 3);
44+-------------+
45| power(2,3)  |
46+-------------+
47| 8           |
48+-------------+
49```"#,
50    standard_argument(name = "base", prefix = "Numeric"),
51    standard_argument(name = "exponent", prefix = "Exponent numeric")
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct PowerFunc {
55    signature: Signature,
56    aliases: Vec<String>,
57}
58
59impl Default for PowerFunc {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl PowerFunc {
66    pub fn new() -> Self {
67        use DataType::*;
68        Self {
69            signature: Signature::one_of(
70                vec![
71                    TypeSignature::Exact(vec![Int64, Int64]),
72                    TypeSignature::Exact(vec![Float64, Float64]),
73                ],
74                Volatility::Immutable,
75            ),
76            aliases: vec![String::from("pow")],
77        }
78    }
79}
80
81impl ScalarUDFImpl for PowerFunc {
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85    fn name(&self) -> &str {
86        "power"
87    }
88
89    fn signature(&self) -> &Signature {
90        &self.signature
91    }
92
93    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94        match arg_types[0] {
95            DataType::Int64 => Ok(DataType::Int64),
96            _ => Ok(DataType::Float64),
97        }
98    }
99
100    fn aliases(&self) -> &[String] {
101        &self.aliases
102    }
103
104    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105        let args = ColumnarValue::values_to_arrays(&args.args)?;
106
107        let arr: ArrayRef = match args[0].data_type() {
108            DataType::Float64 => {
109                let bases = args[0].as_primitive::<Float64Type>();
110                let exponents = args[1].as_primitive::<Float64Type>();
111                let result = arrow::compute::binary::<_, _, _, Float64Type>(
112                    bases,
113                    exponents,
114                    f64::powf,
115                )?;
116                Arc::new(result) as _
117            }
118            DataType::Int64 => {
119                let bases = downcast_named_arg!(&args[0], "base", Int64Array);
120                let exponents = downcast_named_arg!(&args[1], "exponent", Int64Array);
121                bases
122                    .iter()
123                    .zip(exponents.iter())
124                    .map(|(base, exp)| match (base, exp) {
125                        (Some(base), Some(exp)) => Ok(Some(base.pow_checked(
126                            exp.try_into().map_err(|_| {
127                                exec_datafusion_err!(
128                                    "Can't use negative exponents: {exp} in integer computation, please use Float."
129                                )
130                            })?,
131                        ).map_err(|e| arrow_datafusion_err!(e))?)),
132                        _ => Ok(None),
133                    })
134                    .collect::<Result<Int64Array>>()
135                    .map(Arc::new)? as _
136            }
137
138            other => {
139                return exec_err!(
140                    "Unsupported data type {other:?} for function {}",
141                    self.name()
142                )
143            }
144        };
145
146        Ok(ColumnarValue::Array(arr))
147    }
148
149    /// Simplify the `power` function by the relevant rules:
150    /// 1. Power(a, 0) ===> 0
151    /// 2. Power(a, 1) ===> a
152    /// 3. Power(a, Log(a, b)) ===> b
153    fn simplify(
154        &self,
155        mut args: Vec<Expr>,
156        info: &dyn SimplifyInfo,
157    ) -> Result<ExprSimplifyResult> {
158        let exponent = args.pop().ok_or_else(|| {
159            plan_datafusion_err!("Expected power to have 2 arguments, got 0")
160        })?;
161        let base = args.pop().ok_or_else(|| {
162            plan_datafusion_err!("Expected power to have 2 arguments, got 1")
163        })?;
164
165        let exponent_type = info.get_data_type(&exponent)?;
166        match exponent {
167            Expr::Literal(value, _)
168                if value == ScalarValue::new_zero(&exponent_type)? =>
169            {
170                Ok(ExprSimplifyResult::Simplified(Expr::Literal(
171                    ScalarValue::new_one(&info.get_data_type(&base)?)?,
172                    None,
173                )))
174            }
175            Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
176                Ok(ExprSimplifyResult::Simplified(base))
177            }
178            Expr::ScalarFunction(ScalarFunction { func, mut args })
179                if is_log(&func) && args.len() == 2 && base == args[0] =>
180            {
181                let b = args.pop().unwrap(); // length checked above
182                Ok(ExprSimplifyResult::Simplified(b))
183            }
184            _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
185        }
186    }
187
188    fn documentation(&self) -> Option<&Documentation> {
189        self.doc()
190    }
191}
192
193/// Return true if this function call is a call to `Log`
194fn is_log(func: &ScalarUDF) -> bool {
195    func.inner().as_any().downcast_ref::<LogFunc>().is_some()
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use arrow::array::Float64Array;
202    use arrow::datatypes::Field;
203    use datafusion_common::cast::{as_float64_array, as_int64_array};
204    use datafusion_common::config::ConfigOptions;
205
206    #[test]
207    fn test_power_f64() {
208        let arg_fields = vec![
209            Field::new("a", DataType::Float64, true).into(),
210            Field::new("a", DataType::Float64, true).into(),
211        ];
212        let args = ScalarFunctionArgs {
213            args: vec![
214                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
215                    2.0, 2.0, 3.0, 5.0,
216                ]))), // base
217                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
218                    3.0, 2.0, 4.0, 4.0,
219                ]))), // exponent
220            ],
221            arg_fields,
222            number_rows: 4,
223            return_field: Field::new("f", DataType::Float64, true).into(),
224            config_options: Arc::new(ConfigOptions::default()),
225        };
226        let result = PowerFunc::new()
227            .invoke_with_args(args)
228            .expect("failed to initialize function power");
229
230        match result {
231            ColumnarValue::Array(arr) => {
232                let floats = as_float64_array(&arr)
233                    .expect("failed to convert result to a Float64Array");
234                assert_eq!(floats.len(), 4);
235                assert_eq!(floats.value(0), 8.0);
236                assert_eq!(floats.value(1), 4.0);
237                assert_eq!(floats.value(2), 81.0);
238                assert_eq!(floats.value(3), 625.0);
239            }
240            ColumnarValue::Scalar(_) => {
241                panic!("Expected an array value")
242            }
243        }
244    }
245
246    #[test]
247    fn test_power_i64() {
248        let arg_fields = vec![
249            Field::new("a", DataType::Int64, true).into(),
250            Field::new("a", DataType::Int64, true).into(),
251        ];
252        let args = ScalarFunctionArgs {
253            args: vec![
254                ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
255                ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
256            ],
257            arg_fields,
258            number_rows: 4,
259            return_field: Field::new("f", DataType::Int64, true).into(),
260            config_options: Arc::new(ConfigOptions::default()),
261        };
262        let result = PowerFunc::new()
263            .invoke_with_args(args)
264            .expect("failed to initialize function power");
265
266        match result {
267            ColumnarValue::Array(arr) => {
268                let ints = as_int64_array(&arr)
269                    .expect("failed to convert result to a Int64Array");
270
271                assert_eq!(ints.len(), 4);
272                assert_eq!(ints.value(0), 8);
273                assert_eq!(ints.value(1), 4);
274                assert_eq!(ints.value(2), 81);
275                assert_eq!(ints.value(3), 625);
276            }
277            ColumnarValue::Scalar(_) => {
278                panic!("Expected an array value")
279            }
280        }
281    }
282}