datafusion_spark/function/math/
rint.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, ArrayRef, AsArray};
22use arrow::compute::cast;
23use arrow::datatypes::DataType::{
24 Float32, Float64, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
25};
26use arrow::datatypes::{DataType, Float32Type, Float64Type};
27use datafusion_common::{Result, assert_eq_or_internal_err, exec_err};
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::{
30 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkRint {
36 signature: Signature,
37}
38
39impl Default for SparkRint {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparkRint {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::numeric(1, Volatility::Immutable),
49 }
50 }
51}
52
53impl ScalarUDFImpl for SparkRint {
54 fn as_any(&self) -> &dyn Any {
55 self
56 }
57
58 fn name(&self) -> &str {
59 "rint"
60 }
61
62 fn signature(&self) -> &Signature {
63 &self.signature
64 }
65
66 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
67 Ok(Float64)
68 }
69
70 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
71 make_scalar_function(spark_rint, vec![])(&args.args)
72 }
73
74 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
75 if input.len() == 1 {
77 let value = &input[0];
78 Ok(value.sort_properties)
79 } else {
80 Ok(SortProperties::default())
81 }
82 }
83}
84
85pub fn spark_rint(args: &[ArrayRef]) -> Result<ArrayRef> {
86 assert_eq_or_internal_err!(args.len(), 1, "`rint` expects exactly one argument");
87
88 let array: &dyn Array = args[0].as_ref();
89 match args[0].data_type() {
90 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => {
91 Ok(cast(array, &Float64)?)
92 }
93 Float64 => {
94 let array = array
95 .as_primitive::<Float64Type>()
96 .unary::<_, Float64Type>(|value: f64| value.round_ties_even());
97 Ok(Arc::new(array))
98 }
99 Float32 => {
100 let array = array
101 .as_primitive::<Float32Type>()
102 .unary::<_, Float64Type>(|value: f32| value.round_ties_even() as f64);
103 Ok(Arc::new(array))
104 }
105 _ => {
106 exec_err!(
107 "rint expects a numeric argument, got {}",
108 args[0].data_type()
109 )
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use arrow::array::Float64Array;
118
119 #[test]
120 fn test_rint_positive_decimals() {
121 let result = spark_rint(&[Arc::new(Float64Array::from(vec![12.3456]))]).unwrap();
123 assert_eq!(result.as_ref(), &Float64Array::from(vec![12.0]));
124
125 let result = spark_rint(&[Arc::new(Float64Array::from(vec![2.5]))]).unwrap();
127 assert_eq!(result.as_ref(), &Float64Array::from(vec![2.0]));
128
129 let result = spark_rint(&[Arc::new(Float64Array::from(vec![3.5]))]).unwrap();
130 assert_eq!(result.as_ref(), &Float64Array::from(vec![4.0]));
131 }
132
133 #[test]
134 fn test_rint_negative_decimals() {
135 let result = spark_rint(&[Arc::new(Float64Array::from(vec![-12.3456]))]).unwrap();
137 assert_eq!(result.as_ref(), &Float64Array::from(vec![-12.0]));
138
139 let result = spark_rint(&[Arc::new(Float64Array::from(vec![-2.5]))]).unwrap();
141 assert_eq!(result.as_ref(), &Float64Array::from(vec![-2.0]));
142 }
143
144 #[test]
145 fn test_rint_integers() {
146 let result = spark_rint(&[Arc::new(Float64Array::from(vec![42.0]))]).unwrap();
148 assert_eq!(result.as_ref(), &Float64Array::from(vec![42.0]));
149 }
150
151 #[test]
152 fn test_rint_null() {
153 let result = spark_rint(&[Arc::new(Float64Array::from(vec![None]))]).unwrap();
154 assert_eq!(result.as_ref(), &Float64Array::from(vec![None]));
155 }
156
157 #[test]
158 fn test_rint_zero() {
159 let result = spark_rint(&[Arc::new(Float64Array::from(vec![0.0]))]).unwrap();
161 assert_eq!(result.as_ref(), &Float64Array::from(vec![0.0]));
162 }
163}