datafusion_functions/math/
round.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::compute::{cast_with_options, CastOptions};
25use arrow::datatypes::DataType::{Float32, Float64, Int32};
26use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type};
27use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "Math Functions"),
38 description = "Rounds a number to the nearest integer.",
39 syntax_example = "round(numeric_expression[, decimal_places])",
40 standard_argument(name = "numeric_expression", prefix = "Numeric"),
41 argument(
42 name = "decimal_places",
43 description = "Optional. The number of decimal places to round to. Defaults to 0."
44 ),
45 sql_example = r#"```sql
46> SELECT round(3.14159);
47+--------------+
48| round(3.14159)|
49+--------------+
50| 3.0 |
51+--------------+
52```"#
53)]
54#[derive(Debug, PartialEq, Eq, Hash)]
55pub struct RoundFunc {
56 signature: Signature,
57}
58
59impl Default for RoundFunc {
60 fn default() -> Self {
61 RoundFunc::new()
62 }
63}
64
65impl RoundFunc {
66 pub fn new() -> Self {
67 use DataType::*;
68 Self {
69 signature: Signature::one_of(
70 vec![
71 Exact(vec![Float64, Int64]),
72 Exact(vec![Float32, Int64]),
73 Exact(vec![Float64]),
74 Exact(vec![Float32]),
75 ],
76 Volatility::Immutable,
77 ),
78 }
79 }
80}
81
82impl ScalarUDFImpl for RoundFunc {
83 fn as_any(&self) -> &dyn Any {
84 self
85 }
86
87 fn name(&self) -> &str {
88 "round"
89 }
90
91 fn signature(&self) -> &Signature {
92 &self.signature
93 }
94
95 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
96 match arg_types[0] {
97 Float32 => Ok(Float32),
98 _ => Ok(Float64),
99 }
100 }
101
102 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
103 make_scalar_function(round, vec![])(&args.args)
104 }
105
106 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
107 let value = &input[0];
109 let precision = input.get(1);
110
111 if precision
112 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
113 .unwrap_or(true)
114 {
115 Ok(value.sort_properties)
116 } else {
117 Ok(SortProperties::Unordered)
118 }
119 }
120
121 fn documentation(&self) -> Option<&Documentation> {
122 self.doc()
123 }
124}
125
126pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
128 if args.len() != 1 && args.len() != 2 {
129 return exec_err!(
130 "round function requires one or two arguments, got {}",
131 args.len()
132 );
133 }
134
135 let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0)));
136
137 if args.len() == 2 {
138 decimal_places = ColumnarValue::Array(Arc::clone(&args[1]));
139 }
140
141 match args[0].data_type() {
142 Float64 => match decimal_places {
143 ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
144 let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
145 exec_datafusion_err!(
146 "Invalid value for decimal places: {decimal_places}: {e}"
147 )
148 })?;
149
150 let result = args[0]
151 .as_primitive::<Float64Type>()
152 .unary::<_, Float64Type>(|value: f64| {
153 (value * 10.0_f64.powi(decimal_places)).round()
154 / 10.0_f64.powi(decimal_places)
155 });
156 Ok(Arc::new(result) as _)
157 }
158 ColumnarValue::Array(decimal_places) => {
159 let options = CastOptions {
160 safe: false, ..Default::default()
162 };
163 let decimal_places = cast_with_options(&decimal_places, &Int32, &options)
164 .map_err(|e| {
165 exec_datafusion_err!("Invalid values for decimal places: {e}")
166 })?;
167
168 let values = args[0].as_primitive::<Float64Type>();
169 let decimal_places = decimal_places.as_primitive::<Int32Type>();
170 let result = arrow::compute::binary::<_, _, _, Float64Type>(
171 values,
172 decimal_places,
173 |value, decimal_places| {
174 (value * 10.0_f64.powi(decimal_places)).round()
175 / 10.0_f64.powi(decimal_places)
176 },
177 )?;
178 Ok(Arc::new(result) as _)
179 }
180 _ => {
181 exec_err!("round function requires a scalar or array for decimal_places")
182 }
183 },
184
185 Float32 => match decimal_places {
186 ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
187 let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
188 exec_datafusion_err!(
189 "Invalid value for decimal places: {decimal_places}: {e}"
190 )
191 })?;
192 let result = args[0]
193 .as_primitive::<Float32Type>()
194 .unary::<_, Float32Type>(|value: f32| {
195 (value * 10.0_f32.powi(decimal_places)).round()
196 / 10.0_f32.powi(decimal_places)
197 });
198 Ok(Arc::new(result) as _)
199 }
200 ColumnarValue::Array(_) => {
201 let ColumnarValue::Array(decimal_places) =
202 decimal_places.cast_to(&Int32, None).map_err(|e| {
203 exec_datafusion_err!("Invalid values for decimal places: {e}")
204 })?
205 else {
206 panic!("Unexpected result of ColumnarValue::Array.cast")
207 };
208
209 let values = args[0].as_primitive::<Float32Type>();
210 let decimal_places = decimal_places.as_primitive::<Int32Type>();
211 let result: PrimitiveArray<Float32Type> = arrow::compute::binary(
212 values,
213 decimal_places,
214 |value, decimal_places| {
215 (value * 10.0_f32.powi(decimal_places)).round()
216 / 10.0_f32.powi(decimal_places)
217 },
218 )?;
219 Ok(Arc::new(result) as _)
220 }
221 _ => {
222 exec_err!("round function requires a scalar or array for decimal_places")
223 }
224 },
225
226 other => exec_err!("Unsupported data type {other:?} for function round"),
227 }
228}
229
230#[cfg(test)]
231mod test {
232 use std::sync::Arc;
233
234 use crate::math::round::round;
235
236 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
237 use datafusion_common::cast::{as_float32_array, as_float64_array};
238 use datafusion_common::DataFusionError;
239
240 #[test]
241 fn test_round_f32() {
242 let args: Vec<ArrayRef> = vec![
243 Arc::new(Float32Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
246
247 let result = round(&args).expect("failed to initialize function round");
248 let floats =
249 as_float32_array(&result).expect("failed to initialize function round");
250
251 let expected = Float32Array::from(vec![
252 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
253 ]);
254
255 assert_eq!(floats, &expected);
256 }
257
258 #[test]
259 fn test_round_f64() {
260 let args: Vec<ArrayRef> = vec![
261 Arc::new(Float64Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
264
265 let result = round(&args).expect("failed to initialize function round");
266 let floats =
267 as_float64_array(&result).expect("failed to initialize function round");
268
269 let expected = Float64Array::from(vec![
270 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
271 ]);
272
273 assert_eq!(floats, &expected);
274 }
275
276 #[test]
277 fn test_round_f32_one_input() {
278 let args: Vec<ArrayRef> = vec![
279 Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
281
282 let result = round(&args).expect("failed to initialize function round");
283 let floats =
284 as_float32_array(&result).expect("failed to initialize function round");
285
286 let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
287
288 assert_eq!(floats, &expected);
289 }
290
291 #[test]
292 fn test_round_f64_one_input() {
293 let args: Vec<ArrayRef> = vec![
294 Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
296
297 let result = round(&args).expect("failed to initialize function round");
298 let floats =
299 as_float64_array(&result).expect("failed to initialize function round");
300
301 let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
302
303 assert_eq!(floats, &expected);
304 }
305
306 #[test]
307 fn test_round_f32_cast_fail() {
308 let args: Vec<ArrayRef> = vec![
309 Arc::new(Float64Array::from(vec![125.2345])), Arc::new(Int64Array::from(vec![2147483648])), ];
312
313 let result = round(&args);
314
315 assert!(result.is_err());
316 assert!(matches!(result, Err(DataFusionError::Execution(_))));
317 }
318}