datafusion_functions/math/
power.rs1use std::any::Any;
20use std::sync::Arc;
21
22use super::log::LogFunc;
23
24use arrow::array::{ArrayRef, AsArray, Int64Array};
25use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type};
26use datafusion_common::{
27 arrow_datafusion_err, exec_datafusion_err, exec_err, internal_datafusion_err,
28 plan_datafusion_err, DataFusionError, Result, ScalarValue,
29};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{
33 ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, TypeSignature,
34};
35use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "Math Functions"),
40 description = "Returns a base expression raised to the power of an exponent.",
41 syntax_example = "power(base, exponent)",
42 standard_argument(name = "base", prefix = "Numeric"),
43 standard_argument(name = "exponent", prefix = "Exponent numeric")
44)]
45#[derive(Debug)]
46pub struct PowerFunc {
47 signature: Signature,
48 aliases: Vec<String>,
49}
50
51impl Default for PowerFunc {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl PowerFunc {
58 pub fn new() -> Self {
59 use DataType::*;
60 Self {
61 signature: Signature::one_of(
62 vec![
63 TypeSignature::Exact(vec![Int64, Int64]),
64 TypeSignature::Exact(vec![Float64, Float64]),
65 ],
66 Volatility::Immutable,
67 ),
68 aliases: vec![String::from("pow")],
69 }
70 }
71}
72
73impl ScalarUDFImpl for PowerFunc {
74 fn as_any(&self) -> &dyn Any {
75 self
76 }
77 fn name(&self) -> &str {
78 "power"
79 }
80
81 fn signature(&self) -> &Signature {
82 &self.signature
83 }
84
85 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
86 match arg_types[0] {
87 DataType::Int64 => Ok(DataType::Int64),
88 _ => Ok(DataType::Float64),
89 }
90 }
91
92 fn aliases(&self) -> &[String] {
93 &self.aliases
94 }
95
96 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
97 let args = ColumnarValue::values_to_arrays(&args.args)?;
98
99 let arr: ArrayRef = match args[0].data_type() {
100 DataType::Float64 => {
101 let bases = args[0].as_primitive::<Float64Type>();
102 let exponents = args[1].as_primitive::<Float64Type>();
103 let result = arrow::compute::binary::<_, _, _, Float64Type>(
104 bases,
105 exponents,
106 f64::powf,
107 )?;
108 Arc::new(result) as _
109 }
110 DataType::Int64 => {
111 let bases = downcast_named_arg!(&args[0], "base", Int64Array);
112 let exponents = downcast_named_arg!(&args[1], "exponent", Int64Array);
113 bases
114 .iter()
115 .zip(exponents.iter())
116 .map(|(base, exp)| match (base, exp) {
117 (Some(base), Some(exp)) => Ok(Some(base.pow_checked(
118 exp.try_into().map_err(|_| {
119 exec_datafusion_err!(
120 "Can't use negative exponents: {exp} in integer computation, please use Float."
121 )
122 })?,
123 ).map_err(|e| arrow_datafusion_err!(e))?)),
124 _ => Ok(None),
125 })
126 .collect::<Result<Int64Array>>()
127 .map(Arc::new)? as _
128 }
129
130 other => {
131 return exec_err!(
132 "Unsupported data type {other:?} for function {}",
133 self.name()
134 )
135 }
136 };
137
138 Ok(ColumnarValue::Array(arr))
139 }
140
141 fn simplify(
146 &self,
147 mut args: Vec<Expr>,
148 info: &dyn SimplifyInfo,
149 ) -> Result<ExprSimplifyResult> {
150 let exponent = args.pop().ok_or_else(|| {
151 plan_datafusion_err!("Expected power to have 2 arguments, got 0")
152 })?;
153 let base = args.pop().ok_or_else(|| {
154 plan_datafusion_err!("Expected power to have 2 arguments, got 1")
155 })?;
156
157 let exponent_type = info.get_data_type(&exponent)?;
158 match exponent {
159 Expr::Literal(value, _)
160 if value == ScalarValue::new_zero(&exponent_type)? =>
161 {
162 Ok(ExprSimplifyResult::Simplified(Expr::Literal(
163 ScalarValue::new_one(&info.get_data_type(&base)?)?,
164 None,
165 )))
166 }
167 Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
168 Ok(ExprSimplifyResult::Simplified(base))
169 }
170 Expr::ScalarFunction(ScalarFunction { func, mut args })
171 if is_log(&func) && args.len() == 2 && base == args[0] =>
172 {
173 let b = args.pop().unwrap(); Ok(ExprSimplifyResult::Simplified(b))
175 }
176 _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
177 }
178 }
179
180 fn documentation(&self) -> Option<&Documentation> {
181 self.doc()
182 }
183}
184
185fn is_log(func: &ScalarUDF) -> bool {
187 func.inner().as_any().downcast_ref::<LogFunc>().is_some()
188}
189
190#[cfg(test)]
191mod tests {
192 use arrow::array::Float64Array;
193 use arrow::datatypes::Field;
194 use datafusion_common::cast::{as_float64_array, as_int64_array};
195
196 use super::*;
197
198 #[test]
199 fn test_power_f64() {
200 let arg_fields = vec![
201 Field::new("a", DataType::Float64, true).into(),
202 Field::new("a", DataType::Float64, true).into(),
203 ];
204 let args = ScalarFunctionArgs {
205 args: vec![
206 ColumnarValue::Array(Arc::new(Float64Array::from(vec![
207 2.0, 2.0, 3.0, 5.0,
208 ]))), ColumnarValue::Array(Arc::new(Float64Array::from(vec![
210 3.0, 2.0, 4.0, 4.0,
211 ]))), ],
213 arg_fields,
214 number_rows: 4,
215 return_field: Field::new("f", DataType::Float64, true).into(),
216 };
217 let result = PowerFunc::new()
218 .invoke_with_args(args)
219 .expect("failed to initialize function power");
220
221 match result {
222 ColumnarValue::Array(arr) => {
223 let floats = as_float64_array(&arr)
224 .expect("failed to convert result to a Float64Array");
225 assert_eq!(floats.len(), 4);
226 assert_eq!(floats.value(0), 8.0);
227 assert_eq!(floats.value(1), 4.0);
228 assert_eq!(floats.value(2), 81.0);
229 assert_eq!(floats.value(3), 625.0);
230 }
231 ColumnarValue::Scalar(_) => {
232 panic!("Expected an array value")
233 }
234 }
235 }
236
237 #[test]
238 fn test_power_i64() {
239 let arg_fields = vec![
240 Field::new("a", DataType::Int64, true).into(),
241 Field::new("a", DataType::Int64, true).into(),
242 ];
243 let args = ScalarFunctionArgs {
244 args: vec![
245 ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), ],
248 arg_fields,
249 number_rows: 4,
250 return_field: Field::new("f", DataType::Int64, true).into(),
251 };
252 let result = PowerFunc::new()
253 .invoke_with_args(args)
254 .expect("failed to initialize function power");
255
256 match result {
257 ColumnarValue::Array(arr) => {
258 let ints = as_int64_array(&arr)
259 .expect("failed to convert result to a Int64Array");
260
261 assert_eq!(ints.len(), 4);
262 assert_eq!(ints.value(0), 8);
263 assert_eq!(ints.value(1), 4);
264 assert_eq!(ints.value(2), 81);
265 assert_eq!(ints.value(3), 625);
266 }
267 ColumnarValue::Scalar(_) => {
268 panic!("Expected an array value")
269 }
270 }
271 }
272}