Skip to main content

datafusion_functions/math/
trunc.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::utils::make_scalar_function;
21
22use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
23use arrow::datatypes::DataType::{Float32, Float64};
24use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
25use datafusion_common::ScalarValue::Int64;
26use datafusion_common::{Result, ScalarValue, exec_err};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::{
30    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31    Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36    doc_section(label = "Math Functions"),
37    description = "Truncates a number to a whole number or truncated to the specified decimal places.",
38    syntax_example = "trunc(numeric_expression[, decimal_places])",
39    standard_argument(name = "numeric_expression", prefix = "Numeric"),
40    argument(
41        name = "decimal_places",
42        description = r#"Optional. The number of decimal places to
43  truncate to. Defaults to 0 (truncate to a whole number). If
44  `decimal_places` is a positive integer, truncates digits to the
45  right of the decimal point. If `decimal_places` is a negative
46  integer, replaces digits to the left of the decimal point with `0`."#
47    ),
48    sql_example = r#"
49  ```sql
50  > SELECT trunc(42.738);
51  +----------------+
52  | trunc(42.738)  |
53  +----------------+
54  | 42             |
55  +----------------+
56  ```"#
57)]
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct TruncFunc {
60    signature: Signature,
61}
62
63impl Default for TruncFunc {
64    fn default() -> Self {
65        TruncFunc::new()
66    }
67}
68
69impl TruncFunc {
70    pub fn new() -> Self {
71        use DataType::*;
72        Self {
73            // math expressions expect 1 argument of type f64 or f32
74            // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
75            // return the best approximation for it (in f64).
76            // We accept f32 because in this case it is clear that the best approximation
77            // will be as good as the number of digits in the number
78            signature: Signature::one_of(
79                vec![
80                    Exact(vec![Float32, Int64]),
81                    Exact(vec![Float64, Int64]),
82                    Exact(vec![Float64]),
83                    Exact(vec![Float32]),
84                ],
85                Volatility::Immutable,
86            ),
87        }
88    }
89}
90
91impl ScalarUDFImpl for TruncFunc {
92    fn name(&self) -> &str {
93        "trunc"
94    }
95
96    fn signature(&self) -> &Signature {
97        &self.signature
98    }
99
100    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
101        match arg_types[0] {
102            Float32 => Ok(Float32),
103            _ => Ok(Float64),
104        }
105    }
106
107    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
108        // Extract precision from second argument (default 0)
109        let precision = match args.args.get(1) {
110            Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p),
111            Some(ColumnarValue::Scalar(Int64(None))) => None, // null precision
112            Some(ColumnarValue::Array(_)) => {
113                // Precision is an array - use array path
114                return make_scalar_function(trunc, vec![])(&args.args);
115            }
116            None => Some(0), // default precision
117            Some(cv) => {
118                return exec_err!(
119                    "trunc function requires precision to be Int64, got {:?}",
120                    cv.data_type()
121                );
122            }
123        };
124
125        // Scalar fast path using tuple matching for (value, precision)
126        match (&args.args[0], precision) {
127            // Null cases
128            (ColumnarValue::Scalar(sv), _) if sv.is_null() => {
129                ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
130            }
131            (_, None) => {
132                ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
133            }
134            // Scalar cases
135            (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), Some(p)) => Ok(
136                ColumnarValue::Scalar(ScalarValue::Float64(Some(if p == 0 {
137                    v.trunc()
138                } else {
139                    compute_truncate64(*v, p)
140                }))),
141            ),
142            (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), Some(p)) => Ok(
143                ColumnarValue::Scalar(ScalarValue::Float32(Some(if p == 0 {
144                    v.trunc()
145                } else {
146                    compute_truncate32(*v, p)
147                }))),
148            ),
149            // Array path for everything else
150            _ => make_scalar_function(trunc, vec![])(&args.args),
151        }
152    }
153
154    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
155        // trunc preserves the order of the first argument
156        let value = &input[0];
157        let precision = input.get(1);
158
159        if precision
160            .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
161            .unwrap_or(true)
162        {
163            Ok(value.sort_properties)
164        } else {
165            Ok(SortProperties::Unordered)
166        }
167    }
168
169    fn documentation(&self) -> Option<&Documentation> {
170        self.doc()
171    }
172}
173
174/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function
175fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
176    if args.len() != 1 && args.len() != 2 {
177        return exec_err!(
178            "truncate function requires one or two arguments, got {}",
179            args.len()
180        );
181    }
182
183    // If only one arg then invoke toolchain trunc(num) and precision = 0 by default
184    // or then invoke the compute_truncate method to process precision
185    let num = &args[0];
186    let precision = if args.len() == 1 {
187        ColumnarValue::Scalar(Int64(Some(0)))
188    } else {
189        ColumnarValue::Array(Arc::clone(&args[1]))
190    };
191
192    match num.data_type() {
193        Float64 => match precision {
194            ColumnarValue::Scalar(Int64(Some(0))) => {
195                Ok(Arc::new(
196                    args[0]
197                        .as_primitive::<Float64Type>()
198                        .unary::<_, Float64Type>(|x: f64| {
199                            if x == 0_f64 { 0_f64 } else { x.trunc() }
200                        }),
201                ) as ArrayRef)
202            }
203            ColumnarValue::Array(precision) => {
204                let num_array = num.as_primitive::<Float64Type>();
205                let precision_array = precision.as_primitive::<Int64Type>();
206                let result: PrimitiveArray<Float64Type> =
207                    arrow::compute::binary(num_array, precision_array, |x, y| {
208                        compute_truncate64(x, y)
209                    })?;
210
211                Ok(Arc::new(result) as ArrayRef)
212            }
213            _ => exec_err!("trunc function requires a scalar or array for precision"),
214        },
215        Float32 => match precision {
216            ColumnarValue::Scalar(Int64(Some(0))) => {
217                Ok(Arc::new(
218                    args[0]
219                        .as_primitive::<Float32Type>()
220                        .unary::<_, Float32Type>(|x: f32| {
221                            if x == 0_f32 { 0_f32 } else { x.trunc() }
222                        }),
223                ) as ArrayRef)
224            }
225            ColumnarValue::Array(precision) => {
226                let num_array = num.as_primitive::<Float32Type>();
227                let precision_array = precision.as_primitive::<Int64Type>();
228                let result: PrimitiveArray<Float32Type> =
229                    arrow::compute::binary(num_array, precision_array, |x, y| {
230                        compute_truncate32(x, y)
231                    })?;
232
233                Ok(Arc::new(result) as ArrayRef)
234            }
235            _ => exec_err!("trunc function requires a scalar or array for precision"),
236        },
237        other => exec_err!("Unsupported data type {other:?} for function trunc"),
238    }
239}
240
241fn compute_truncate32(x: f32, y: i64) -> f32 {
242    let factor = 10.0_f32.powi(y as i32);
243    (x * factor).trunc() / factor
244}
245
246fn compute_truncate64(x: f64, y: i64) -> f64 {
247    let factor = 10.0_f64.powi(y as i32);
248    (x * factor).trunc() / factor
249}
250
251#[cfg(test)]
252mod test {
253    use std::sync::Arc;
254
255    use crate::math::trunc::trunc;
256
257    use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
258    use datafusion_common::cast::{as_float32_array, as_float64_array};
259
260    #[test]
261    fn test_truncate_32() {
262        let args: Vec<ArrayRef> = vec![
263            Arc::new(Float32Array::from(vec![
264                15.0,
265                1_234.267_8,
266                1_233.123_4,
267                3.312_979_2,
268                -21.123_4,
269            ])),
270            Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
271        ];
272
273        let result = trunc(&args).expect("failed to initialize function truncate");
274        let floats =
275            as_float32_array(&result).expect("failed to initialize function truncate");
276
277        assert_eq!(floats.len(), 5);
278        assert_eq!(floats.value(0), 15.0);
279        assert_eq!(floats.value(1), 1_234.267);
280        assert_eq!(floats.value(2), 1_233.12);
281        assert_eq!(floats.value(3), 3.312_97);
282        assert_eq!(floats.value(4), -21.123_4);
283    }
284
285    #[test]
286    fn test_truncate_64() {
287        let args: Vec<ArrayRef> = vec![
288            Arc::new(Float64Array::from(vec![
289                5.0,
290                234.267_812_176,
291                123.123_456_789,
292                123.312_979_313_2,
293                -321.123_1,
294            ])),
295            Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
296        ];
297
298        let result = trunc(&args).expect("failed to initialize function truncate");
299        let floats =
300            as_float64_array(&result).expect("failed to initialize function truncate");
301
302        assert_eq!(floats.len(), 5);
303        assert_eq!(floats.value(0), 5.0);
304        assert_eq!(floats.value(1), 234.267);
305        assert_eq!(floats.value(2), 123.12);
306        assert_eq!(floats.value(3), 123.312_97);
307        assert_eq!(floats.value(4), -321.123_1);
308    }
309
310    #[test]
311    fn test_truncate_64_one_arg() {
312        let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
313            5.0,
314            234.267_812,
315            123.123_45,
316            123.312_979_313_2,
317            -321.123,
318        ]))];
319
320        let result = trunc(&args).expect("failed to initialize function truncate");
321        let floats =
322            as_float64_array(&result).expect("failed to initialize function truncate");
323
324        assert_eq!(floats.len(), 5);
325        assert_eq!(floats.value(0), 5.0);
326        assert_eq!(floats.value(1), 234.0);
327        assert_eq!(floats.value(2), 123.0);
328        assert_eq!(floats.value(3), 123.0);
329        assert_eq!(floats.value(4), -321.0);
330    }
331}