datafusion_functions/math/
round.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::compute::{cast_with_options, CastOptions};
25use arrow::datatypes::DataType::{Float32, Float64, Int32};
26use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type};
27use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32    Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "Math Functions"),
38    description = "Rounds a number to the nearest integer.",
39    syntax_example = "round(numeric_expression[, decimal_places])",
40    standard_argument(name = "numeric_expression", prefix = "Numeric"),
41    argument(
42        name = "decimal_places",
43        description = "Optional. The number of decimal places to round to. Defaults to 0."
44    ),
45    sql_example = r#"```sql
46> SELECT round(3.14159);
47+--------------+
48| round(3.14159)|
49+--------------+
50| 3.0          |
51+--------------+
52```"#
53)]
54#[derive(Debug, PartialEq, Eq, Hash)]
55pub struct RoundFunc {
56    signature: Signature,
57}
58
59impl Default for RoundFunc {
60    fn default() -> Self {
61        RoundFunc::new()
62    }
63}
64
65impl RoundFunc {
66    pub fn new() -> Self {
67        use DataType::*;
68        Self {
69            signature: Signature::one_of(
70                vec![
71                    Exact(vec![Float64, Int64]),
72                    Exact(vec![Float32, Int64]),
73                    Exact(vec![Float64]),
74                    Exact(vec![Float32]),
75                ],
76                Volatility::Immutable,
77            ),
78        }
79    }
80}
81
82impl ScalarUDFImpl for RoundFunc {
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    fn name(&self) -> &str {
88        "round"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
96        match arg_types[0] {
97            Float32 => Ok(Float32),
98            _ => Ok(Float64),
99        }
100    }
101
102    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
103        make_scalar_function(round, vec![])(&args.args)
104    }
105
106    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
107        // round preserves the order of the first argument
108        let value = &input[0];
109        let precision = input.get(1);
110
111        if precision
112            .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
113            .unwrap_or(true)
114        {
115            Ok(value.sort_properties)
116        } else {
117            Ok(SortProperties::Unordered)
118        }
119    }
120
121    fn documentation(&self) -> Option<&Documentation> {
122        self.doc()
123    }
124}
125
126/// Round SQL function
127pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
128    if args.len() != 1 && args.len() != 2 {
129        return exec_err!(
130            "round function requires one or two arguments, got {}",
131            args.len()
132        );
133    }
134
135    let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0)));
136
137    if args.len() == 2 {
138        decimal_places = ColumnarValue::Array(Arc::clone(&args[1]));
139    }
140
141    match args[0].data_type() {
142        Float64 => match decimal_places {
143            ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
144                let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
145                    exec_datafusion_err!(
146                        "Invalid value for decimal places: {decimal_places}: {e}"
147                    )
148                })?;
149
150                let result = args[0]
151                    .as_primitive::<Float64Type>()
152                    .unary::<_, Float64Type>(|value: f64| {
153                        (value * 10.0_f64.powi(decimal_places)).round()
154                            / 10.0_f64.powi(decimal_places)
155                    });
156                Ok(Arc::new(result) as _)
157            }
158            ColumnarValue::Array(decimal_places) => {
159                let options = CastOptions {
160                    safe: false, // raise error if the cast is not possible
161                    ..Default::default()
162                };
163                let decimal_places = cast_with_options(&decimal_places, &Int32, &options)
164                    .map_err(|e| {
165                        exec_datafusion_err!("Invalid values for decimal places: {e}")
166                    })?;
167
168                let values = args[0].as_primitive::<Float64Type>();
169                let decimal_places = decimal_places.as_primitive::<Int32Type>();
170                let result = arrow::compute::binary::<_, _, _, Float64Type>(
171                    values,
172                    decimal_places,
173                    |value, decimal_places| {
174                        (value * 10.0_f64.powi(decimal_places)).round()
175                            / 10.0_f64.powi(decimal_places)
176                    },
177                )?;
178                Ok(Arc::new(result) as _)
179            }
180            _ => {
181                exec_err!("round function requires a scalar or array for decimal_places")
182            }
183        },
184
185        Float32 => match decimal_places {
186            ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
187                let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
188                    exec_datafusion_err!(
189                        "Invalid value for decimal places: {decimal_places}: {e}"
190                    )
191                })?;
192                let result = args[0]
193                    .as_primitive::<Float32Type>()
194                    .unary::<_, Float32Type>(|value: f32| {
195                        (value * 10.0_f32.powi(decimal_places)).round()
196                            / 10.0_f32.powi(decimal_places)
197                    });
198                Ok(Arc::new(result) as _)
199            }
200            ColumnarValue::Array(_) => {
201                let ColumnarValue::Array(decimal_places) =
202                    decimal_places.cast_to(&Int32, None).map_err(|e| {
203                        exec_datafusion_err!("Invalid values for decimal places: {e}")
204                    })?
205                else {
206                    panic!("Unexpected result of ColumnarValue::Array.cast")
207                };
208
209                let values = args[0].as_primitive::<Float32Type>();
210                let decimal_places = decimal_places.as_primitive::<Int32Type>();
211                let result: PrimitiveArray<Float32Type> = arrow::compute::binary(
212                    values,
213                    decimal_places,
214                    |value, decimal_places| {
215                        (value * 10.0_f32.powi(decimal_places)).round()
216                            / 10.0_f32.powi(decimal_places)
217                    },
218                )?;
219                Ok(Arc::new(result) as _)
220            }
221            _ => {
222                exec_err!("round function requires a scalar or array for decimal_places")
223            }
224        },
225
226        other => exec_err!("Unsupported data type {other:?} for function round"),
227    }
228}
229
230#[cfg(test)]
231mod test {
232    use std::sync::Arc;
233
234    use crate::math::round::round;
235
236    use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
237    use datafusion_common::cast::{as_float32_array, as_float64_array};
238    use datafusion_common::DataFusionError;
239
240    #[test]
241    fn test_round_f32() {
242        let args: Vec<ArrayRef> = vec![
243            Arc::new(Float32Array::from(vec![125.2345; 10])), // input
244            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
245        ];
246
247        let result = round(&args).expect("failed to initialize function round");
248        let floats =
249            as_float32_array(&result).expect("failed to initialize function round");
250
251        let expected = Float32Array::from(vec![
252            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
253        ]);
254
255        assert_eq!(floats, &expected);
256    }
257
258    #[test]
259    fn test_round_f64() {
260        let args: Vec<ArrayRef> = vec![
261            Arc::new(Float64Array::from(vec![125.2345; 10])), // input
262            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
263        ];
264
265        let result = round(&args).expect("failed to initialize function round");
266        let floats =
267            as_float64_array(&result).expect("failed to initialize function round");
268
269        let expected = Float64Array::from(vec![
270            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
271        ]);
272
273        assert_eq!(floats, &expected);
274    }
275
276    #[test]
277    fn test_round_f32_one_input() {
278        let args: Vec<ArrayRef> = vec![
279            Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
280        ];
281
282        let result = round(&args).expect("failed to initialize function round");
283        let floats =
284            as_float32_array(&result).expect("failed to initialize function round");
285
286        let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
287
288        assert_eq!(floats, &expected);
289    }
290
291    #[test]
292    fn test_round_f64_one_input() {
293        let args: Vec<ArrayRef> = vec![
294            Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
295        ];
296
297        let result = round(&args).expect("failed to initialize function round");
298        let floats =
299            as_float64_array(&result).expect("failed to initialize function round");
300
301        let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
302
303        assert_eq!(floats, &expected);
304    }
305
306    #[test]
307    fn test_round_f32_cast_fail() {
308        let args: Vec<ArrayRef> = vec![
309            Arc::new(Float64Array::from(vec![125.2345])), // input
310            Arc::new(Int64Array::from(vec![2147483648])), // decimal_places
311        ];
312
313        let result = round(&args);
314
315        assert!(result.is_err());
316        assert!(matches!(result, Err(DataFusionError::Execution(_))));
317    }
318}