1use std::any::Any;
19
20use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};
21
22use arrow::array::ArrayRef;
23use arrow::datatypes::DataType::{
24 Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
25};
26use arrow::datatypes::{
27 ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
28 Decimal256Type, Float32Type, Float64Type, Int32Type,
29};
30use arrow::error::ArrowError;
31use datafusion_common::types::{
32 NativeType, logical_float32, logical_float64, logical_int32,
33};
34use datafusion_common::{Result, ScalarValue, exec_err};
35use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
36use datafusion_expr::{
37 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
38 TypeSignature, TypeSignatureClass, Volatility,
39};
40use datafusion_macros::user_doc;
41
42#[user_doc(
43 doc_section(label = "Math Functions"),
44 description = "Rounds a number to the nearest integer.",
45 syntax_example = "round(numeric_expression[, decimal_places])",
46 standard_argument(name = "numeric_expression", prefix = "Numeric"),
47 argument(
48 name = "decimal_places",
49 description = "Optional. The number of decimal places to round to. Defaults to 0."
50 ),
51 sql_example = r#"```sql
52> SELECT round(3.14159);
53+--------------+
54| round(3.14159)|
55+--------------+
56| 3.0 |
57+--------------+
58```"#
59)]
60#[derive(Debug, PartialEq, Eq, Hash)]
61pub struct RoundFunc {
62 signature: Signature,
63}
64
65impl Default for RoundFunc {
66 fn default() -> Self {
67 RoundFunc::new()
68 }
69}
70
71impl RoundFunc {
72 pub fn new() -> Self {
73 let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
74 let decimal_places = Coercion::new_implicit(
75 TypeSignatureClass::Native(logical_int32()),
76 vec![TypeSignatureClass::Integer],
77 NativeType::Int32,
78 );
79 let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
80 let float64 = Coercion::new_implicit(
81 TypeSignatureClass::Native(logical_float64()),
82 vec![TypeSignatureClass::Numeric],
83 NativeType::Float64,
84 );
85 Self {
86 signature: Signature::one_of(
87 vec![
88 TypeSignature::Coercible(vec![
89 decimal.clone(),
90 decimal_places.clone(),
91 ]),
92 TypeSignature::Coercible(vec![decimal]),
93 TypeSignature::Coercible(vec![
94 float32.clone(),
95 decimal_places.clone(),
96 ]),
97 TypeSignature::Coercible(vec![float32]),
98 TypeSignature::Coercible(vec![float64.clone(), decimal_places]),
99 TypeSignature::Coercible(vec![float64]),
100 ],
101 Volatility::Immutable,
102 ),
103 }
104 }
105}
106
107impl ScalarUDFImpl for RoundFunc {
108 fn as_any(&self) -> &dyn Any {
109 self
110 }
111
112 fn name(&self) -> &str {
113 "round"
114 }
115
116 fn signature(&self) -> &Signature {
117 &self.signature
118 }
119
120 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
121 Ok(match arg_types[0].clone() {
122 Float32 => Float32,
123 dt @ Decimal128(_, _)
124 | dt @ Decimal256(_, _)
125 | dt @ Decimal32(_, _)
126 | dt @ Decimal64(_, _) => dt,
127 _ => Float64,
128 })
129 }
130
131 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
132 if args.arg_fields.iter().any(|a| a.data_type().is_null()) {
133 return ColumnarValue::Scalar(ScalarValue::Null)
134 .cast_to(args.return_type(), None);
135 }
136
137 let default_decimal_places = ColumnarValue::Scalar(ScalarValue::Int32(Some(0)));
138 let decimal_places = if args.args.len() == 2 {
139 &args.args[1]
140 } else {
141 &default_decimal_places
142 };
143
144 round_columnar(&args.args[0], decimal_places, args.number_rows)
145 }
146
147 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
148 let value = &input[0];
150 let precision = input.get(1);
151
152 if precision
153 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
154 .unwrap_or(true)
155 {
156 Ok(value.sort_properties)
157 } else {
158 Ok(SortProperties::Unordered)
159 }
160 }
161
162 fn documentation(&self) -> Option<&Documentation> {
163 self.doc()
164 }
165}
166
167fn round_columnar(
168 value: &ColumnarValue,
169 decimal_places: &ColumnarValue,
170 number_rows: usize,
171) -> Result<ColumnarValue> {
172 let value_array = value.to_array(number_rows)?;
173 let both_scalars = matches!(value, ColumnarValue::Scalar(_))
174 && matches!(decimal_places, ColumnarValue::Scalar(_));
175
176 let arr: ArrayRef = match value_array.data_type() {
177 Float64 => {
178 let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
179 value_array.as_ref(),
180 decimal_places,
181 round_float::<f64>,
182 )?;
183 result as _
184 }
185 Float32 => {
186 let result = calculate_binary_math::<Float32Type, Int32Type, Float32Type, _>(
187 value_array.as_ref(),
188 decimal_places,
189 round_float::<f32>,
190 )?;
191 result as _
192 }
193 Decimal32(precision, scale) => {
194 let result = calculate_binary_decimal_math::<
195 Decimal32Type,
196 Int32Type,
197 Decimal32Type,
198 _,
199 >(
200 value_array.as_ref(),
201 decimal_places,
202 |v, dp| round_decimal(v, *scale, dp),
203 *precision,
204 *scale,
205 )?;
206 result as _
207 }
208 Decimal64(precision, scale) => {
209 let result = calculate_binary_decimal_math::<
210 Decimal64Type,
211 Int32Type,
212 Decimal64Type,
213 _,
214 >(
215 value_array.as_ref(),
216 decimal_places,
217 |v, dp| round_decimal(v, *scale, dp),
218 *precision,
219 *scale,
220 )?;
221 result as _
222 }
223 Decimal128(precision, scale) => {
224 let result = calculate_binary_decimal_math::<
225 Decimal128Type,
226 Int32Type,
227 Decimal128Type,
228 _,
229 >(
230 value_array.as_ref(),
231 decimal_places,
232 |v, dp| round_decimal(v, *scale, dp),
233 *precision,
234 *scale,
235 )?;
236 result as _
237 }
238 Decimal256(precision, scale) => {
239 let result = calculate_binary_decimal_math::<
240 Decimal256Type,
241 Int32Type,
242 Decimal256Type,
243 _,
244 >(
245 value_array.as_ref(),
246 decimal_places,
247 |v, dp| round_decimal(v, *scale, dp),
248 *precision,
249 *scale,
250 )?;
251 result as _
252 }
253 other => exec_err!("Unsupported data type {other:?} for function round")?,
254 };
255
256 if both_scalars {
257 ScalarValue::try_from_array(&arr, 0).map(ColumnarValue::Scalar)
258 } else {
259 Ok(ColumnarValue::Array(arr))
260 }
261}
262
263fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError>
264where
265 T: num_traits::Float,
266{
267 let factor = T::from(10_f64.powi(decimal_places)).ok_or_else(|| {
268 ArrowError::ComputeError(format!(
269 "Invalid value for decimal places: {decimal_places}"
270 ))
271 })?;
272 Ok((value * factor).round() / factor)
273}
274
275fn round_decimal<V: ArrowNativeTypeOp>(
276 value: V,
277 scale: i8,
278 decimal_places: i32,
279) -> Result<V, ArrowError> {
280 let diff = i64::from(scale) - i64::from(decimal_places);
281 if diff <= 0 {
282 return Ok(value);
283 }
284
285 let diff: u32 = diff.try_into().map_err(|e| {
286 ArrowError::ComputeError(format!(
287 "Invalid value for decimal places: {decimal_places}: {e}"
288 ))
289 })?;
290
291 let one = V::ONE;
292 let two = V::from_usize(2).ok_or_else(|| {
293 ArrowError::ComputeError("Internal error: could not create constant 2".into())
294 })?;
295 let ten = V::from_usize(10).ok_or_else(|| {
296 ArrowError::ComputeError("Internal error: could not create constant 10".into())
297 })?;
298
299 let factor = ten.pow_checked(diff).map_err(|_| {
300 ArrowError::ComputeError(format!(
301 "Overflow while rounding decimal with scale {scale} and decimal places {decimal_places}"
302 ))
303 })?;
304
305 let mut quotient = value.div_wrapping(factor);
306 let remainder = value.mod_wrapping(factor);
307
308 let threshold = factor.div_wrapping(two);
310 if remainder >= threshold {
311 quotient = quotient.add_checked(one).map_err(|_| {
312 ArrowError::ComputeError("Overflow while rounding decimal".into())
313 })?;
314 } else if remainder <= threshold.neg_wrapping() {
315 quotient = quotient.sub_checked(one).map_err(|_| {
316 ArrowError::ComputeError("Overflow while rounding decimal".into())
317 })?;
318 }
319
320 quotient
321 .mul_checked(factor)
322 .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into()))
323}
324
325#[cfg(test)]
326mod test {
327 use std::sync::Arc;
328
329 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
330 use datafusion_common::DataFusionError;
331 use datafusion_common::ScalarValue;
332 use datafusion_common::cast::{as_float32_array, as_float64_array};
333 use datafusion_expr::ColumnarValue;
334
335 fn round_arrays(
336 value: ArrayRef,
337 decimal_places: Option<ArrayRef>,
338 ) -> Result<ArrayRef, DataFusionError> {
339 let number_rows = value.len();
340 let value = ColumnarValue::Array(value);
341 let decimal_places = decimal_places
342 .map(ColumnarValue::Array)
343 .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));
344
345 let result = super::round_columnar(&value, &decimal_places, number_rows)?;
346 match result {
347 ColumnarValue::Array(array) => Ok(array),
348 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
349 }
350 }
351
352 #[test]
353 fn test_round_f32() {
354 let args: Vec<ArrayRef> = vec![
355 Arc::new(Float32Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
358
359 let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
360 .expect("failed to initialize function round");
361 let floats =
362 as_float32_array(&result).expect("failed to initialize function round");
363
364 let expected = Float32Array::from(vec![
365 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
366 ]);
367
368 assert_eq!(floats, &expected);
369 }
370
371 #[test]
372 fn test_round_f64() {
373 let args: Vec<ArrayRef> = vec![
374 Arc::new(Float64Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
377
378 let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
379 .expect("failed to initialize function round");
380 let floats =
381 as_float64_array(&result).expect("failed to initialize function round");
382
383 let expected = Float64Array::from(vec![
384 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
385 ]);
386
387 assert_eq!(floats, &expected);
388 }
389
390 #[test]
391 fn test_round_f32_one_input() {
392 let args: Vec<ArrayRef> = vec![
393 Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
395
396 let result = round_arrays(Arc::clone(&args[0]), None)
397 .expect("failed to initialize function round");
398 let floats =
399 as_float32_array(&result).expect("failed to initialize function round");
400
401 let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
402
403 assert_eq!(floats, &expected);
404 }
405
406 #[test]
407 fn test_round_f64_one_input() {
408 let args: Vec<ArrayRef> = vec![
409 Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
411
412 let result = round_arrays(Arc::clone(&args[0]), None)
413 .expect("failed to initialize function round");
414 let floats =
415 as_float64_array(&result).expect("failed to initialize function round");
416
417 let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
418
419 assert_eq!(floats, &expected);
420 }
421
422 #[test]
423 fn test_round_f32_cast_fail() {
424 let args: Vec<ArrayRef> = vec![
425 Arc::new(Float64Array::from(vec![125.2345])), Arc::new(Int64Array::from(vec![2147483648])), ];
428
429 let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])));
430
431 assert!(result.is_err());
432 assert!(matches!(
433 result,
434 Err(DataFusionError::ArrowError(_, _)) | Err(DataFusionError::Execution(_))
435 ));
436 }
437}