1use super::log::LogFunc;
20
21use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
22use arrow::array::{Array, ArrayRef};
23use arrow::datatypes::i256;
24use arrow::datatypes::{
25 ArrowNativeType, ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type,
26 Decimal128Type, Decimal256Type, Float64Type, Int64Type,
27};
28use arrow::error::ArrowError;
29use datafusion_common::types::{NativeType, logical_float64, logical_int64};
30use datafusion_common::utils::take_function_args;
31use datafusion_common::{Result, ScalarValue, internal_err};
32use datafusion_expr::expr::ScalarFunction;
33use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
34use datafusion_expr::{
35 Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF,
36 ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit,
37};
38use datafusion_macros::user_doc;
39use num_traits::{NumCast, ToPrimitive};
40
41#[inline]
43fn float64_power_checked(base: f64, exp: f64) -> Result<f64, ArrowError> {
44 if base == 0.0 && exp < 0.0 {
45 return Err(ArrowError::ComputeError(
46 "zero raised to a negative power is undefined".to_string(),
47 ));
48 }
49 Ok(base.powf(exp))
50}
51
52#[user_doc(
53 doc_section(label = "Math Functions"),
54 description = "Returns a base expression raised to the power of an exponent.",
55 syntax_example = "power(base, exponent)",
56 sql_example = r#"```sql
57> SELECT power(2, 3);
58+-------------+
59| power(2,3) |
60+-------------+
61| 8 |
62+-------------+
63```"#,
64 standard_argument(name = "base", prefix = "Numeric"),
65 standard_argument(name = "exponent", prefix = "Exponent numeric")
66)]
67#[derive(Debug, PartialEq, Eq, Hash)]
68pub struct PowerFunc {
69 signature: Signature,
70 aliases: Vec<String>,
71}
72
73impl Default for PowerFunc {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl PowerFunc {
80 pub fn new() -> Self {
81 let integer = Coercion::new_implicit(
82 TypeSignatureClass::Native(logical_int64()),
83 vec![TypeSignatureClass::Integer],
84 NativeType::Int64,
85 );
86 let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
87 let float = Coercion::new_implicit(
88 TypeSignatureClass::Native(logical_float64()),
89 vec![TypeSignatureClass::Numeric],
90 NativeType::Float64,
91 );
92 Self {
93 signature: Signature::one_of(
94 vec![
95 TypeSignature::Coercible(vec![decimal.clone(), integer]),
96 TypeSignature::Coercible(vec![decimal.clone(), float.clone()]),
97 TypeSignature::Coercible(vec![float; 2]),
98 ],
99 Volatility::Immutable,
100 ),
101 aliases: vec![String::from("pow")],
102 }
103 }
104}
105
106fn pow_decimal_int<T>(base: T, scale: i8, exp: i64) -> Result<T, ArrowError>
127where
128 T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
129{
130 if exp < 0 {
132 return pow_decimal_float(base, scale, exp as f64);
133 }
134
135 let exp: u32 = exp.try_into().map_err(|_| {
136 ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
137 })?;
138 if exp == 0 {
141 return if scale >= 0 {
142 T::usize_as(10).pow_checked(scale as u32).map_err(|_| {
143 ArrowError::ArithmeticOverflow(format!(
144 "Cannot make unscale factor for {scale} and {exp}"
145 ))
146 })
147 } else {
148 Ok(T::ZERO)
149 };
150 }
151 let powered: T = base.pow_checked(exp).map_err(|_| {
152 ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
153 })?;
154
155 let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
158
159 if mul_exp == 0 {
160 return Ok(powered);
161 }
162
163 if mul_exp > 0 {
166 let div_factor: T =
167 T::usize_as(10).pow_checked(mul_exp as u32).map_err(|_| {
168 ArrowError::ArithmeticOverflow(format!(
169 "Cannot make div factor for {scale} and {exp}"
170 ))
171 })?;
172 powered.div_checked(div_factor)
173 } else {
174 let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
176 ArrowError::ArithmeticOverflow(
177 "Overflow while negating scale exponent".to_string(),
178 )
179 })?;
180 let mul_factor: T =
181 T::usize_as(10).pow_checked(abs_exp as u32).map_err(|_| {
182 ArrowError::ArithmeticOverflow(format!(
183 "Cannot make mul factor for {scale} and {exp}"
184 ))
185 })?;
186 powered.mul_checked(mul_factor)
187 }
188}
189
190fn pow_decimal_float<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
193where
194 T: ArrowNativeType + ArrowNativeTypeOp + ToPrimitive + NumCast + Copy,
195{
196 if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
197 return pow_decimal_int(base, scale, exp as i64);
198 }
199
200 if !exp.is_finite() {
201 return Err(ArrowError::ComputeError(format!(
202 "Cannot use non-finite exp: {exp}"
203 )));
204 }
205
206 pow_decimal_float_fallback(base, scale, exp)
207}
208
209#[inline]
212fn compute_pow_f64_result(
213 base_f64: f64,
214 scale: i8,
215 exp: f64,
216) -> Result<i128, ArrowError> {
217 let result_f64 = float64_power_checked(base_f64, exp)?;
218
219 if !result_f64.is_finite() {
220 return Err(ArrowError::ArithmeticOverflow(format!(
221 "Result of {base_f64}^{exp} is not finite"
222 )));
223 }
224
225 let scale_factor = 10f64.powi(scale as i32);
226 let result_scaled = result_f64 * scale_factor;
227 let result_rounded = result_scaled.round();
228
229 if result_rounded.abs() > i128::MAX as f64 {
230 return Err(ArrowError::ArithmeticOverflow(format!(
231 "Result {result_rounded} is too large for the target decimal type"
232 )));
233 }
234
235 Ok(result_rounded as i128)
236}
237
238#[inline]
241fn decimal_from_i128<T>(value: i128) -> Result<T, ArrowError>
242where
243 T: NumCast,
244{
245 NumCast::from(value).ok_or_else(|| {
246 ArrowError::ArithmeticOverflow(format!(
247 "Value {value} is too large for the target decimal type"
248 ))
249 })
250}
251
252fn pow_decimal_float_fallback<T>(base: T, scale: i8, exp: f64) -> Result<T, ArrowError>
255where
256 T: ToPrimitive + NumCast + Copy,
257{
258 if scale < 0 {
259 return Err(ArrowError::NotYetImplemented(format!(
260 "Negative scale is not yet supported: {scale}"
261 )));
262 }
263
264 let scale_factor = 10f64.powi(scale as i32);
265 let base_f64 = base.to_f64().ok_or_else(|| {
266 ArrowError::ComputeError("Cannot convert base to f64".to_string())
267 })? / scale_factor;
268
269 let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
270
271 decimal_from_i128(result_i128)
272}
273
274fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result<i256, ArrowError> {
276 if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 {
277 return pow_decimal256_int(base, scale, exp as i64);
278 }
279
280 if !exp.is_finite() {
281 return Err(ArrowError::ComputeError(format!(
282 "Cannot use non-finite exp: {exp}"
283 )));
284 }
285
286 pow_decimal256_float_fallback(base, scale, exp)
287}
288
289fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result<i256, ArrowError> {
291 if exp < 0 {
292 return pow_decimal256_float(base, scale, exp as f64);
293 }
294
295 let exp: u32 = exp.try_into().map_err(|_| {
296 ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}"))
297 })?;
298
299 if exp == 0 {
300 return if scale >= 0 {
301 i256::from_i128(10).pow_checked(scale as u32).map_err(|_| {
302 ArrowError::ArithmeticOverflow(format!(
303 "Cannot make unscale factor for {scale} and {exp}"
304 ))
305 })
306 } else {
307 Ok(i256::from_i128(0))
308 };
309 }
310
311 let powered: i256 = base.pow_checked(exp).map_err(|_| {
312 ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}"))
313 })?;
314
315 let mul_exp = (scale as i64).wrapping_mul(exp as i64 - 1);
316
317 if mul_exp == 0 {
318 return Ok(powered);
319 }
320
321 if mul_exp > 0 {
322 let div_factor: i256 =
323 i256::from_i128(10)
324 .pow_checked(mul_exp as u32)
325 .map_err(|_| {
326 ArrowError::ArithmeticOverflow(format!(
327 "Cannot make div factor for {scale} and {exp}"
328 ))
329 })?;
330 powered.div_checked(div_factor)
331 } else {
332 let abs_exp = mul_exp.checked_neg().ok_or_else(|| {
333 ArrowError::ArithmeticOverflow(
334 "Overflow while negating scale exponent".to_string(),
335 )
336 })?;
337 let mul_factor: i256 =
338 i256::from_i128(10)
339 .pow_checked(abs_exp as u32)
340 .map_err(|_| {
341 ArrowError::ArithmeticOverflow(format!(
342 "Cannot make mul factor for {scale} and {exp}"
343 ))
344 })?;
345 powered.mul_checked(mul_factor)
346 }
347}
348
349fn pow_decimal256_float_fallback(
351 base: i256,
352 scale: i8,
353 exp: f64,
354) -> Result<i256, ArrowError> {
355 if scale < 0 {
356 return Err(ArrowError::NotYetImplemented(format!(
357 "Negative scale is not yet supported: {scale}"
358 )));
359 }
360
361 let scale_factor = 10f64.powi(scale as i32);
362 let base_f64 = base.to_f64().ok_or_else(|| {
363 ArrowError::ComputeError("Cannot convert base to f64".to_string())
364 })? / scale_factor;
365
366 let result_i128 = compute_pow_f64_result(base_f64, scale, exp)?;
367
368 Ok(i256::from_i128(result_i128))
370}
371
372fn pow_decimal_with_float_fallback(
376 base: &ArrayRef,
377 exponent: &ColumnarValue,
378 num_rows: usize,
379) -> Result<ColumnarValue> {
380 use arrow::compute::cast;
381
382 let original_type = base.data_type().clone();
383 let base_f64 = cast(base.as_ref(), &DataType::Float64)?;
384
385 let exp_f64 = match exponent {
386 ColumnarValue::Array(arr) => cast(arr.as_ref(), &DataType::Float64)?,
387 ColumnarValue::Scalar(scalar) => {
388 let scalar_f64 = scalar.cast_to(&DataType::Float64)?;
389 scalar_f64.to_array_of_size(num_rows)?
390 }
391 };
392
393 let result_f64 = calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
394 &base_f64,
395 &ColumnarValue::Array(exp_f64),
396 float64_power_checked,
397 )?;
398
399 let result = cast(result_f64.as_ref(), &original_type)?;
400 Ok(ColumnarValue::Array(result))
401}
402
403impl ScalarUDFImpl for PowerFunc {
404 fn name(&self) -> &str {
405 "power"
406 }
407
408 fn signature(&self) -> &Signature {
409 &self.signature
410 }
411
412 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
413 if arg_types[0].is_null() {
414 Ok(DataType::Float64)
415 } else {
416 Ok(arg_types[0].clone())
417 }
418 }
419
420 fn aliases(&self) -> &[String] {
421 &self.aliases
422 }
423
424 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
425 let [base, exponent] = take_function_args(self.name(), &args.args)?;
426
427 let use_float_fallback = matches!(
431 base.data_type(),
432 DataType::Decimal32(_, _)
433 | DataType::Decimal64(_, _)
434 | DataType::Decimal128(_, _)
435 | DataType::Decimal256(_, _)
436 ) && matches!(exponent, ColumnarValue::Array(_));
437
438 let base = base.to_array(args.number_rows)?;
439
440 if use_float_fallback {
442 return pow_decimal_with_float_fallback(&base, exponent, args.number_rows);
443 }
444
445 let arr: ArrayRef = match (base.data_type(), exponent.data_type()) {
446 (DataType::Float64, DataType::Float64) => {
447 calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
448 &base,
449 exponent,
450 float64_power_checked,
451 )?
452 }
453 (DataType::Decimal32(precision, scale), DataType::Int64) => {
454 calculate_binary_decimal_math::<Decimal32Type, Int64Type, Decimal32Type, _>(
455 &base,
456 exponent,
457 |b, e| pow_decimal_int(b, *scale, e),
458 *precision,
459 *scale,
460 )?
461 }
462 (DataType::Decimal32(precision, scale), DataType::Float64) => {
463 calculate_binary_decimal_math::<
464 Decimal32Type,
465 Float64Type,
466 Decimal32Type,
467 _,
468 >(
469 &base,
470 exponent,
471 |b, e| pow_decimal_float(b, *scale, e),
472 *precision,
473 *scale,
474 )?
475 }
476 (DataType::Decimal64(precision, scale), DataType::Int64) => {
477 calculate_binary_decimal_math::<Decimal64Type, Int64Type, Decimal64Type, _>(
478 &base,
479 exponent,
480 |b, e| pow_decimal_int(b, *scale, e),
481 *precision,
482 *scale,
483 )?
484 }
485 (DataType::Decimal64(precision, scale), DataType::Float64) => {
486 calculate_binary_decimal_math::<
487 Decimal64Type,
488 Float64Type,
489 Decimal64Type,
490 _,
491 >(
492 &base,
493 exponent,
494 |b, e| pow_decimal_float(b, *scale, e),
495 *precision,
496 *scale,
497 )?
498 }
499 (DataType::Decimal128(precision, scale), DataType::Int64) => {
500 calculate_binary_decimal_math::<
501 Decimal128Type,
502 Int64Type,
503 Decimal128Type,
504 _,
505 >(
506 &base,
507 exponent,
508 |b, e| pow_decimal_int(b, *scale, e),
509 *precision,
510 *scale,
511 )?
512 }
513 (DataType::Decimal128(precision, scale), DataType::Float64) => {
514 calculate_binary_decimal_math::<
515 Decimal128Type,
516 Float64Type,
517 Decimal128Type,
518 _,
519 >(
520 &base,
521 exponent,
522 |b, e| pow_decimal_float(b, *scale, e),
523 *precision,
524 *scale,
525 )?
526 }
527 (DataType::Decimal256(precision, scale), DataType::Int64) => {
528 calculate_binary_decimal_math::<
529 Decimal256Type,
530 Int64Type,
531 Decimal256Type,
532 _,
533 >(
534 &base,
535 exponent,
536 |b, e| pow_decimal256_int(b, *scale, e),
537 *precision,
538 *scale,
539 )?
540 }
541 (DataType::Decimal256(precision, scale), DataType::Float64) => {
542 calculate_binary_decimal_math::<
543 Decimal256Type,
544 Float64Type,
545 Decimal256Type,
546 _,
547 >(
548 &base,
549 exponent,
550 |b, e| pow_decimal256_float(b, *scale, e),
551 *precision,
552 *scale,
553 )?
554 }
555 (base_type, exp_type) => {
556 return internal_err!(
557 "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for power"
558 );
559 }
560 };
561 Ok(ColumnarValue::Array(arr))
562 }
563
564 fn simplify(
569 &self,
570 args: Vec<Expr>,
571 info: &SimplifyContext,
572 ) -> Result<ExprSimplifyResult> {
573 let [base, exponent] = take_function_args("power", args)?;
574 let base_type = info.get_data_type(&base)?;
575 let exponent_type = info.get_data_type(&exponent)?;
576
577 if base_type.is_null() || exponent_type.is_null() {
579 let return_type = self.return_type(&[base_type, exponent_type])?;
580 return Ok(ExprSimplifyResult::Simplified(lit(
581 ScalarValue::Null.cast_to(&return_type)?
582 )));
583 }
584
585 match exponent {
586 Expr::Literal(value, _)
587 if value == ScalarValue::new_zero(&exponent_type)? =>
588 {
589 Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
590 &base_type,
591 )?)))
592 }
593 Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => {
594 Ok(ExprSimplifyResult::Simplified(base))
595 }
596 Expr::ScalarFunction(ScalarFunction { func, mut args })
597 if is_log(&func) && args.len() == 2 && base == args[0] =>
598 {
599 let b = args.pop().unwrap(); Ok(ExprSimplifyResult::Simplified(b))
601 }
602 _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
603 }
604 }
605
606 fn documentation(&self) -> Option<&Documentation> {
607 self.doc()
608 }
609}
610
611fn is_log(func: &ScalarUDF) -> bool {
613 func.inner().is::<LogFunc>()
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn test_pow_decimal128_helper() {
622 assert_eq!(pow_decimal_int(25i128, 1, 4).unwrap(), 390i128);
624 assert_eq!(pow_decimal_int(2500i128, 3, 4).unwrap(), 39062i128);
625 assert_eq!(pow_decimal_int(25000i128, 4, 4).unwrap(), 390625i128);
626
627 assert_eq!(pow_decimal_int(25i128, 0, 4).unwrap(), 390625i128);
629
630 assert_eq!(pow_decimal_int(25i128, 1, 1).unwrap(), 25i128);
632 assert_eq!(pow_decimal_int(25i128, 0, 1).unwrap(), 25i128);
633 assert_eq!(pow_decimal_int(25i128, 0, 0).unwrap(), 1i128);
634 assert_eq!(pow_decimal_int(25i128, 1, 0).unwrap(), 10i128);
635
636 assert_eq!(pow_decimal_int(25i128, -1, 4).unwrap(), 390625000i128);
637 }
638
639 #[test]
640 fn test_pow_decimal_float_fallback() {
641 let result: i128 = pow_decimal_float(400i128, 2, -1.0).unwrap();
644 assert_eq!(result, 25);
645
646 let result: i128 = pow_decimal_float(400i128, 2, 0.5).unwrap();
649 assert_eq!(result, 200);
650
651 let result: i128 = pow_decimal_float(80i128, 1, 1.0 / 3.0).unwrap();
654 assert_eq!(result, 20);
655
656 let result: i128 = pow_decimal_float(-20i128, 1, 3.0).unwrap();
660 assert_eq!(result, -80);
661
662 let result: i128 = pow_decimal_float(25i128, 1, 4.0).unwrap();
666 assert_eq!(result, 390); assert!(pow_decimal_float(100i128, 2, f64::NAN).is_err());
670 assert!(pow_decimal_float(100i128, 2, f64::INFINITY).is_err());
671
672 assert!(pow_decimal_float(0i128, 2, -1.0).is_err());
674 }
675
676 #[test]
677 fn test_pow_decimal256_zero_to_negative_exp_errors() {
678 assert!(pow_decimal256_float(i256::ZERO, 2, -1.0).is_err());
679 assert!(pow_decimal256_int(i256::ZERO, 2, -1).is_err());
681 }
682
683 #[test]
684 fn test_float64_power_checked_zero_negative_exp() {
685 assert_eq!(float64_power_checked(0.0, 1.0).unwrap(), 0.0);
686 assert_eq!(float64_power_checked(2.0, -1.0).unwrap(), 0.5);
687 for base in [0.0f64, -0.0] {
688 assert!(float64_power_checked(base, -1.0).is_err());
689 assert!(float64_power_checked(base, -0.5).is_err());
690 }
691 }
692}