1use std::any::Any;
20
21use super::log::LogFunc;
22
23use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
24use arrow::array::{Array, ArrayRef};
25use arrow::datatypes::{
26 ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
27 Decimal256Type, Float64Type, Int64Type,
28};
29use arrow::error::ArrowError;
30use datafusion_common::types::{NativeType, logical_float64, logical_int64};
31use datafusion_common::utils::take_function_args;
32use datafusion_common::{Result, ScalarValue, internal_err};
33use datafusion_expr::expr::ScalarFunction;
34use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
35use datafusion_expr::{
36 Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF,
37 ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit,
38};
39use datafusion_macros::user_doc;
40
41#[user_doc(
42 doc_section(label = "Math Functions"),
43 description = "Returns a base expression raised to the power of an exponent.",
44 syntax_example = "power(base, exponent)",
45 sql_example = r#"```sql
46> SELECT power(2, 3);
47+-------------+
48| power(2,3) |
49+-------------+
50| 8 |
51+-------------+
52```"#,
53 standard_argument(name = "base", prefix = "Numeric"),
54 standard_argument(name = "exponent", prefix = "Exponent numeric")
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct PowerFunc {
58 signature: Signature,
59 aliases: Vec<String>,
60}
61
62impl Default for PowerFunc {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl PowerFunc {
69 pub fn new() -> Self {
70 let integer = Coercion::new_implicit(
71 TypeSignatureClass::Native(logical_int64()),
72 vec![TypeSignatureClass::Integer],
73 NativeType::Int64,
74 );
75 let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
76 let float = Coercion::new_implicit(
77 TypeSignatureClass::Native(logical_float64()),
78 vec![TypeSignatureClass::Numeric],
79 NativeType::Float64,
80 );
81 Self {
82 signature: Signature::one_of(
83 vec![
84 TypeSignature::Coercible(vec![decimal.clone(), integer]),
85 TypeSignature::Coercible(vec![decimal.clone(), float.clone()]),
86 TypeSignature::Coercible(vec![float; 2]),
87 ],
88 Volatility::Immutable,
89 ),
90 aliases: vec![String::from("pow")],
91 }
92 }
93}
94
95fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
118where
119 T: From<i32> + ArrowNativeTypeOp,
120{
121 let exp: u32 = exp.try_into().map_err(|_| {
122 ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
123 })?;
124 if exp == 0 {
127 return if scale >= 0 {
128 T::from(10).pow_checked(scale as u32).map_err(|_| {
129 ArrowError::ArithmeticOverflow(format!(
130 "Cannot make unscale factor for {scale} and {exp}"
131 ))
132 })
133 } else {
134 Ok(T::from(0))
135 };
136 }
137 let powered: T = base.pow_checked(exp).map_err(|_| {
138 ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
139 })?;
140
141 let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
144
145 if mul_exp == 0 {
146 return Ok(powered);
147 }
148
149 if mul_exp > 0 {
152 let div_factor: T = T::from(10).pow_checked(mul_exp as u32).map_err(|_| {
153 ArrowError::ArithmeticOverflow(format!(
154 "Cannot make div factor for {scale} and {exp}"
155 ))
156 })?;
157 powered.div_checked(div_factor)
158 } else {
159 let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
161 ArrowError::ArithmeticOverflow(
162 "Overflow while negating scale exponent".to_string(),
163 )
164 })?;
165 let mul_factor: T = T::from(10).pow_checked(abs_exp as u32).map_err(|_| {
166 ArrowError::ArithmeticOverflow(format!(
167 "Cannot make mul factor for {scale} and {exp}"
168 ))
169 })?;
170 powered.mul_checked(mul_factor)
171 }
172}
173
174fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
178where
179 T: From<i32> + ArrowNativeTypeOp,
180{
181 if !exp.is_finite() || exp.trunc() != exp {
182 return Err(ArrowError::ComputeError(format!(
183 "Cannot use non-integer exp: {exp}"
184 )));
185 }
186 if exp < 0f64 || exp >= u32::MAX as f64 {
187 return Err(ArrowError::ArithmeticOverflow(format!(
188 "Unsupported exp value: {exp}"
189 )));
190 }
191 pow_decimal_int(base, scale, exp as i64)
192}
193
194impl ScalarUDFImpl for PowerFunc {
195 fn as_any(&self) -> &dyn Any {
196 self
197 }
198
199 fn name(&self) -> &str {
200 "power"
201 }
202
203 fn signature(&self) -> &Signature {
204 &self.signature
205 }
206
207 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
208 if arg_types[0].is_null() {
209 Ok(DataType::Float64)
210 } else {
211 Ok(arg_types[0].clone())
212 }
213 }
214
215 fn aliases(&self) -> &[String] {
216 &self.aliases
217 }
218
219 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
220 let [base, exponent] = take_function_args(self.name(), &args.args)?;
221 let base = base.to_array(args.number_rows)?;
222
223 let arr: ArrayRef = match (base.data_type(), exponent.data_type()) {
224 (DataType::Float64, DataType::Float64) => {
225 calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
226 &base,
227 exponent,
228 |b, e| Ok(f64::powf(b, e)),
229 )?
230 }
231 (DataType::Decimal32(precision, scale), DataType::Int64) => {
232 calculate_binary_decimal_math::<Decimal32Type, Int64Type, Decimal32Type, _>(
233 &base,
234 exponent,
235 |b, e| pow_decimal_int(b, *scale, e),
236 *precision,
237 *scale,
238 )?
239 }
240 (DataType::Decimal32(precision, scale), DataType::Float64) => {
241 calculate_binary_decimal_math::<
242 Decimal32Type,
243 Float64Type,
244 Decimal32Type,
245 _,
246 >(
247 &base,
248 exponent,
249 |b, e| pow_decimal_float(b, *scale, e),
250 *precision,
251 *scale,
252 )?
253 }
254 (DataType::Decimal64(precision, scale), DataType::Int64) => {
255 calculate_binary_decimal_math::<Decimal64Type, Int64Type, Decimal64Type, _>(
256 &base,
257 exponent,
258 |b, e| pow_decimal_int(b, *scale, e),
259 *precision,
260 *scale,
261 )?
262 }
263 (DataType::Decimal64(precision, scale), DataType::Float64) => {
264 calculate_binary_decimal_math::<
265 Decimal64Type,
266 Float64Type,
267 Decimal64Type,
268 _,
269 >(
270 &base,
271 exponent,
272 |b, e| pow_decimal_float(b, *scale, e),
273 *precision,
274 *scale,
275 )?
276 }
277 (DataType::Decimal128(precision, scale), DataType::Int64) => {
278 calculate_binary_decimal_math::<
279 Decimal128Type,
280 Int64Type,
281 Decimal128Type,
282 _,
283 >(
284 &base,
285 exponent,
286 |b, e| pow_decimal_int(b, *scale, e),
287 *precision,
288 *scale,
289 )?
290 }
291 (DataType::Decimal128(precision, scale), DataType::Float64) => {
292 calculate_binary_decimal_math::<
293 Decimal128Type,
294 Float64Type,
295 Decimal128Type,
296 _,
297 >(
298 &base,
299 exponent,
300 |b, e| pow_decimal_float(b, *scale, e),
301 *precision,
302 *scale,
303 )?
304 }
305 (DataType::Decimal256(precision, scale), DataType::Int64) => {
306 calculate_binary_decimal_math::<
307 Decimal256Type,
308 Int64Type,
309 Decimal256Type,
310 _,
311 >(
312 &base,
313 exponent,
314 |b, e| pow_decimal_int(b, *scale, e),
315 *precision,
316 *scale,
317 )?
318 }
319 (DataType::Decimal256(precision, scale), DataType::Float64) => {
320 calculate_binary_decimal_math::<
321 Decimal256Type,
322 Float64Type,
323 Decimal256Type,
324 _,
325 >(
326 &base,
327 exponent,
328 |b, e| pow_decimal_float(b, *scale, e),
329 *precision,
330 *scale,
331 )?
332 }
333 (base_type, exp_type) => {
334 return internal_err!(
335 "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for power"
336 );
337 }
338 };
339 Ok(ColumnarValue::Array(arr))
340 }
341
342 fn simplify(
347 &self,
348 args: Vec<Expr>,
349 info: &dyn SimplifyInfo,
350 ) -> Result<ExprSimplifyResult> {
351 let [base, exponent] = take_function_args("power", args)?;
352 let base_type = info.get_data_type(&base)?;
353 let exponent_type = info.get_data_type(&exponent)?;
354
355 if base_type.is_null() || exponent_type.is_null() {
357 let return_type = self.return_type(&[base_type, exponent_type])?;
358 return Ok(ExprSimplifyResult::Simplified(lit(
359 ScalarValue::Null.cast_to(&return_type)?
360 )));
361 }
362
363 match exponent {
364 Expr::Literal(value, _)
365 if value == ScalarValue::new_zero(&exponent_type)? =>
366 {
367 Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
368 &base_type,
369 )?)))
370 }
371 Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
372 Ok(ExprSimplifyResult::Simplified(base))
373 }
374 Expr::ScalarFunction(ScalarFunction { func, mut args })
375 if is_log(&func) && args.len() == 2 && base == args[0] =>
376 {
377 let b = args.pop().unwrap(); Ok(ExprSimplifyResult::Simplified(b))
379 }
380 _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
381 }
382 }
383
384 fn documentation(&self) -> Option<&Documentation> {
385 self.doc()
386 }
387}
388
389fn is_log(func: &ScalarUDF) -> bool {
391 func.inner().as_any().downcast_ref::<LogFunc>().is_some()
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn test_pow_decimal128_helper() {
400 assert_eq!(pow_decimal_int(25, 1, 4).unwrap(), i128::from(390));
402 assert_eq!(pow_decimal_int(2500, 3, 4).unwrap(), i128::from(39062));
403 assert_eq!(pow_decimal_int(25000, 4, 4).unwrap(), i128::from(390625));
404
405 assert_eq!(pow_decimal_int(25, 0, 4).unwrap(), i128::from(390625));
407
408 assert_eq!(pow_decimal_int(25, 1, 1).unwrap(), i128::from(25));
410 assert_eq!(pow_decimal_int(25, 0, 1).unwrap(), i128::from(25));
411 assert_eq!(pow_decimal_int(25, 0, 0).unwrap(), i128::from(1));
412 assert_eq!(pow_decimal_int(25, 1, 0).unwrap(), i128::from(10));
413
414 assert_eq!(pow_decimal_int(25, -1, 4).unwrap(), i128::from(390625000));
415 }
416}