datafusion_spark/function/math/
rint.rs1use std::sync::Arc;
19
20use arrow::array::{Array, ArrayRef, AsArray};
21use arrow::compute::cast;
22use arrow::datatypes::DataType::{
23 Float32, Float64, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
24};
25use arrow::datatypes::{DataType, Float32Type, Float64Type};
26use datafusion_common::{Result, assert_eq_or_internal_err, exec_err};
27use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
28use datafusion_expr::{
29 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkRint {
35 signature: Signature,
36}
37
38impl Default for SparkRint {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl SparkRint {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::numeric(1, Volatility::Immutable),
48 }
49 }
50}
51
52impl ScalarUDFImpl for SparkRint {
53 fn name(&self) -> &str {
54 "rint"
55 }
56
57 fn signature(&self) -> &Signature {
58 &self.signature
59 }
60
61 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
62 Ok(Float64)
63 }
64
65 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
66 make_scalar_function(spark_rint, vec![])(&args.args)
67 }
68
69 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
70 if input.len() == 1 {
72 let value = &input[0];
73 Ok(value.sort_properties)
74 } else {
75 Ok(SortProperties::default())
76 }
77 }
78}
79
80pub fn spark_rint(args: &[ArrayRef]) -> Result<ArrayRef> {
81 assert_eq_or_internal_err!(args.len(), 1, "`rint` expects exactly one argument");
82
83 let array: &dyn Array = args[0].as_ref();
84 match args[0].data_type() {
85 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => {
86 Ok(cast(array, &Float64)?)
87 }
88 Float64 => {
89 let array = array
90 .as_primitive::<Float64Type>()
91 .unary::<_, Float64Type>(|value: f64| value.round_ties_even());
92 Ok(Arc::new(array))
93 }
94 Float32 => {
95 let array = array
96 .as_primitive::<Float32Type>()
97 .unary::<_, Float64Type>(|value: f32| value.round_ties_even() as f64);
98 Ok(Arc::new(array))
99 }
100 _ => {
101 exec_err!(
102 "rint expects a numeric argument, got {}",
103 args[0].data_type()
104 )
105 }
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use arrow::array::Float64Array;
113
114 #[test]
115 fn test_rint_positive_decimals() {
116 let result = spark_rint(&[Arc::new(Float64Array::from(vec![12.3456]))]).unwrap();
118 assert_eq!(result.as_ref(), &Float64Array::from(vec![12.0]));
119
120 let result = spark_rint(&[Arc::new(Float64Array::from(vec![2.5]))]).unwrap();
122 assert_eq!(result.as_ref(), &Float64Array::from(vec![2.0]));
123
124 let result = spark_rint(&[Arc::new(Float64Array::from(vec![3.5]))]).unwrap();
125 assert_eq!(result.as_ref(), &Float64Array::from(vec![4.0]));
126 }
127
128 #[test]
129 fn test_rint_negative_decimals() {
130 let result = spark_rint(&[Arc::new(Float64Array::from(vec![-12.3456]))]).unwrap();
132 assert_eq!(result.as_ref(), &Float64Array::from(vec![-12.0]));
133
134 let result = spark_rint(&[Arc::new(Float64Array::from(vec![-2.5]))]).unwrap();
136 assert_eq!(result.as_ref(), &Float64Array::from(vec![-2.0]));
137 }
138
139 #[test]
140 fn test_rint_integers() {
141 let result = spark_rint(&[Arc::new(Float64Array::from(vec![42.0]))]).unwrap();
143 assert_eq!(result.as_ref(), &Float64Array::from(vec![42.0]));
144 }
145
146 #[test]
147 fn test_rint_null() {
148 let result = spark_rint(&[Arc::new(Float64Array::from(vec![None]))]).unwrap();
149 assert_eq!(result.as_ref(), &Float64Array::from(vec![None]));
150 }
151
152 #[test]
153 fn test_rint_zero() {
154 let result = spark_rint(&[Arc::new(Float64Array::from(vec![0.0]))]).unwrap();
156 assert_eq!(result.as_ref(), &Float64Array::from(vec![0.0]));
157 }
158}