Skip to main content

datafusion_functions/math/
trunc.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
26use datafusion_common::ScalarValue::Int64;
27use datafusion_common::{Result, ScalarValue, exec_err};
28use datafusion_expr::TypeSignature::Exact;
29use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
30use datafusion_expr::{
31    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32    Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "Math Functions"),
38    description = "Truncates a number to a whole number or truncated to the specified decimal places.",
39    syntax_example = "trunc(numeric_expression[, decimal_places])",
40    standard_argument(name = "numeric_expression", prefix = "Numeric"),
41    argument(
42        name = "decimal_places",
43        description = r#"Optional. The number of decimal places to
44  truncate to. Defaults to 0 (truncate to a whole number). If
45  `decimal_places` is a positive integer, truncates digits to the
46  right of the decimal point. If `decimal_places` is a negative
47  integer, replaces digits to the left of the decimal point with `0`."#
48    ),
49    sql_example = r#"
50  ```sql
51  > SELECT trunc(42.738);
52  +----------------+
53  | trunc(42.738)  |
54  +----------------+
55  | 42             |
56  +----------------+
57  ```"#
58)]
59#[derive(Debug, PartialEq, Eq, Hash)]
60pub struct TruncFunc {
61    signature: Signature,
62}
63
64impl Default for TruncFunc {
65    fn default() -> Self {
66        TruncFunc::new()
67    }
68}
69
70impl TruncFunc {
71    pub fn new() -> Self {
72        use DataType::*;
73        Self {
74            // math expressions expect 1 argument of type f64 or f32
75            // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
76            // return the best approximation for it (in f64).
77            // We accept f32 because in this case it is clear that the best approximation
78            // will be as good as the number of digits in the number
79            signature: Signature::one_of(
80                vec![
81                    Exact(vec![Float32, Int64]),
82                    Exact(vec![Float64, Int64]),
83                    Exact(vec![Float64]),
84                    Exact(vec![Float32]),
85                ],
86                Volatility::Immutable,
87            ),
88        }
89    }
90}
91
92impl ScalarUDFImpl for TruncFunc {
93    fn as_any(&self) -> &dyn Any {
94        self
95    }
96
97    fn name(&self) -> &str {
98        "trunc"
99    }
100
101    fn signature(&self) -> &Signature {
102        &self.signature
103    }
104
105    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
106        match arg_types[0] {
107            Float32 => Ok(Float32),
108            _ => Ok(Float64),
109        }
110    }
111
112    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
113        // Extract precision from second argument (default 0)
114        let precision = match args.args.get(1) {
115            Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p),
116            Some(ColumnarValue::Scalar(Int64(None))) => None, // null precision
117            Some(ColumnarValue::Array(_)) => {
118                // Precision is an array - use array path
119                return make_scalar_function(trunc, vec![])(&args.args);
120            }
121            None => Some(0), // default precision
122            Some(cv) => {
123                return exec_err!(
124                    "trunc function requires precision to be Int64, got {:?}",
125                    cv.data_type()
126                );
127            }
128        };
129
130        // Scalar fast path using tuple matching for (value, precision)
131        match (&args.args[0], precision) {
132            // Null cases
133            (ColumnarValue::Scalar(sv), _) if sv.is_null() => {
134                ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
135            }
136            (_, None) => {
137                ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
138            }
139            // Scalar cases
140            (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), Some(p)) => Ok(
141                ColumnarValue::Scalar(ScalarValue::Float64(Some(if p == 0 {
142                    v.trunc()
143                } else {
144                    compute_truncate64(*v, p)
145                }))),
146            ),
147            (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), Some(p)) => Ok(
148                ColumnarValue::Scalar(ScalarValue::Float32(Some(if p == 0 {
149                    v.trunc()
150                } else {
151                    compute_truncate32(*v, p)
152                }))),
153            ),
154            // Array path for everything else
155            _ => make_scalar_function(trunc, vec![])(&args.args),
156        }
157    }
158
159    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
160        // trunc preserves the order of the first argument
161        let value = &input[0];
162        let precision = input.get(1);
163
164        if precision
165            .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
166            .unwrap_or(true)
167        {
168            Ok(value.sort_properties)
169        } else {
170            Ok(SortProperties::Unordered)
171        }
172    }
173
174    fn documentation(&self) -> Option<&Documentation> {
175        self.doc()
176    }
177}
178
179/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function
180fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
181    if args.len() != 1 && args.len() != 2 {
182        return exec_err!(
183            "truncate function requires one or two arguments, got {}",
184            args.len()
185        );
186    }
187
188    // If only one arg then invoke toolchain trunc(num) and precision = 0 by default
189    // or then invoke the compute_truncate method to process precision
190    let num = &args[0];
191    let precision = if args.len() == 1 {
192        ColumnarValue::Scalar(Int64(Some(0)))
193    } else {
194        ColumnarValue::Array(Arc::clone(&args[1]))
195    };
196
197    match num.data_type() {
198        Float64 => match precision {
199            ColumnarValue::Scalar(Int64(Some(0))) => {
200                Ok(Arc::new(
201                    args[0]
202                        .as_primitive::<Float64Type>()
203                        .unary::<_, Float64Type>(|x: f64| {
204                            if x == 0_f64 { 0_f64 } else { x.trunc() }
205                        }),
206                ) as ArrayRef)
207            }
208            ColumnarValue::Array(precision) => {
209                let num_array = num.as_primitive::<Float64Type>();
210                let precision_array = precision.as_primitive::<Int64Type>();
211                let result: PrimitiveArray<Float64Type> =
212                    arrow::compute::binary(num_array, precision_array, |x, y| {
213                        compute_truncate64(x, y)
214                    })?;
215
216                Ok(Arc::new(result) as ArrayRef)
217            }
218            _ => exec_err!("trunc function requires a scalar or array for precision"),
219        },
220        Float32 => match precision {
221            ColumnarValue::Scalar(Int64(Some(0))) => {
222                Ok(Arc::new(
223                    args[0]
224                        .as_primitive::<Float32Type>()
225                        .unary::<_, Float32Type>(|x: f32| {
226                            if x == 0_f32 { 0_f32 } else { x.trunc() }
227                        }),
228                ) as ArrayRef)
229            }
230            ColumnarValue::Array(precision) => {
231                let num_array = num.as_primitive::<Float32Type>();
232                let precision_array = precision.as_primitive::<Int64Type>();
233                let result: PrimitiveArray<Float32Type> =
234                    arrow::compute::binary(num_array, precision_array, |x, y| {
235                        compute_truncate32(x, y)
236                    })?;
237
238                Ok(Arc::new(result) as ArrayRef)
239            }
240            _ => exec_err!("trunc function requires a scalar or array for precision"),
241        },
242        other => exec_err!("Unsupported data type {other:?} for function trunc"),
243    }
244}
245
246fn compute_truncate32(x: f32, y: i64) -> f32 {
247    let factor = 10.0_f32.powi(y as i32);
248    (x * factor).trunc() / factor
249}
250
251fn compute_truncate64(x: f64, y: i64) -> f64 {
252    let factor = 10.0_f64.powi(y as i32);
253    (x * factor).trunc() / factor
254}
255
256#[cfg(test)]
257mod test {
258    use std::sync::Arc;
259
260    use crate::math::trunc::trunc;
261
262    use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
263    use datafusion_common::cast::{as_float32_array, as_float64_array};
264
265    #[test]
266    fn test_truncate_32() {
267        let args: Vec<ArrayRef> = vec![
268            Arc::new(Float32Array::from(vec![
269                15.0,
270                1_234.267_8,
271                1_233.123_4,
272                3.312_979_2,
273                -21.123_4,
274            ])),
275            Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
276        ];
277
278        let result = trunc(&args).expect("failed to initialize function truncate");
279        let floats =
280            as_float32_array(&result).expect("failed to initialize function truncate");
281
282        assert_eq!(floats.len(), 5);
283        assert_eq!(floats.value(0), 15.0);
284        assert_eq!(floats.value(1), 1_234.267);
285        assert_eq!(floats.value(2), 1_233.12);
286        assert_eq!(floats.value(3), 3.312_97);
287        assert_eq!(floats.value(4), -21.123_4);
288    }
289
290    #[test]
291    fn test_truncate_64() {
292        let args: Vec<ArrayRef> = vec![
293            Arc::new(Float64Array::from(vec![
294                5.0,
295                234.267_812_176,
296                123.123_456_789,
297                123.312_979_313_2,
298                -321.123_1,
299            ])),
300            Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
301        ];
302
303        let result = trunc(&args).expect("failed to initialize function truncate");
304        let floats =
305            as_float64_array(&result).expect("failed to initialize function truncate");
306
307        assert_eq!(floats.len(), 5);
308        assert_eq!(floats.value(0), 5.0);
309        assert_eq!(floats.value(1), 234.267);
310        assert_eq!(floats.value(2), 123.12);
311        assert_eq!(floats.value(3), 123.312_97);
312        assert_eq!(floats.value(4), -321.123_1);
313    }
314
315    #[test]
316    fn test_truncate_64_one_arg() {
317        let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
318            5.0,
319            234.267_812,
320            123.123_45,
321            123.312_979_313_2,
322            -321.123,
323        ]))];
324
325        let result = trunc(&args).expect("failed to initialize function truncate");
326        let floats =
327            as_float64_array(&result).expect("failed to initialize function truncate");
328
329        assert_eq!(floats.len(), 5);
330        assert_eq!(floats.value(0), 5.0);
331        assert_eq!(floats.value(1), 234.0);
332        assert_eq!(floats.value(2), 123.0);
333        assert_eq!(floats.value(3), 123.0);
334        assert_eq!(floats.value(4), -321.0);
335    }
336}