datafusion_functions/math/
power.rs1use std::any::Any;
20use std::sync::Arc;
21
22use super::log::LogFunc;
23
24use arrow::array::{ArrayRef, AsArray, Int64Array};
25use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type};
26use datafusion_common::{
27 arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err,
28 DataFusionError, Result, ScalarValue,
29};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{
33 ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, TypeSignature,
34};
35use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "Math Functions"),
40 description = "Returns a base expression raised to the power of an exponent.",
41 syntax_example = "power(base, exponent)",
42 sql_example = r#"```sql
43> SELECT power(2, 3);
44+-------------+
45| power(2,3) |
46+-------------+
47| 8 |
48+-------------+
49```"#,
50 standard_argument(name = "base", prefix = "Numeric"),
51 standard_argument(name = "exponent", prefix = "Exponent numeric")
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct PowerFunc {
55 signature: Signature,
56 aliases: Vec<String>,
57}
58
59impl Default for PowerFunc {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl PowerFunc {
66 pub fn new() -> Self {
67 use DataType::*;
68 Self {
69 signature: Signature::one_of(
70 vec![
71 TypeSignature::Exact(vec![Int64, Int64]),
72 TypeSignature::Exact(vec![Float64, Float64]),
73 ],
74 Volatility::Immutable,
75 ),
76 aliases: vec![String::from("pow")],
77 }
78 }
79}
80
81impl ScalarUDFImpl for PowerFunc {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85 fn name(&self) -> &str {
86 "power"
87 }
88
89 fn signature(&self) -> &Signature {
90 &self.signature
91 }
92
93 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94 match arg_types[0] {
95 DataType::Int64 => Ok(DataType::Int64),
96 _ => Ok(DataType::Float64),
97 }
98 }
99
100 fn aliases(&self) -> &[String] {
101 &self.aliases
102 }
103
104 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105 let args = ColumnarValue::values_to_arrays(&args.args)?;
106
107 let arr: ArrayRef = match args[0].data_type() {
108 DataType::Float64 => {
109 let bases = args[0].as_primitive::<Float64Type>();
110 let exponents = args[1].as_primitive::<Float64Type>();
111 let result = arrow::compute::binary::<_, _, _, Float64Type>(
112 bases,
113 exponents,
114 f64::powf,
115 )?;
116 Arc::new(result) as _
117 }
118 DataType::Int64 => {
119 let bases = downcast_named_arg!(&args[0], "base", Int64Array);
120 let exponents = downcast_named_arg!(&args[1], "exponent", Int64Array);
121 bases
122 .iter()
123 .zip(exponents.iter())
124 .map(|(base, exp)| match (base, exp) {
125 (Some(base), Some(exp)) => Ok(Some(base.pow_checked(
126 exp.try_into().map_err(|_| {
127 exec_datafusion_err!(
128 "Can't use negative exponents: {exp} in integer computation, please use Float."
129 )
130 })?,
131 ).map_err(|e| arrow_datafusion_err!(e))?)),
132 _ => Ok(None),
133 })
134 .collect::<Result<Int64Array>>()
135 .map(Arc::new)? as _
136 }
137
138 other => {
139 return exec_err!(
140 "Unsupported data type {other:?} for function {}",
141 self.name()
142 )
143 }
144 };
145
146 Ok(ColumnarValue::Array(arr))
147 }
148
149 fn simplify(
154 &self,
155 mut args: Vec<Expr>,
156 info: &dyn SimplifyInfo,
157 ) -> Result<ExprSimplifyResult> {
158 let exponent = args.pop().ok_or_else(|| {
159 plan_datafusion_err!("Expected power to have 2 arguments, got 0")
160 })?;
161 let base = args.pop().ok_or_else(|| {
162 plan_datafusion_err!("Expected power to have 2 arguments, got 1")
163 })?;
164
165 let exponent_type = info.get_data_type(&exponent)?;
166 match exponent {
167 Expr::Literal(value, _)
168 if value == ScalarValue::new_zero(&exponent_type)? =>
169 {
170 Ok(ExprSimplifyResult::Simplified(Expr::Literal(
171 ScalarValue::new_one(&info.get_data_type(&base)?)?,
172 None,
173 )))
174 }
175 Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
176 Ok(ExprSimplifyResult::Simplified(base))
177 }
178 Expr::ScalarFunction(ScalarFunction { func, mut args })
179 if is_log(&func) && args.len() == 2 && base == args[0] =>
180 {
181 let b = args.pop().unwrap(); Ok(ExprSimplifyResult::Simplified(b))
183 }
184 _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
185 }
186 }
187
188 fn documentation(&self) -> Option<&Documentation> {
189 self.doc()
190 }
191}
192
193fn is_log(func: &ScalarUDF) -> bool {
195 func.inner().as_any().downcast_ref::<LogFunc>().is_some()
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use arrow::array::Float64Array;
202 use arrow::datatypes::Field;
203 use datafusion_common::cast::{as_float64_array, as_int64_array};
204 use datafusion_common::config::ConfigOptions;
205
206 #[test]
207 fn test_power_f64() {
208 let arg_fields = vec![
209 Field::new("a", DataType::Float64, true).into(),
210 Field::new("a", DataType::Float64, true).into(),
211 ];
212 let args = ScalarFunctionArgs {
213 args: vec![
214 ColumnarValue::Array(Arc::new(Float64Array::from(vec![
215 2.0, 2.0, 3.0, 5.0,
216 ]))), ColumnarValue::Array(Arc::new(Float64Array::from(vec![
218 3.0, 2.0, 4.0, 4.0,
219 ]))), ],
221 arg_fields,
222 number_rows: 4,
223 return_field: Field::new("f", DataType::Float64, true).into(),
224 config_options: Arc::new(ConfigOptions::default()),
225 };
226 let result = PowerFunc::new()
227 .invoke_with_args(args)
228 .expect("failed to initialize function power");
229
230 match result {
231 ColumnarValue::Array(arr) => {
232 let floats = as_float64_array(&arr)
233 .expect("failed to convert result to a Float64Array");
234 assert_eq!(floats.len(), 4);
235 assert_eq!(floats.value(0), 8.0);
236 assert_eq!(floats.value(1), 4.0);
237 assert_eq!(floats.value(2), 81.0);
238 assert_eq!(floats.value(3), 625.0);
239 }
240 ColumnarValue::Scalar(_) => {
241 panic!("Expected an array value")
242 }
243 }
244 }
245
246 #[test]
247 fn test_power_i64() {
248 let arg_fields = vec![
249 Field::new("a", DataType::Int64, true).into(),
250 Field::new("a", DataType::Int64, true).into(),
251 ];
252 let args = ScalarFunctionArgs {
253 args: vec![
254 ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), ],
257 arg_fields,
258 number_rows: 4,
259 return_field: Field::new("f", DataType::Int64, true).into(),
260 config_options: Arc::new(ConfigOptions::default()),
261 };
262 let result = PowerFunc::new()
263 .invoke_with_args(args)
264 .expect("failed to initialize function power");
265
266 match result {
267 ColumnarValue::Array(arr) => {
268 let ints = as_int64_array(&arr)
269 .expect("failed to convert result to a Int64Array");
270
271 assert_eq!(ints.len(), 4);
272 assert_eq!(ints.value(0), 8);
273 assert_eq!(ints.value(1), 4);
274 assert_eq!(ints.value(2), 81);
275 assert_eq!(ints.value(3), 625);
276 }
277 ColumnarValue::Scalar(_) => {
278 panic!("Expected an array value")
279 }
280 }
281 }
282}