datafusion_comet_spark_expr/math_funcs/internal/
normalize_nan.rs1use arrow::{
19 array::{as_primitive_array, ArrayAccessor, ArrayIter, Float32Array, Float64Array},
20 datatypes::{ArrowNativeType, Float32Type, Float64Type},
21 record_batch::RecordBatch,
22};
23use arrow_schema::{DataType, Schema};
24use datafusion::logical_expr::ColumnarValue;
25use datafusion_physical_expr::PhysicalExpr;
26use std::hash::Hash;
27use std::{
28 any::Any,
29 fmt::{Display, Formatter},
30 sync::Arc,
31};
32
33#[derive(Debug, Eq)]
34pub struct NormalizeNaNAndZero {
35 pub data_type: DataType,
36 pub child: Arc<dyn PhysicalExpr>,
37}
38
39impl PartialEq for NormalizeNaNAndZero {
40 fn eq(&self, other: &Self) -> bool {
41 self.child.eq(&other.child) && self.data_type.eq(&other.data_type)
42 }
43}
44
45impl Hash for NormalizeNaNAndZero {
46 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47 self.child.hash(state);
48 self.data_type.hash(state);
49 }
50}
51
52impl NormalizeNaNAndZero {
53 pub fn new(data_type: DataType, child: Arc<dyn PhysicalExpr>) -> Self {
54 Self { data_type, child }
55 }
56}
57
58impl PhysicalExpr for NormalizeNaNAndZero {
59 fn as_any(&self) -> &dyn Any {
60 self
61 }
62
63 fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result<DataType> {
64 self.child.data_type(input_schema)
65 }
66
67 fn nullable(&self, input_schema: &Schema) -> datafusion_common::Result<bool> {
68 self.child.nullable(input_schema)
69 }
70
71 fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
72 let cv = self.child.evaluate(batch)?;
73 let array = cv.into_array(batch.num_rows())?;
74
75 match &self.data_type {
76 DataType::Float32 => {
77 let v = eval_typed(as_primitive_array::<Float32Type>(&array));
78 let new_array = Float32Array::from(v);
79 Ok(ColumnarValue::Array(Arc::new(new_array)))
80 }
81 DataType::Float64 => {
82 let v = eval_typed(as_primitive_array::<Float64Type>(&array));
83 let new_array = Float64Array::from(v);
84 Ok(ColumnarValue::Array(Arc::new(new_array)))
85 }
86 dt => panic!("Unexpected data type {:?}", dt),
87 }
88 }
89
90 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
91 self.child.children()
92 }
93
94 fn with_new_children(
95 self: Arc<Self>,
96 children: Vec<Arc<dyn PhysicalExpr>>,
97 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
98 Ok(Arc::new(NormalizeNaNAndZero::new(
99 self.data_type.clone(),
100 Arc::clone(&children[0]),
101 )))
102 }
103}
104
105fn eval_typed<V: FloatDouble, T: ArrayAccessor<Item = V>>(input: T) -> Vec<Option<V>> {
106 let iter = ArrayIter::new(input);
107 iter.map(|o| {
108 o.map(|v| {
109 if v.is_nan() {
110 v.nan()
111 } else if v.is_neg_zero() {
112 v.zero()
113 } else {
114 v
115 }
116 })
117 })
118 .collect()
119}
120
121impl Display for NormalizeNaNAndZero {
122 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
123 write!(f, "FloatNormalize [child: {}]", self.child)
124 }
125}
126
127trait FloatDouble: ArrowNativeType {
128 fn is_nan(&self) -> bool;
129 fn nan(&self) -> Self;
130 fn is_neg_zero(&self) -> bool;
131 fn zero(&self) -> Self;
132}
133
134impl FloatDouble for f32 {
135 fn is_nan(&self) -> bool {
136 f32::is_nan(*self)
137 }
138 fn nan(&self) -> Self {
139 f32::NAN
140 }
141 fn is_neg_zero(&self) -> bool {
142 *self == -0.0
143 }
144 fn zero(&self) -> Self {
145 0.0
146 }
147}
148impl FloatDouble for f64 {
149 fn is_nan(&self) -> bool {
150 f64::is_nan(*self)
151 }
152 fn nan(&self) -> Self {
153 f64::NAN
154 }
155 fn is_neg_zero(&self) -> bool {
156 *self == -0.0
157 }
158 fn zero(&self) -> Self {
159 0.0
160 }
161}