datafusion_functions/math/
log.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Math function: `log()`.
19
20use std::any::Any;
21
22use super::power::PowerFunc;
23
24use crate::utils::{calculate_binary_math, decimal128_to_i128};
25use arrow::array::{Array, ArrayRef};
26use arrow::compute::kernels::cast;
27use arrow::datatypes::{
28    DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type,
29};
30use arrow::error::ArrowError;
31use arrow_buffer::i256;
32use datafusion_common::types::NativeType;
33use datafusion_common::{
34    exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue,
35};
36use datafusion_expr::expr::ScalarFunction;
37use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
38use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
39use datafusion_expr::{
40    lit, Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF,
41    TypeSignature, TypeSignatureClass,
42};
43use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
44use datafusion_macros::user_doc;
45use num_traits::Float;
46
47#[user_doc(
48    doc_section(label = "Math Functions"),
49    description = "Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.",
50    syntax_example = r#"log(base, numeric_expression)
51log(numeric_expression)"#,
52    sql_example = r#"```sql
53> SELECT log(10);
54+---------+
55| log(10) |
56+---------+
57| 1.0     |
58+---------+
59```"#,
60    standard_argument(name = "base", prefix = "Base numeric"),
61    standard_argument(name = "numeric_expression", prefix = "Numeric")
62)]
63#[derive(Debug, PartialEq, Eq, Hash)]
64pub struct LogFunc {
65    signature: Signature,
66}
67
68impl Default for LogFunc {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl LogFunc {
75    pub fn new() -> Self {
76        // Converts decimals & integers to float64, accepting other floats as is
77        let as_float = Coercion::new_implicit(
78            TypeSignatureClass::Float,
79            vec![TypeSignatureClass::Numeric],
80            NativeType::Float64,
81        );
82        Self {
83            signature: Signature::one_of(
84                // Ensure decimals have precedence over floats since we have
85                // a native decimal implementation for log
86                vec![
87                    // log(value)
88                    TypeSignature::Coercible(vec![Coercion::new_exact(
89                        TypeSignatureClass::Decimal,
90                    )]),
91                    TypeSignature::Coercible(vec![as_float.clone()]),
92                    // log(base, value)
93                    TypeSignature::Coercible(vec![
94                        as_float.clone(),
95                        Coercion::new_exact(TypeSignatureClass::Decimal),
96                    ]),
97                    TypeSignature::Coercible(vec![as_float.clone(), as_float.clone()]),
98                ],
99                Volatility::Immutable,
100            ),
101        }
102    }
103}
104
105/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
106/// Returns error if base is invalid
107fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError> {
108    if !base.is_finite() || base.trunc() != base {
109        return Err(ArrowError::ComputeError(format!(
110            "Log cannot use non-integer base: {base}"
111        )));
112    }
113    if (base as u32) < 2 {
114        return Err(ArrowError::ComputeError(format!(
115            "Log base must be greater than 1: {base}"
116        )));
117    }
118
119    let unscaled_value = decimal128_to_i128(value, scale)?;
120    if unscaled_value > 0 {
121        let log_value: u32 = unscaled_value.ilog(base as i128);
122        Ok(log_value as f64)
123    } else {
124        // Reflect f64::log behaviour
125        Ok(f64::NAN)
126    }
127}
128
129/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
130/// Returns error if base is invalid or if value is out of bounds of Decimal128
131fn log_decimal256(value: i256, scale: i8, base: f64) -> Result<f64, ArrowError> {
132    match value.to_i128() {
133        Some(value) => log_decimal128(value, scale, base),
134        None => Err(ArrowError::NotYetImplemented(format!(
135            "Log of Decimal256 larger than Decimal128 is not yet supported: {value}"
136        ))),
137    }
138}
139
140impl ScalarUDFImpl for LogFunc {
141    fn as_any(&self) -> &dyn Any {
142        self
143    }
144    fn name(&self) -> &str {
145        "log"
146    }
147
148    fn signature(&self) -> &Signature {
149        &self.signature
150    }
151
152    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
153        // Check last argument (value)
154        match &arg_types.last().ok_or(plan_datafusion_err!("No args"))? {
155            DataType::Float16 => Ok(DataType::Float16),
156            DataType::Float32 => Ok(DataType::Float32),
157            _ => Ok(DataType::Float64),
158        }
159    }
160
161    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
162        let (base_sort_properties, num_sort_properties) = if input.len() == 1 {
163            // log(x) defaults to log(10, x)
164            (SortProperties::Singleton, input[0].sort_properties)
165        } else {
166            (input[0].sort_properties, input[1].sort_properties)
167        };
168        match (num_sort_properties, base_sort_properties) {
169            (first @ SortProperties::Ordered(num), SortProperties::Ordered(base))
170                if num.descending != base.descending
171                    && num.nulls_first == base.nulls_first =>
172            {
173                Ok(first)
174            }
175            (
176                first @ (SortProperties::Ordered(_) | SortProperties::Singleton),
177                SortProperties::Singleton,
178            ) => Ok(first),
179            (SortProperties::Singleton, second @ SortProperties::Ordered(_)) => {
180                Ok(-second)
181            }
182            _ => Ok(SortProperties::Unordered),
183        }
184    }
185
186    // Support overloaded log(base, x) and log(x) which defaults to log(10, x)
187    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
188        if args.arg_fields.iter().any(|a| a.data_type().is_null()) {
189            return ColumnarValue::Scalar(ScalarValue::Null)
190                .cast_to(args.return_type(), None);
191        }
192
193        let (base, value) = if args.args.len() == 2 {
194            (args.args[0].clone(), &args.args[1])
195        } else {
196            // no base specified, default to 10
197            (
198                ColumnarValue::Scalar(ScalarValue::new_ten(args.return_type())?),
199                &args.args[0],
200            )
201        };
202        let value = value.to_array(args.number_rows)?;
203
204        let output: ArrayRef = match value.data_type() {
205            DataType::Float16 => {
206                calculate_binary_math::<Float16Type, Float16Type, Float16Type, _>(
207                    &value,
208                    &base,
209                    |value, base| Ok(value.log(base)),
210                )?
211            }
212            DataType::Float32 => {
213                calculate_binary_math::<Float32Type, Float32Type, Float32Type, _>(
214                    &value,
215                    &base,
216                    |value, base| Ok(value.log(base)),
217                )?
218            }
219            DataType::Float64 => {
220                calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
221                    &value,
222                    &base,
223                    |value, base| Ok(value.log(base)),
224                )?
225            }
226            // TODO: native log support for decimal 32 & 64; right now upcast
227            //       to decimal128 to calculate
228            //       https://github.com/apache/datafusion/issues/17555
229            DataType::Decimal32(precision, scale)
230            | DataType::Decimal64(precision, scale) => {
231                calculate_binary_math::<Decimal128Type, Float64Type, Float64Type, _>(
232                    &cast(&value, &DataType::Decimal128(*precision, *scale))?,
233                    &base,
234                    |value, base| log_decimal128(value, *scale, base),
235                )?
236            }
237            DataType::Decimal128(_, scale) => {
238                calculate_binary_math::<Decimal128Type, Float64Type, Float64Type, _>(
239                    &value,
240                    &base,
241                    |value, base| log_decimal128(value, *scale, base),
242                )?
243            }
244            DataType::Decimal256(_, scale) => {
245                calculate_binary_math::<Decimal256Type, Float64Type, Float64Type, _>(
246                    &value,
247                    &base,
248                    |value, base| log_decimal256(value, *scale, base),
249                )?
250            }
251            other => {
252                return exec_err!("Unsupported data type {other:?} for function log")
253            }
254        };
255
256        Ok(ColumnarValue::Array(output))
257    }
258
259    fn documentation(&self) -> Option<&Documentation> {
260        self.doc()
261    }
262
263    /// Simplify the `log` function by the relevant rules:
264    /// 1. Log(a, 1) ===> 0
265    /// 2. Log(a, Power(a, b)) ===> b
266    /// 3. Log(a, a) ===> 1
267    fn simplify(
268        &self,
269        mut args: Vec<Expr>,
270        info: &dyn SimplifyInfo,
271    ) -> Result<ExprSimplifyResult> {
272        let mut arg_types = args
273            .iter()
274            .map(|arg| info.get_data_type(arg))
275            .collect::<Result<Vec<_>>>()?;
276        let return_type = self.return_type(&arg_types)?;
277
278        // Null propagation
279        if arg_types.iter().any(|dt| dt.is_null()) {
280            return Ok(ExprSimplifyResult::Simplified(lit(
281                ScalarValue::Null.cast_to(&return_type)?
282            )));
283        }
284
285        // Args are either
286        // log(number)
287        // log(base, number)
288        let num_args = args.len();
289        if num_args != 1 && num_args != 2 {
290            return plan_err!("Expected log to have 1 or 2 arguments, got {num_args}");
291        }
292        let number = args.pop().unwrap();
293        let number_datatype = arg_types.pop().unwrap();
294        // default to base 10
295        let base = if let Some(base) = args.pop() {
296            base
297        } else {
298            lit(ScalarValue::new_ten(&number_datatype)?)
299        };
300
301        match number {
302            Expr::Literal(value, _)
303                if value == ScalarValue::new_one(&number_datatype)? =>
304            {
305                Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero(
306                    &info.get_data_type(&base)?,
307                )?)))
308            }
309            Expr::ScalarFunction(ScalarFunction { func, mut args })
310                if is_pow(&func) && args.len() == 2 && base == args[0] =>
311            {
312                let b = args.pop().unwrap(); // length checked above
313                Ok(ExprSimplifyResult::Simplified(b))
314            }
315            number => {
316                if number == base {
317                    Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
318                        &number_datatype,
319                    )?)))
320                } else {
321                    let args = match num_args {
322                        1 => vec![number],
323                        2 => vec![base, number],
324                        _ => {
325                            return internal_err!(
326                                "Unexpected number of arguments in log::simplify"
327                            )
328                        }
329                    };
330                    Ok(ExprSimplifyResult::Original(args))
331                }
332            }
333        }
334    }
335}
336
337/// Returns true if the function is `PowerFunc`
338fn is_pow(func: &ScalarUDF) -> bool {
339    func.inner().as_any().downcast_ref::<PowerFunc>().is_some()
340}
341
342#[cfg(test)]
343mod tests {
344    use std::collections::HashMap;
345    use std::sync::Arc;
346
347    use super::*;
348
349    use arrow::array::{
350        Date32Array, Decimal128Array, Decimal256Array, Float32Array, Float64Array,
351    };
352    use arrow::compute::SortOptions;
353    use arrow::datatypes::{Field, DECIMAL256_MAX_PRECISION};
354    use datafusion_common::cast::{as_float32_array, as_float64_array};
355    use datafusion_common::config::ConfigOptions;
356    use datafusion_common::DFSchema;
357    use datafusion_expr::execution_props::ExecutionProps;
358    use datafusion_expr::simplify::SimplifyContext;
359
360    #[test]
361    fn test_log_decimal_native() {
362        let value = 10_i128.pow(35);
363        assert_eq!((value as f64).log2(), 116.26748332105768);
364        assert_eq!(
365            log_decimal128(value, 0, 2.0).unwrap(),
366            // TODO: see we're losing our decimal points compared to above
367            //       https://github.com/apache/datafusion/issues/18524
368            116.0
369        );
370    }
371
372    #[test]
373    fn test_log_invalid_base_type() {
374        let arg_fields = vec![
375            Field::new("b", DataType::Date32, false).into(),
376            Field::new("n", DataType::Float64, false).into(),
377        ];
378        let args = ScalarFunctionArgs {
379            args: vec![
380                ColumnarValue::Array(Arc::new(Date32Array::from(vec![5, 10, 15, 20]))), // base
381                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
382                    10.0, 100.0, 1000.0, 10000.0,
383                ]))), // num
384            ],
385            arg_fields,
386            number_rows: 4,
387            return_field: Field::new("f", DataType::Float64, true).into(),
388            config_options: Arc::new(ConfigOptions::default()),
389        };
390        let result = LogFunc::new().invoke_with_args(args);
391        assert!(result.is_err());
392        assert_eq!(
393            result.unwrap_err().to_string().lines().next().unwrap(),
394            "Arrow error: Cast error: Casting from Date32 to Float64 not supported"
395        );
396    }
397
398    #[test]
399    fn test_log_invalid_value() {
400        let arg_field = Field::new("a", DataType::Date32, false).into();
401        let args = ScalarFunctionArgs {
402            args: vec![
403                ColumnarValue::Array(Arc::new(Date32Array::from(vec![10]))), // num
404            ],
405            arg_fields: vec![arg_field],
406            number_rows: 1,
407            return_field: Field::new("f", DataType::Float64, true).into(),
408            config_options: Arc::new(ConfigOptions::default()),
409        };
410
411        let result = LogFunc::new().invoke_with_args(args);
412        result.expect_err("expected error");
413    }
414
415    #[test]
416    fn test_log_scalar_f32_unary() {
417        let arg_field = Field::new("a", DataType::Float32, false).into();
418        let args = ScalarFunctionArgs {
419            args: vec![
420                ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
421            ],
422            arg_fields: vec![arg_field],
423            number_rows: 1,
424            return_field: Field::new("f", DataType::Float32, true).into(),
425            config_options: Arc::new(ConfigOptions::default()),
426        };
427        let result = LogFunc::new()
428            .invoke_with_args(args)
429            .expect("failed to initialize function log");
430
431        match result {
432            ColumnarValue::Array(arr) => {
433                let floats = as_float32_array(&arr)
434                    .expect("failed to convert result to a Float32Array");
435
436                assert_eq!(floats.len(), 1);
437                assert!((floats.value(0) - 1.0).abs() < 1e-10);
438            }
439            ColumnarValue::Scalar(_) => {
440                panic!("Expected an array value")
441            }
442        }
443    }
444
445    #[test]
446    fn test_log_scalar_f64_unary() {
447        let arg_field = Field::new("a", DataType::Float64, false).into();
448        let args = ScalarFunctionArgs {
449            args: vec![
450                ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
451            ],
452            arg_fields: vec![arg_field],
453            number_rows: 1,
454            return_field: Field::new("f", DataType::Float64, true).into(),
455            config_options: Arc::new(ConfigOptions::default()),
456        };
457        let result = LogFunc::new()
458            .invoke_with_args(args)
459            .expect("failed to initialize function log");
460
461        match result {
462            ColumnarValue::Array(arr) => {
463                let floats = as_float64_array(&arr)
464                    .expect("failed to convert result to a Float64Array");
465
466                assert_eq!(floats.len(), 1);
467                assert!((floats.value(0) - 1.0).abs() < 1e-10);
468            }
469            ColumnarValue::Scalar(_) => {
470                panic!("Expected an array value")
471            }
472        }
473    }
474
475    #[test]
476    fn test_log_scalar_f32() {
477        let arg_fields = vec![
478            Field::new("a", DataType::Float32, false).into(),
479            Field::new("a", DataType::Float32, false).into(),
480        ];
481        let args = ScalarFunctionArgs {
482            args: vec![
483                ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // base
484                ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
485            ],
486            arg_fields,
487            number_rows: 1,
488            return_field: Field::new("f", DataType::Float32, true).into(),
489            config_options: Arc::new(ConfigOptions::default()),
490        };
491        let result = LogFunc::new()
492            .invoke_with_args(args)
493            .expect("failed to initialize function log");
494
495        match result {
496            ColumnarValue::Array(arr) => {
497                let floats = as_float32_array(&arr)
498                    .expect("failed to convert result to a Float32Array");
499
500                assert_eq!(floats.len(), 1);
501                assert!((floats.value(0) - 5.0).abs() < 1e-10);
502            }
503            ColumnarValue::Scalar(_) => {
504                panic!("Expected an array value")
505            }
506        }
507    }
508
509    #[test]
510    fn test_log_scalar_f64() {
511        let arg_fields = vec![
512            Field::new("a", DataType::Float64, false).into(),
513            Field::new("a", DataType::Float64, false).into(),
514        ];
515        let args = ScalarFunctionArgs {
516            args: vec![
517                ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // base
518                ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
519            ],
520            arg_fields,
521            number_rows: 1,
522            return_field: Field::new("f", DataType::Float64, true).into(),
523            config_options: Arc::new(ConfigOptions::default()),
524        };
525        let result = LogFunc::new()
526            .invoke_with_args(args)
527            .expect("failed to initialize function log");
528
529        match result {
530            ColumnarValue::Array(arr) => {
531                let floats = as_float64_array(&arr)
532                    .expect("failed to convert result to a Float64Array");
533
534                assert_eq!(floats.len(), 1);
535                assert!((floats.value(0) - 6.0).abs() < 1e-10);
536            }
537            ColumnarValue::Scalar(_) => {
538                panic!("Expected an array value")
539            }
540        }
541    }
542
543    #[test]
544    fn test_log_f64_unary() {
545        let arg_field = Field::new("a", DataType::Float64, false).into();
546        let args = ScalarFunctionArgs {
547            args: vec![
548                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
549                    10.0, 100.0, 1000.0, 10000.0,
550                ]))), // num
551            ],
552            arg_fields: vec![arg_field],
553            number_rows: 4,
554            return_field: Field::new("f", DataType::Float64, true).into(),
555            config_options: Arc::new(ConfigOptions::default()),
556        };
557        let result = LogFunc::new()
558            .invoke_with_args(args)
559            .expect("failed to initialize function log");
560
561        match result {
562            ColumnarValue::Array(arr) => {
563                let floats = as_float64_array(&arr)
564                    .expect("failed to convert result to a Float64Array");
565
566                assert_eq!(floats.len(), 4);
567                assert!((floats.value(0) - 1.0).abs() < 1e-10);
568                assert!((floats.value(1) - 2.0).abs() < 1e-10);
569                assert!((floats.value(2) - 3.0).abs() < 1e-10);
570                assert!((floats.value(3) - 4.0).abs() < 1e-10);
571            }
572            ColumnarValue::Scalar(_) => {
573                panic!("Expected an array value")
574            }
575        }
576    }
577
578    #[test]
579    fn test_log_f32_unary() {
580        let arg_field = Field::new("a", DataType::Float32, false).into();
581        let args = ScalarFunctionArgs {
582            args: vec![
583                ColumnarValue::Array(Arc::new(Float32Array::from(vec![
584                    10.0, 100.0, 1000.0, 10000.0,
585                ]))), // num
586            ],
587            arg_fields: vec![arg_field],
588            number_rows: 4,
589            return_field: Field::new("f", DataType::Float32, true).into(),
590            config_options: Arc::new(ConfigOptions::default()),
591        };
592        let result = LogFunc::new()
593            .invoke_with_args(args)
594            .expect("failed to initialize function log");
595
596        match result {
597            ColumnarValue::Array(arr) => {
598                let floats = as_float32_array(&arr)
599                    .expect("failed to convert result to a Float64Array");
600
601                assert_eq!(floats.len(), 4);
602                assert!((floats.value(0) - 1.0).abs() < 1e-10);
603                assert!((floats.value(1) - 2.0).abs() < 1e-10);
604                assert!((floats.value(2) - 3.0).abs() < 1e-10);
605                assert!((floats.value(3) - 4.0).abs() < 1e-10);
606            }
607            ColumnarValue::Scalar(_) => {
608                panic!("Expected an array value")
609            }
610        }
611    }
612
613    #[test]
614    fn test_log_f64() {
615        let arg_fields = vec![
616            Field::new("a", DataType::Float64, false).into(),
617            Field::new("a", DataType::Float64, false).into(),
618        ];
619        let args = ScalarFunctionArgs {
620            args: vec![
621                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
622                    2.0, 2.0, 3.0, 5.0, 5.0,
623                ]))), // base
624                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
625                    8.0, 4.0, 81.0, 625.0, -123.0,
626                ]))), // num
627            ],
628            arg_fields,
629            number_rows: 5,
630            return_field: Field::new("f", DataType::Float64, true).into(),
631            config_options: Arc::new(ConfigOptions::default()),
632        };
633        let result = LogFunc::new()
634            .invoke_with_args(args)
635            .expect("failed to initialize function log");
636
637        match result {
638            ColumnarValue::Array(arr) => {
639                let floats = as_float64_array(&arr)
640                    .expect("failed to convert result to a Float64Array");
641
642                assert_eq!(floats.len(), 5);
643                assert!((floats.value(0) - 3.0).abs() < 1e-10);
644                assert!((floats.value(1) - 2.0).abs() < 1e-10);
645                assert!((floats.value(2) - 4.0).abs() < 1e-10);
646                assert!((floats.value(3) - 4.0).abs() < 1e-10);
647                assert!(floats.value(4).is_nan());
648            }
649            ColumnarValue::Scalar(_) => {
650                panic!("Expected an array value")
651            }
652        }
653    }
654
655    #[test]
656    fn test_log_f32() {
657        let arg_fields = vec![
658            Field::new("a", DataType::Float32, false).into(),
659            Field::new("a", DataType::Float32, false).into(),
660        ];
661        let args = ScalarFunctionArgs {
662            args: vec![
663                ColumnarValue::Array(Arc::new(Float32Array::from(vec![
664                    2.0, 2.0, 3.0, 5.0,
665                ]))), // base
666                ColumnarValue::Array(Arc::new(Float32Array::from(vec![
667                    8.0, 4.0, 81.0, 625.0,
668                ]))), // num
669            ],
670            arg_fields,
671            number_rows: 4,
672            return_field: Field::new("f", DataType::Float32, true).into(),
673            config_options: Arc::new(ConfigOptions::default()),
674        };
675        let result = LogFunc::new()
676            .invoke_with_args(args)
677            .expect("failed to initialize function log");
678
679        match result {
680            ColumnarValue::Array(arr) => {
681                let floats = as_float32_array(&arr)
682                    .expect("failed to convert result to a Float32Array");
683
684                assert_eq!(floats.len(), 4);
685                assert!((floats.value(0) - 3.0).abs() < f32::EPSILON);
686                assert!((floats.value(1) - 2.0).abs() < f32::EPSILON);
687                assert!((floats.value(2) - 4.0).abs() < f32::EPSILON);
688                assert!((floats.value(3) - 4.0).abs() < f32::EPSILON);
689            }
690            ColumnarValue::Scalar(_) => {
691                panic!("Expected an array value")
692            }
693        }
694    }
695    #[test]
696    // Test log() simplification errors
697    fn test_log_simplify_errors() {
698        let props = ExecutionProps::new();
699        let schema =
700            Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap());
701        let context = SimplifyContext::new(&props).with_schema(schema);
702        // Expect 0 args to error
703        let _ = LogFunc::new().simplify(vec![], &context).unwrap_err();
704        // Expect 3 args to error
705        let _ = LogFunc::new()
706            .simplify(vec![lit(1), lit(2), lit(3)], &context)
707            .unwrap_err();
708    }
709
710    #[test]
711    // Test that non-simplifiable log() expressions are unchanged after simplification
712    fn test_log_simplify_original() {
713        let props = ExecutionProps::new();
714        let schema =
715            Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap());
716        let context = SimplifyContext::new(&props).with_schema(schema);
717        // One argument with no simplifications
718        let result = LogFunc::new().simplify(vec![lit(2)], &context).unwrap();
719        let ExprSimplifyResult::Original(args) = result else {
720            panic!("Expected ExprSimplifyResult::Original")
721        };
722        assert_eq!(args.len(), 1);
723        assert_eq!(args[0], lit(2));
724        // Two arguments with no simplifications
725        let result = LogFunc::new()
726            .simplify(vec![lit(2), lit(3)], &context)
727            .unwrap();
728        let ExprSimplifyResult::Original(args) = result else {
729            panic!("Expected ExprSimplifyResult::Original")
730        };
731        assert_eq!(args.len(), 2);
732        assert_eq!(args[0], lit(2));
733        assert_eq!(args[1], lit(3));
734    }
735
736    #[test]
737    fn test_log_output_ordering() {
738        // [Unordered, Ascending, Descending, Literal]
739        let orders = [
740            ExprProperties::new_unknown(),
741            ExprProperties::new_unknown().with_order(SortProperties::Ordered(
742                SortOptions {
743                    descending: false,
744                    nulls_first: true,
745                },
746            )),
747            ExprProperties::new_unknown().with_order(SortProperties::Ordered(
748                SortOptions {
749                    descending: true,
750                    nulls_first: true,
751                },
752            )),
753            ExprProperties::new_unknown().with_order(SortProperties::Singleton),
754        ];
755
756        let log = LogFunc::new();
757
758        // Test log(num)
759        for order in orders.iter().cloned() {
760            let result = log.output_ordering(std::slice::from_ref(&order)).unwrap();
761            assert_eq!(result, order.sort_properties);
762        }
763
764        // Test log(base, num), where `nulls_first` is the same
765        let mut results = Vec::with_capacity(orders.len() * orders.len());
766        for base_order in orders.iter() {
767            for num_order in orders.iter().cloned() {
768                let result = log
769                    .output_ordering(&[base_order.clone(), num_order])
770                    .unwrap();
771                results.push(result);
772            }
773        }
774        let expected = [
775            // base: Unordered
776            SortProperties::Unordered,
777            SortProperties::Unordered,
778            SortProperties::Unordered,
779            SortProperties::Unordered,
780            // base: Ascending, num: Unordered
781            SortProperties::Unordered,
782            // base: Ascending, num: Ascending
783            SortProperties::Unordered,
784            // base: Ascending, num: Descending
785            SortProperties::Ordered(SortOptions {
786                descending: true,
787                nulls_first: true,
788            }),
789            // base: Ascending, num: Literal
790            SortProperties::Ordered(SortOptions {
791                descending: true,
792                nulls_first: true,
793            }),
794            // base: Descending, num: Unordered
795            SortProperties::Unordered,
796            // base: Descending, num: Ascending
797            SortProperties::Ordered(SortOptions {
798                descending: false,
799                nulls_first: true,
800            }),
801            // base: Descending, num: Descending
802            SortProperties::Unordered,
803            // base: Descending, num: Literal
804            SortProperties::Ordered(SortOptions {
805                descending: false,
806                nulls_first: true,
807            }),
808            // base: Literal, num: Unordered
809            SortProperties::Unordered,
810            // base: Literal, num: Ascending
811            SortProperties::Ordered(SortOptions {
812                descending: false,
813                nulls_first: true,
814            }),
815            // base: Literal, num: Descending
816            SortProperties::Ordered(SortOptions {
817                descending: true,
818                nulls_first: true,
819            }),
820            // base: Literal, num: Literal
821            SortProperties::Singleton,
822        ];
823        assert_eq!(results, expected);
824
825        // Test with different `nulls_first`
826        let base_order = ExprProperties::new_unknown().with_order(
827            SortProperties::Ordered(SortOptions {
828                descending: true,
829                nulls_first: true,
830            }),
831        );
832        let num_order = ExprProperties::new_unknown().with_order(
833            SortProperties::Ordered(SortOptions {
834                descending: false,
835                nulls_first: false,
836            }),
837        );
838        assert_eq!(
839            log.output_ordering(&[base_order, num_order]).unwrap(),
840            SortProperties::Unordered
841        );
842    }
843
844    #[test]
845    fn test_log_scalar_decimal128_unary() {
846        let arg_field = Field::new("a", DataType::Decimal128(38, 0), false).into();
847        let args = ScalarFunctionArgs {
848            args: vec![
849                ColumnarValue::Scalar(ScalarValue::Decimal128(Some(10), 38, 0)), // num
850            ],
851            arg_fields: vec![arg_field],
852            number_rows: 1,
853            return_field: Field::new("f", DataType::Decimal128(38, 0), true).into(),
854            config_options: Arc::new(ConfigOptions::default()),
855        };
856        let result = LogFunc::new()
857            .invoke_with_args(args)
858            .expect("failed to initialize function log");
859
860        match result {
861            ColumnarValue::Array(arr) => {
862                let floats = as_float64_array(&arr)
863                    .expect("failed to convert result to a Decimal128Array");
864                assert_eq!(floats.len(), 1);
865                assert!((floats.value(0) - 1.0).abs() < 1e-10);
866            }
867            ColumnarValue::Scalar(_) => {
868                panic!("Expected an array value")
869            }
870        }
871    }
872
873    #[test]
874    fn test_log_scalar_decimal128() {
875        let arg_fields = vec![
876            Field::new("b", DataType::Float64, false).into(),
877            Field::new("x", DataType::Decimal128(38, 0), false).into(),
878        ];
879        let args = ScalarFunctionArgs {
880            args: vec![
881                ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // base
882                ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num
883            ],
884            arg_fields,
885            number_rows: 1,
886            return_field: Field::new("f", DataType::Float64, true).into(),
887            config_options: Arc::new(ConfigOptions::default()),
888        };
889        let result = LogFunc::new()
890            .invoke_with_args(args)
891            .expect("failed to initialize function log");
892
893        match result {
894            ColumnarValue::Array(arr) => {
895                let floats = as_float64_array(&arr)
896                    .expect("failed to convert result to a Float64Array");
897
898                assert_eq!(floats.len(), 1);
899                assert!((floats.value(0) - 6.0).abs() < 1e-10);
900            }
901            ColumnarValue::Scalar(_) => {
902                panic!("Expected an array value")
903            }
904        }
905    }
906
907    #[test]
908    fn test_log_decimal128_unary() {
909        let arg_field = Field::new("a", DataType::Decimal128(38, 0), false).into();
910        let args = ScalarFunctionArgs {
911            args: vec![
912                ColumnarValue::Array(Arc::new(
913                    Decimal128Array::from(vec![10, 100, 1000, 10000, 12600, -123])
914                        .with_precision_and_scale(38, 0)
915                        .unwrap(),
916                )), // num
917            ],
918            arg_fields: vec![arg_field],
919            number_rows: 6,
920            return_field: Field::new("f", DataType::Float64, true).into(),
921            config_options: Arc::new(ConfigOptions::default()),
922        };
923        let result = LogFunc::new()
924            .invoke_with_args(args)
925            .expect("failed to initialize function log");
926
927        match result {
928            ColumnarValue::Array(arr) => {
929                let floats = as_float64_array(&arr)
930                    .expect("failed to convert result to a Float64Array");
931
932                assert_eq!(floats.len(), 6);
933                assert!((floats.value(0) - 1.0).abs() < 1e-10);
934                assert!((floats.value(1) - 2.0).abs() < 1e-10);
935                assert!((floats.value(2) - 3.0).abs() < 1e-10);
936                assert!((floats.value(3) - 4.0).abs() < 1e-10);
937                assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding
938                assert!(floats.value(5).is_nan());
939            }
940            ColumnarValue::Scalar(_) => {
941                panic!("Expected an array value")
942            }
943        }
944    }
945
946    #[test]
947    fn test_log_decimal128_base_decimal() {
948        // Base stays 2 despite scaling
949        for base in [
950            ScalarValue::Decimal128(Some(i128::from(2)), 38, 0),
951            ScalarValue::Decimal128(Some(i128::from(2000)), 38, 3),
952        ] {
953            let arg_fields = vec![
954                Field::new("b", DataType::Decimal128(38, 0), false).into(),
955                Field::new("x", DataType::Decimal128(38, 0), false).into(),
956            ];
957            let args = ScalarFunctionArgs {
958                args: vec![
959                    ColumnarValue::Scalar(base), // base
960                    ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num
961                ],
962                arg_fields,
963                number_rows: 1,
964                return_field: Field::new("f", DataType::Float64, true).into(),
965                config_options: Arc::new(ConfigOptions::default()),
966            };
967            let result = LogFunc::new()
968                .invoke_with_args(args)
969                .expect("failed to initialize function log");
970
971            match result {
972                ColumnarValue::Array(arr) => {
973                    let floats = as_float64_array(&arr)
974                        .expect("failed to convert result to a Float64Array");
975
976                    assert_eq!(floats.len(), 1);
977                    assert!((floats.value(0) - 6.0).abs() < 1e-10);
978                }
979                ColumnarValue::Scalar(_) => {
980                    panic!("Expected an array value")
981                }
982            }
983        }
984    }
985
986    #[test]
987    fn test_log_decimal128_value_scale() {
988        // Value stays 1000 despite scaling
989        for value in [
990            ScalarValue::Decimal128(Some(i128::from(1000)), 38, 0),
991            ScalarValue::Decimal128(Some(i128::from(10000)), 38, 1),
992            ScalarValue::Decimal128(Some(i128::from(1000000)), 38, 3),
993        ] {
994            let arg_fields = vec![
995                Field::new("b", DataType::Decimal128(38, 0), false).into(),
996                Field::new("x", DataType::Decimal128(38, 0), false).into(),
997            ];
998            let args = ScalarFunctionArgs {
999                args: vec![
1000                    ColumnarValue::Scalar(value), // base
1001                ],
1002                arg_fields,
1003                number_rows: 1,
1004                return_field: Field::new("f", DataType::Float64, true).into(),
1005                config_options: Arc::new(ConfigOptions::default()),
1006            };
1007            let result = LogFunc::new()
1008                .invoke_with_args(args)
1009                .expect("failed to initialize function log");
1010
1011            match result {
1012                ColumnarValue::Array(arr) => {
1013                    let floats = as_float64_array(&arr)
1014                        .expect("failed to convert result to a Float64Array");
1015
1016                    assert_eq!(floats.len(), 1);
1017                    assert!((floats.value(0) - 3.0).abs() < 1e-10);
1018                }
1019                ColumnarValue::Scalar(_) => {
1020                    panic!("Expected an array value")
1021                }
1022            }
1023        }
1024    }
1025
1026    #[test]
1027    fn test_log_decimal256_unary() {
1028        let arg_field = Field::new(
1029            "a",
1030            DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0),
1031            false,
1032        )
1033        .into();
1034        let args = ScalarFunctionArgs {
1035            args: vec![
1036                ColumnarValue::Array(Arc::new(
1037                    Decimal256Array::from(vec![
1038                        Some(i256::from(10)),
1039                        Some(i256::from(100)),
1040                        Some(i256::from(1000)),
1041                        Some(i256::from(10000)),
1042                        Some(i256::from(12600)),
1043                        // Slightly lower than i128 max - can calculate
1044                        Some(i256::from_i128(i128::MAX) - i256::from(1000)),
1045                        // Give NaN for incorrect inputs, as in f64::log
1046                        Some(i256::from(-123)),
1047                    ])
1048                    .with_precision_and_scale(DECIMAL256_MAX_PRECISION, 0)
1049                    .unwrap(),
1050                )), // num
1051            ],
1052            arg_fields: vec![arg_field],
1053            number_rows: 7,
1054            return_field: Field::new("f", DataType::Float64, true).into(),
1055            config_options: Arc::new(ConfigOptions::default()),
1056        };
1057        let result = LogFunc::new()
1058            .invoke_with_args(args)
1059            .expect("failed to initialize function log");
1060
1061        match result {
1062            ColumnarValue::Array(arr) => {
1063                let floats = as_float64_array(&arr)
1064                    .expect("failed to convert result to a Float64Array");
1065
1066                assert_eq!(floats.len(), 7);
1067                eprintln!("floats {:?}", &floats);
1068                assert!((floats.value(0) - 1.0).abs() < 1e-10);
1069                assert!((floats.value(1) - 2.0).abs() < 1e-10);
1070                assert!((floats.value(2) - 3.0).abs() < 1e-10);
1071                assert!((floats.value(3) - 4.0).abs() < 1e-10);
1072                assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding for float log
1073                assert!((floats.value(5) - 38.0).abs() < 1e-10);
1074                assert!(floats.value(6).is_nan());
1075            }
1076            ColumnarValue::Scalar(_) => {
1077                panic!("Expected an array value")
1078            }
1079        }
1080    }
1081
1082    #[test]
1083    fn test_log_decimal128_wrong_base() {
1084        let arg_fields = vec![
1085            Field::new("b", DataType::Float64, false).into(),
1086            Field::new("x", DataType::Decimal128(38, 0), false).into(),
1087        ];
1088        let args = ScalarFunctionArgs {
1089            args: vec![
1090                ColumnarValue::Scalar(ScalarValue::Float64(Some(-2.0))), // base
1091                ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num
1092            ],
1093            arg_fields,
1094            number_rows: 1,
1095            return_field: Field::new("f", DataType::Float64, true).into(),
1096            config_options: Arc::new(ConfigOptions::default()),
1097        };
1098        let result = LogFunc::new().invoke_with_args(args);
1099        assert!(result.is_err());
1100        assert_eq!(
1101            "Arrow error: Compute error: Log base must be greater than 1: -2",
1102            result.unwrap_err().to_string().lines().next().unwrap()
1103        );
1104    }
1105
1106    #[test]
1107    fn test_log_decimal256_error() {
1108        let arg_field = Field::new("a", DataType::Decimal256(38, 0), false).into();
1109        let args = ScalarFunctionArgs {
1110            args: vec![
1111                ColumnarValue::Array(Arc::new(Decimal256Array::from(vec![
1112                    // Slightly larger than i128
1113                    Some(i256::from_i128(i128::MAX) + i256::from(1000)),
1114                ]))), // num
1115            ],
1116            arg_fields: vec![arg_field],
1117            number_rows: 1,
1118            return_field: Field::new("f", DataType::Float64, true).into(),
1119            config_options: Arc::new(ConfigOptions::default()),
1120        };
1121        let result = LogFunc::new().invoke_with_args(args);
1122        assert!(result.is_err());
1123        assert_eq!(result.unwrap_err().to_string().lines().next().unwrap(),
1124            "Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727"
1125        );
1126    }
1127}