datafusion_comet_spark_expr/math_funcs/internal/
make_decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::math_funcs::utils::get_precision_scale;
19use arrow::datatypes::DataType;
20use arrow::{
21    array::{AsArray, Decimal128Builder},
22    datatypes::{validate_decimal_precision, Int64Type},
23};
24use datafusion::common::{internal_err, Result as DataFusionResult, ScalarValue};
25use datafusion::physical_plan::ColumnarValue;
26use std::sync::Arc;
27
28/// Spark-compatible `MakeDecimal` expression (internal to Spark optimizer)
29pub fn spark_make_decimal(
30    args: &[ColumnarValue],
31    data_type: &DataType,
32) -> DataFusionResult<ColumnarValue> {
33    let (precision, scale) = get_precision_scale(data_type);
34    match &args[0] {
35        ColumnarValue::Scalar(v) => match v {
36            ScalarValue::Int64(n) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
37                long_to_decimal(n, precision),
38                precision,
39                scale,
40            ))),
41            sv => internal_err!("Expected Int64 but found {sv:?}"),
42        },
43        ColumnarValue::Array(a) => {
44            let arr = a.as_primitive::<Int64Type>();
45            let mut result = Decimal128Builder::new();
46            for v in arr.into_iter() {
47                result.append_option(long_to_decimal(&v, precision))
48            }
49            let result_type = DataType::Decimal128(precision, scale);
50
51            Ok(ColumnarValue::Array(Arc::new(
52                result.finish().with_data_type(result_type),
53            )))
54        }
55    }
56}
57
58/// Convert the input long to decimal with the given maximum precision. If overflows, returns null
59/// instead.
60#[inline]
61fn long_to_decimal(v: &Option<i64>, precision: u8) -> Option<i128> {
62    match v {
63        Some(v) if validate_decimal_precision(*v as i128, precision).is_ok() => Some(*v as i128),
64        _ => None,
65    }
66}