1use std::any::Any;
20
21use super::log::LogFunc;
22
23use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
24use arrow::array::{Array, ArrayRef};
25use arrow::datatypes::i256;
26use arrow::datatypes::{
27 ArrowNativeType, ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type,
28 Decimal128Type, Decimal256Type, Float64Type, Int64Type,
29};
30use arrow::error::ArrowError;
31use datafusion_common::types::{NativeType, logical_float64, logical_int64};
32use datafusion_common::utils::take_function_args;
33use datafusion_common::{Result, ScalarValue, internal_err};
34use datafusion_expr::expr::ScalarFunction;
35use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
36use datafusion_expr::{
37 Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF,
38 ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit,
39};
40use datafusion_macros::user_doc;
41use num_traits::{NumCast, ToPrimitive};
42
43#[user_doc(
44 doc_section(label = "Math Functions"),
45 description = "Returns a base expression raised to the power of an exponent.",
46 syntax_example = "power(base, exponent)",
47 sql_example = r#"```sql
48> SELECT power(2, 3);
49+-------------+
50| power(2,3) |
51+-------------+
52| 8 |
53+-------------+
54```"#,
55 standard_argument(name = "base", prefix = "Numeric"),
56 standard_argument(name = "exponent", prefix = "Exponent numeric")
57)]
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct PowerFunc {
60 signature: Signature,
61 aliases: Vec<String>,
62}
63
64impl Default for PowerFunc {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl PowerFunc {
71 pub fn new() -> Self {
72 let integer = Coercion::new_implicit(
73 TypeSignatureClass::Native(logical_int64()),
74 vec![TypeSignatureClass::Integer],
75 NativeType::Int64,
76 );
77 let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
78 let float = Coercion::new_implicit(
79 TypeSignatureClass::Native(logical_float64()),
80 vec![TypeSignatureClass::Numeric],
81 NativeType::Float64,
82 );
83 Self {
84 signature: Signature::one_of(
85 vec![
86 TypeSignature::Coercible(vec![decimal.clone(), integer]),
87 TypeSignature::Coercible(vec![decimal.clone(), float.clone()]),
88 TypeSignature::Coercible(vec![float; 2]),
89 ],
90 Volatility::Immutable,
91 ),
92 aliases: vec![String::from("pow")],
93 }
94 }
95}
96
97fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
118where
119 T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
120{
121 if exp < 0 {
123 return pow_decimal_float(base, scale, exp as f64);
124 }
125
126 let exp: u32 = exp.try_into().map_err(|_| {
127 ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
128 })?;
129 if exp == 0 {
132 return if scale >= 0 {
133 T::usize_as(10).pow_checked(scale as u32).map_err(|_| {
134 ArrowError::ArithmeticOverflow(format!(
135 "Cannot make unscale factor for {scale} and {exp}"
136 ))
137 })
138 } else {
139 Ok(T::ZERO)
140 };
141 }
142 let powered: T = base.pow_checked(exp).map_err(|_| {
143 ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
144 })?;
145
146 let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
149
150 if mul_exp == 0 {
151 return Ok(powered);
152 }
153
154 if mul_exp > 0 {
157 let div_factor: T =
158 T::usize_as(10).pow_checked(mul_exp as u32).map_err(|_| {
159 ArrowError::ArithmeticOverflow(format!(
160 "Cannot make div factor for {scale} and {exp}"
161 ))
162 })?;
163 powered.div_checked(div_factor)
164 } else {
165 let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
167 ArrowError::ArithmeticOverflow(
168 "Overflow while negating scale exponent".to_string(),
169 )
170 })?;
171 let mul_factor: T =
172 T::usize_as(10).pow_checked(abs_exp as u32).map_err(|_| {
173 ArrowError::ArithmeticOverflow(format!(
174 "Cannot make mul factor for {scale} and {exp}"
175 ))
176 })?;
177 powered.mul_checked(mul_factor)
178 }
179}
180
181fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
184where
185 T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
186{
187 if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
188 return pow_decimal_int(base, scale, exp as i64);
189 }
190
191 if !exp.is_finite() {
192 return Err(ArrowError::ComputeError(format!(
193 "Cannot use non-finite exp: {exp}"
194 )));
195 }
196
197 pow_decimal_float_fallback(base, scale, exp)
198}
199
200#[inline]
203fn compute_pow_f64_result(
204 base_f64: f64,
205 scale: i8,
206 exp: f64,
207) -> Result<i128, ArrowError> {
208 let result_f64 = base_f64.powf(exp);
209
210 if !result_f64.is_finite() {
211 return Err(ArrowError::ArithmeticOverflow(format!(
212 "Result of {base_f64}^{exp} is not finite"
213 )));
214 }
215
216 let scale_factor = 10f64.powi(scale as i32);
217 let result_scaled = result_f64 * scale_factor;
218 let result_rounded = result_scaled.round();
219
220 if result_rounded.abs() > i128::MAX as f64 {
221 return Err(ArrowError::ArithmeticOverflow(format!(
222 "Result {result_rounded} is too large for the target decimal type"
223 )));
224 }
225
226 Ok(result_rounded as i128)
227}
228
229#[inline]
232fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
233where
234 T: NumCast,
235{
236 NumCast::from(value).ok_or_else(|| {
237 ArrowError::ArithmeticOverflow(format!(
238 "Value {value} is too large for the target decimal type"
239 ))
240 })
241}
242
243fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
246where
247 T: ToPrimitive + NumCast + Copy,
248{
249 if scale < 0 {
250 return Err(ArrowError::NotYetImplemented(format!(
251 "Negative scale is not yet supported: {scale}"
252 )));
253 }
254
255 let scale_factor = 10f64.powi(scale as i32);
256 let base_f64 = base.to_f64().ok_or_else(|| {
257 ArrowError::ComputeError("Cannot convert base to f64".to_string())
258 })? / scale_factor;
259
260 let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
261
262 decimal_from_i128(result_i128)
263}
264
265fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result<i256, ArrowError> {
267 if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
268 return pow_decimal256_int(base, scale, exp as i64);
269 }
270
271 if !exp.is_finite() {
272 return Err(ArrowError::ComputeError(format!(
273 "Cannot use non-finite exp: {exp}"
274 )));
275 }
276
277 pow_decimal256_float_fallback(base, scale, exp)
278}
279
280fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result<i256, ArrowError> {
282 if exp < 0 {
283 return pow_decimal256_float(base, scale, exp as f64);
284 }
285
286 let exp: u32 = exp.try_into().map_err(|_| {
287 ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
288 })?;
289
290 if exp == 0 {
291 return if scale >= 0 {
292 i256::from_i128(10).pow_checked(scale as u32).map_err(|_| {
293 ArrowError::ArithmeticOverflow(format!(
294 "Cannot make unscale factor for {scale} and {exp}"
295 ))
296 })
297 } else {
298 Ok(i256::from_i128(0))
299 };
300 }
301
302 let powered: i256 = base.pow_checked(exp).map_err(|_| {
303 ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
304 })?;
305
306 let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
307
308 if mul_exp == 0 {
309 return Ok(powered);
310 }
311
312 if mul_exp > 0 {
313 let div_factor: i256 =
314 i256::from_i128(10)
315 .pow_checked(mul_exp as u32)
316 .map_err(|_| {
317 ArrowError::ArithmeticOverflow(format!(
318 "Cannot make div factor for {scale} and {exp}"
319 ))
320 })?;
321 powered.div_checked(div_factor)
322 } else {
323 let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
324 ArrowError::ArithmeticOverflow(
325 "Overflow while negating scale exponent".to_string(),
326 )
327 })?;
328 let mul_factor: i256 =
329 i256::from_i128(10)
330 .pow_checked(abs_exp as u32)
331 .map_err(|_| {
332 ArrowError::ArithmeticOverflow(format!(
333 "Cannot make mul factor for {scale} and {exp}"
334 ))
335 })?;
336 powered.mul_checked(mul_factor)
337 }
338}
339
340fn pow_decimal256_float_fallback(
342 base: i256,
343 scale: i8,
344 exp: f64,
345) -> Result<i256, ArrowError> {
346 if scale < 0 {
347 return Err(ArrowError::NotYetImplemented(format!(
348 "Negative scale is not yet supported: {scale}"
349 )));
350 }
351
352 let scale_factor = 10f64.powi(scale as i32);
353 let base_f64 = base.to_f64().ok_or_else(|| {
354 ArrowError::ComputeError("Cannot convert base to f64".to_string())
355 })? / scale_factor;
356
357 let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
358
359 Ok(i256::from_i128(result_i128))
361}
362
363fn pow_decimal_with_float_fallback(
367 base: &ArrayRef,
368 exponent: &ColumnarValue,
369 num_rows: usize,
370) -> Result<ColumnarValue> {
371 use arrow::compute::cast;
372
373 let original_type = base.data_type().clone();
374 let base_f64 = cast(base.as_ref(), &DataType::Float64)?;
375
376 let exp_f64 = match exponent {
377 ColumnarValue::Array(arr) => cast(arr.as_ref(), &DataType::Float64)?,
378 ColumnarValue::Scalar(scalar) => {
379 let scalar_f64 = scalar.cast_to(&DataType::Float64)?;
380 scalar_f64.to_array_of_size(num_rows)?
381 }
382 };
383
384 let result_f64 = calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
385 &base_f64,
386 &ColumnarValue::Array(exp_f64),
387 |b, e| Ok(f64::powf(b, e)),
388 )?;
389
390 let result = cast(result_f64.as_ref(), &original_type)?;
391 Ok(ColumnarValue::Array(result))
392}
393
394impl ScalarUDFImpl for PowerFunc {
395 fn as_any(&self) -> &dyn Any {
396 self
397 }
398
399 fn name(&self) -> &str {
400 "power"
401 }
402
403 fn signature(&self) -> &Signature {
404 &self.signature
405 }
406
407 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
408 if arg_types[0].is_null() {
409 Ok(DataType::Float64)
410 } else {
411 Ok(arg_types[0].clone())
412 }
413 }
414
415 fn aliases(&self) -> &[String] {
416 &self.aliases
417 }
418
419 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
420 let [base, exponent] = take_function_args(self.name(), &args.args)?;
421
422 let use_float_fallback = matches!(
426 base.data_type(),
427 DataType::Decimal32(_, _)
428 | DataType::Decimal64(_, _)
429 | DataType::Decimal128(_, _)
430 | DataType::Decimal256(_, _)
431 ) && matches!(exponent, ColumnarValue::Array(_));
432
433 let base = base.to_array(args.number_rows)?;
434
435 if use_float_fallback {
437 return pow_decimal_with_float_fallback(&base, exponent, args.number_rows);
438 }
439
440 let arr: ArrayRef = match (base.data_type(), exponent.data_type()) {
441 (DataType::Float64, DataType::Float64) => {
442 calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
443 &base,
444 exponent,
445 |b, e| Ok(f64::powf(b, e)),
446 )?
447 }
448 (DataType::Decimal32(precision, scale), DataType::Int64) => {
449 calculate_binary_decimal_math::<Decimal32Type, Int64Type, Decimal32Type, _>(
450 &base,
451 exponent,
452 |b, e| pow_decimal_int(b, *scale, e),
453 *precision,
454 *scale,
455 )?
456 }
457 (DataType::Decimal32(precision, scale), DataType::Float64) => {
458 calculate_binary_decimal_math::<
459 Decimal32Type,
460 Float64Type,
461 Decimal32Type,
462 _,
463 >(
464 &base,
465 exponent,
466 |b, e| pow_decimal_float(b, *scale, e),
467 *precision,
468 *scale,
469 )?
470 }
471 (DataType::Decimal64(precision, scale), DataType::Int64) => {
472 calculate_binary_decimal_math::<Decimal64Type, Int64Type, Decimal64Type, _>(
473 &base,
474 exponent,
475 |b, e| pow_decimal_int(b, *scale, e),
476 *precision,
477 *scale,
478 )?
479 }
480 (DataType::Decimal64(precision, scale), DataType::Float64) => {
481 calculate_binary_decimal_math::<
482 Decimal64Type,
483 Float64Type,
484 Decimal64Type,
485 _,
486 >(
487 &base,
488 exponent,
489 |b, e| pow_decimal_float(b, *scale, e),
490 *precision,
491 *scale,
492 )?
493 }
494 (DataType::Decimal128(precision, scale), DataType::Int64) => {
495 calculate_binary_decimal_math::<
496 Decimal128Type,
497 Int64Type,
498 Decimal128Type,
499 _,
500 >(
501 &base,
502 exponent,
503 |b, e| pow_decimal_int(b, *scale, e),
504 *precision,
505 *scale,
506 )?
507 }
508 (DataType::Decimal128(precision, scale), DataType::Float64) => {
509 calculate_binary_decimal_math::<
510 Decimal128Type,
511 Float64Type,
512 Decimal128Type,
513 _,
514 >(
515 &base,
516 exponent,
517 |b, e| pow_decimal_float(b, *scale, e),
518 *precision,
519 *scale,
520 )?
521 }
522 (DataType::Decimal256(precision, scale), DataType::Int64) => {
523 calculate_binary_decimal_math::<
524 Decimal256Type,
525 Int64Type,
526 Decimal256Type,
527 _,
528 >(
529 &base,
530 exponent,
531 |b, e| pow_decimal256_int(b, *scale, e),
532 *precision,
533 *scale,
534 )?
535 }
536 (DataType::Decimal256(precision, scale), DataType::Float64) => {
537 calculate_binary_decimal_math::<
538 Decimal256Type,
539 Float64Type,
540 Decimal256Type,
541 _,
542 >(
543 &base,
544 exponent,
545 |b, e| pow_decimal256_float(b, *scale, e),
546 *precision,
547 *scale,
548 )?
549 }
550 (base_type, exp_type) => {
551 return internal_err!(
552 "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for power"
553 );
554 }
555 };
556 Ok(ColumnarValue::Array(arr))
557 }
558
559 fn simplify(
564 &self,
565 args: Vec<Expr>,
566 info: &SimplifyContext,
567 ) -> Result<ExprSimplifyResult> {
568 let [base, exponent] = take_function_args("power", args)?;
569 let base_type = info.get_data_type(&base)?;
570 let exponent_type = info.get_data_type(&exponent)?;
571
572 if base_type.is_null() || exponent_type.is_null() {
574 let return_type = self.return_type(&[base_type, exponent_type])?;
575 return Ok(ExprSimplifyResult::Simplified(lit(
576 ScalarValue::Null.cast_to(&return_type)?
577 )));
578 }
579
580 match exponent {
581 Expr::Literal(value, _)
582 if value == ScalarValue::new_zero(&exponent_type)? =>
583 {
584 Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
585 &base_type,
586 )?)))
587 }
588 Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
589 Ok(ExprSimplifyResult::Simplified(base))
590 }
591 Expr::ScalarFunction(ScalarFunction { func, mut args })
592 if is_log(&func) && args.len() == 2 && base == args[0] =>
593 {
594 let b = args.pop().unwrap(); Ok(ExprSimplifyResult::Simplified(b))
596 }
597 _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
598 }
599 }
600
601 fn documentation(&self) -> Option<&Documentation> {
602 self.doc()
603 }
604}
605
606fn is_log(func: &ScalarUDF) -> bool {
608 func.inner().as_any().downcast_ref::<LogFunc>().is_some()
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614
615 #[test]
616 fn test_pow_decimal128_helper() {
617 assert_eq!(pow_decimal_int(25i128, 1, 4).unwrap(), 390i128);
619 assert_eq!(pow_decimal_int(2500i128, 3, 4).unwrap(), 39062i128);
620 assert_eq!(pow_decimal_int(25000i128, 4, 4).unwrap(), 390625i128);
621
622 assert_eq!(pow_decimal_int(25i128, 0, 4).unwrap(), 390625i128);
624
625 assert_eq!(pow_decimal_int(25i128, 1, 1).unwrap(), 25i128);
627 assert_eq!(pow_decimal_int(25i128, 0, 1).unwrap(), 25i128);
628 assert_eq!(pow_decimal_int(25i128, 0, 0).unwrap(), 1i128);
629 assert_eq!(pow_decimal_int(25i128, 1, 0).unwrap(), 10i128);
630
631 assert_eq!(pow_decimal_int(25i128, -1, 4).unwrap(), 390625000i128);
632 }
633
634 #[test]
635 fn test_pow_decimal_float_fallback() {
636 let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap();
639 assert_eq!(result, 25);
640
641 let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap();
644 assert_eq!(result, 200);
645
646 let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap();
649 assert_eq!(result, 20);
650
651 let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap();
655 assert_eq!(result, -80);
656
657 let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap();
661 assert_eq!(result, 390); assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err());
665 assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err());
666 }
667}