datafusion_comet_spark_expr/math_funcs/internal/
normalize_nan.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Schema};
19use arrow::{
20    array::{as_primitive_array, ArrayAccessor, ArrayIter, Float32Array, Float64Array},
21    datatypes::{ArrowNativeType, Float32Type, Float64Type},
22    record_batch::RecordBatch,
23};
24use datafusion::logical_expr::ColumnarValue;
25use datafusion::physical_expr::PhysicalExpr;
26use std::hash::Hash;
27use std::{
28    any::Any,
29    fmt::{Display, Formatter},
30    sync::Arc,
31};
32
33#[derive(Debug, Eq)]
34pub struct NormalizeNaNAndZero {
35    pub data_type: DataType,
36    pub child: Arc<dyn PhysicalExpr>,
37}
38
39impl PartialEq for NormalizeNaNAndZero {
40    fn eq(&self, other: &Self) -> bool {
41        self.child.eq(&other.child) && self.data_type.eq(&other.data_type)
42    }
43}
44
45impl Hash for NormalizeNaNAndZero {
46    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47        self.child.hash(state);
48        self.data_type.hash(state);
49    }
50}
51
52impl NormalizeNaNAndZero {
53    pub fn new(data_type: DataType, child: Arc<dyn PhysicalExpr>) -> Self {
54        Self { data_type, child }
55    }
56}
57
58impl PhysicalExpr for NormalizeNaNAndZero {
59    fn as_any(&self) -> &dyn Any {
60        self
61    }
62
63    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
64        unimplemented!()
65    }
66
67    fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result<DataType> {
68        self.child.data_type(input_schema)
69    }
70
71    fn nullable(&self, input_schema: &Schema) -> datafusion::common::Result<bool> {
72        self.child.nullable(input_schema)
73    }
74
75    fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
76        let cv = self.child.evaluate(batch)?;
77        let array = cv.into_array(batch.num_rows())?;
78
79        match &self.data_type {
80            DataType::Float32 => {
81                let v = eval_typed(as_primitive_array::<Float32Type>(&array));
82                let new_array = Float32Array::from(v);
83                Ok(ColumnarValue::Array(Arc::new(new_array)))
84            }
85            DataType::Float64 => {
86                let v = eval_typed(as_primitive_array::<Float64Type>(&array));
87                let new_array = Float64Array::from(v);
88                Ok(ColumnarValue::Array(Arc::new(new_array)))
89            }
90            dt => panic!("Unexpected data type {dt:?}"),
91        }
92    }
93
94    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
95        self.child.children()
96    }
97
98    fn with_new_children(
99        self: Arc<Self>,
100        children: Vec<Arc<dyn PhysicalExpr>>,
101    ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
102        Ok(Arc::new(NormalizeNaNAndZero::new(
103            self.data_type.clone(),
104            Arc::clone(&children[0]),
105        )))
106    }
107}
108
109fn eval_typed<V: FloatDouble, T: ArrayAccessor<Item = V>>(input: T) -> Vec<Option<V>> {
110    let iter = ArrayIter::new(input);
111    iter.map(|o| {
112        o.map(|v| {
113            if v.is_nan() {
114                v.nan()
115            } else if v.is_neg_zero() {
116                v.zero()
117            } else {
118                v
119            }
120        })
121    })
122    .collect()
123}
124
125impl Display for NormalizeNaNAndZero {
126    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
127        write!(f, "FloatNormalize [child: {}]", self.child)
128    }
129}
130
131trait FloatDouble: ArrowNativeType {
132    fn is_nan(&self) -> bool;
133    fn nan(&self) -> Self;
134    fn is_neg_zero(&self) -> bool;
135    fn zero(&self) -> Self;
136}
137
138impl FloatDouble for f32 {
139    fn is_nan(&self) -> bool {
140        f32::is_nan(*self)
141    }
142    fn nan(&self) -> Self {
143        f32::NAN
144    }
145    fn is_neg_zero(&self) -> bool {
146        *self == -0.0
147    }
148    fn zero(&self) -> Self {
149        0.0
150    }
151}
152impl FloatDouble for f64 {
153    fn is_nan(&self) -> bool {
154        f64::is_nan(*self)
155    }
156    fn nan(&self) -> Self {
157        f64::NAN
158    }
159    fn is_neg_zero(&self) -> bool {
160        *self == -0.0
161    }
162    fn zero(&self) -> Self {
163        0.0
164    }
165}