datafusion_comet_spark_expr/math_funcs/internal/
normalize_nan.rs1use arrow::datatypes::{DataType, Schema};
19use arrow::{
20 array::{as_primitive_array, ArrayAccessor, ArrayIter, Float32Array, Float64Array},
21 datatypes::{ArrowNativeType, Float32Type, Float64Type},
22 record_batch::RecordBatch,
23};
24use datafusion::logical_expr::ColumnarValue;
25use datafusion::physical_expr::PhysicalExpr;
26use std::hash::Hash;
27use std::{
28 any::Any,
29 fmt::{Display, Formatter},
30 sync::Arc,
31};
32
33#[derive(Debug, Eq)]
34pub struct NormalizeNaNAndZero {
35 pub data_type: DataType,
36 pub child: Arc<dyn PhysicalExpr>,
37}
38
39impl PartialEq for NormalizeNaNAndZero {
40 fn eq(&self, other: &Self) -> bool {
41 self.child.eq(&other.child) && self.data_type.eq(&other.data_type)
42 }
43}
44
45impl Hash for NormalizeNaNAndZero {
46 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47 self.child.hash(state);
48 self.data_type.hash(state);
49 }
50}
51
52impl NormalizeNaNAndZero {
53 pub fn new(data_type: DataType, child: Arc<dyn PhysicalExpr>) -> Self {
54 Self { data_type, child }
55 }
56}
57
58impl PhysicalExpr for NormalizeNaNAndZero {
59 fn as_any(&self) -> &dyn Any {
60 self
61 }
62
63 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
64 unimplemented!()
65 }
66
67 fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result<DataType> {
68 self.child.data_type(input_schema)
69 }
70
71 fn nullable(&self, input_schema: &Schema) -> datafusion::common::Result<bool> {
72 self.child.nullable(input_schema)
73 }
74
75 fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
76 let cv = self.child.evaluate(batch)?;
77 let array = cv.into_array(batch.num_rows())?;
78
79 match &self.data_type {
80 DataType::Float32 => {
81 let v = eval_typed(as_primitive_array::<Float32Type>(&array));
82 let new_array = Float32Array::from(v);
83 Ok(ColumnarValue::Array(Arc::new(new_array)))
84 }
85 DataType::Float64 => {
86 let v = eval_typed(as_primitive_array::<Float64Type>(&array));
87 let new_array = Float64Array::from(v);
88 Ok(ColumnarValue::Array(Arc::new(new_array)))
89 }
90 dt => panic!("Unexpected data type {dt:?}"),
91 }
92 }
93
94 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
95 self.child.children()
96 }
97
98 fn with_new_children(
99 self: Arc<Self>,
100 children: Vec<Arc<dyn PhysicalExpr>>,
101 ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
102 Ok(Arc::new(NormalizeNaNAndZero::new(
103 self.data_type.clone(),
104 Arc::clone(&children[0]),
105 )))
106 }
107}
108
109fn eval_typed<V: FloatDouble, T: ArrayAccessor<Item = V>>(input: T) -> Vec<Option<V>> {
110 let iter = ArrayIter::new(input);
111 iter.map(|o| {
112 o.map(|v| {
113 if v.is_nan() {
114 v.nan()
115 } else if v.is_neg_zero() {
116 v.zero()
117 } else {
118 v
119 }
120 })
121 })
122 .collect()
123}
124
125impl Display for NormalizeNaNAndZero {
126 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
127 write!(f, "FloatNormalize [child: {}]", self.child)
128 }
129}
130
131trait FloatDouble: ArrowNativeType {
132 fn is_nan(&self) -> bool;
133 fn nan(&self) -> Self;
134 fn is_neg_zero(&self) -> bool;
135 fn zero(&self) -> Self;
136}
137
138impl FloatDouble for f32 {
139 fn is_nan(&self) -> bool {
140 f32::is_nan(*self)
141 }
142 fn nan(&self) -> Self {
143 f32::NAN
144 }
145 fn is_neg_zero(&self) -> bool {
146 *self == -0.0
147 }
148 fn zero(&self) -> Self {
149 0.0
150 }
151}
152impl FloatDouble for f64 {
153 fn is_nan(&self) -> bool {
154 f64::is_nan(*self)
155 }
156 fn nan(&self) -> Self {
157 f64::NAN
158 }
159 fn is_neg_zero(&self) -> bool {
160 *self == -0.0
161 }
162 fn zero(&self) -> Self {
163 0.0
164 }
165}