datafusion_spark/function/math/
rint.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, ArrayRef, AsArray};
22use arrow::compute::cast;
23use arrow::datatypes::DataType::{
24 Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8,
25};
26use arrow::datatypes::{DataType, Float32Type, Float64Type};
27use datafusion_common::{exec_err, Result};
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::{
30 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkRint {
36 signature: Signature,
37}
38
39impl Default for SparkRint {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparkRint {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::numeric(1, Volatility::Immutable),
49 }
50 }
51}
52
53impl ScalarUDFImpl for SparkRint {
54 fn as_any(&self) -> &dyn Any {
55 self
56 }
57
58 fn name(&self) -> &str {
59 "rint"
60 }
61
62 fn signature(&self) -> &Signature {
63 &self.signature
64 }
65
66 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
67 Ok(Float64)
68 }
69
70 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
71 make_scalar_function(spark_rint, vec![])(&args.args)
72 }
73
74 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
75 if input.len() == 1 {
77 let value = &input[0];
78 Ok(value.sort_properties)
79 } else {
80 Ok(SortProperties::default())
81 }
82 }
83}
84
85pub fn spark_rint(args: &[ArrayRef]) -> Result<ArrayRef> {
86 if args.len() != 1 {
87 return exec_err!("rint expects exactly 1 argument, got {}", args.len());
88 }
89
90 let array: &dyn Array = args[0].as_ref();
91 match args[0].data_type() {
92 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => {
93 Ok(cast(array, &Float64)?)
94 }
95 Float64 => {
96 let array = array
97 .as_primitive::<Float64Type>()
98 .unary::<_, Float64Type>(|value: f64| value.round_ties_even());
99 Ok(Arc::new(array))
100 }
101 Float32 => {
102 let array = array
103 .as_primitive::<Float32Type>()
104 .unary::<_, Float64Type>(|value: f32| value.round_ties_even() as f64);
105 Ok(Arc::new(array))
106 }
107 _ => {
108 exec_err!(
109 "rint expects a numeric argument, got {}",
110 args[0].data_type()
111 )
112 }
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use arrow::array::Float64Array;
120
121 #[test]
122 fn test_rint_positive_decimals() {
123 let result = spark_rint(&[Arc::new(Float64Array::from(vec![12.3456]))]).unwrap();
125 assert_eq!(result.as_ref(), &Float64Array::from(vec![12.0]));
126
127 let result = spark_rint(&[Arc::new(Float64Array::from(vec![2.5]))]).unwrap();
129 assert_eq!(result.as_ref(), &Float64Array::from(vec![2.0]));
130
131 let result = spark_rint(&[Arc::new(Float64Array::from(vec![3.5]))]).unwrap();
132 assert_eq!(result.as_ref(), &Float64Array::from(vec![4.0]));
133 }
134
135 #[test]
136 fn test_rint_negative_decimals() {
137 let result = spark_rint(&[Arc::new(Float64Array::from(vec![-12.3456]))]).unwrap();
139 assert_eq!(result.as_ref(), &Float64Array::from(vec![-12.0]));
140
141 let result = spark_rint(&[Arc::new(Float64Array::from(vec![-2.5]))]).unwrap();
143 assert_eq!(result.as_ref(), &Float64Array::from(vec![-2.0]));
144 }
145
146 #[test]
147 fn test_rint_integers() {
148 let result = spark_rint(&[Arc::new(Float64Array::from(vec![42.0]))]).unwrap();
150 assert_eq!(result.as_ref(), &Float64Array::from(vec![42.0]));
151 }
152
153 #[test]
154 fn test_rint_null() {
155 let result = spark_rint(&[Arc::new(Float64Array::from(vec![None]))]).unwrap();
156 assert_eq!(result.as_ref(), &Float64Array::from(vec![None]));
157 }
158
159 #[test]
160 fn test_rint_zero() {
161 let result = spark_rint(&[Arc::new(Float64Array::from(vec![0.0]))]).unwrap();
163 assert_eq!(result.as_ref(), &Float64Array::from(vec![0.0]));
164 }
165}