datafusion_comet_spark_expr/math_funcs/internal/
normalize_nan.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{
19    array::{as_primitive_array, ArrayAccessor, ArrayIter, Float32Array, Float64Array},
20    datatypes::{ArrowNativeType, Float32Type, Float64Type},
21    record_batch::RecordBatch,
22};
23use arrow_schema::{DataType, Schema};
24use datafusion::logical_expr::ColumnarValue;
25use datafusion_physical_expr::PhysicalExpr;
26use std::hash::Hash;
27use std::{
28    any::Any,
29    fmt::{Display, Formatter},
30    sync::Arc,
31};
32
33#[derive(Debug, Eq)]
34pub struct NormalizeNaNAndZero {
35    pub data_type: DataType,
36    pub child: Arc<dyn PhysicalExpr>,
37}
38
39impl PartialEq for NormalizeNaNAndZero {
40    fn eq(&self, other: &Self) -> bool {
41        self.child.eq(&other.child) && self.data_type.eq(&other.data_type)
42    }
43}
44
45impl Hash for NormalizeNaNAndZero {
46    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47        self.child.hash(state);
48        self.data_type.hash(state);
49    }
50}
51
52impl NormalizeNaNAndZero {
53    pub fn new(data_type: DataType, child: Arc<dyn PhysicalExpr>) -> Self {
54        Self { data_type, child }
55    }
56}
57
58impl PhysicalExpr for NormalizeNaNAndZero {
59    fn as_any(&self) -> &dyn Any {
60        self
61    }
62
63    fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result<DataType> {
64        self.child.data_type(input_schema)
65    }
66
67    fn nullable(&self, input_schema: &Schema) -> datafusion_common::Result<bool> {
68        self.child.nullable(input_schema)
69    }
70
71    fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
72        let cv = self.child.evaluate(batch)?;
73        let array = cv.into_array(batch.num_rows())?;
74
75        match &self.data_type {
76            DataType::Float32 => {
77                let v = eval_typed(as_primitive_array::<Float32Type>(&array));
78                let new_array = Float32Array::from(v);
79                Ok(ColumnarValue::Array(Arc::new(new_array)))
80            }
81            DataType::Float64 => {
82                let v = eval_typed(as_primitive_array::<Float64Type>(&array));
83                let new_array = Float64Array::from(v);
84                Ok(ColumnarValue::Array(Arc::new(new_array)))
85            }
86            dt => panic!("Unexpected data type {:?}", dt),
87        }
88    }
89
90    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
91        self.child.children()
92    }
93
94    fn with_new_children(
95        self: Arc<Self>,
96        children: Vec<Arc<dyn PhysicalExpr>>,
97    ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
98        Ok(Arc::new(NormalizeNaNAndZero::new(
99            self.data_type.clone(),
100            Arc::clone(&children[0]),
101        )))
102    }
103}
104
105fn eval_typed<V: FloatDouble, T: ArrayAccessor<Item = V>>(input: T) -> Vec<Option<V>> {
106    let iter = ArrayIter::new(input);
107    iter.map(|o| {
108        o.map(|v| {
109            if v.is_nan() {
110                v.nan()
111            } else if v.is_neg_zero() {
112                v.zero()
113            } else {
114                v
115            }
116        })
117    })
118    .collect()
119}
120
121impl Display for NormalizeNaNAndZero {
122    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
123        write!(f, "FloatNormalize [child: {}]", self.child)
124    }
125}
126
127trait FloatDouble: ArrowNativeType {
128    fn is_nan(&self) -> bool;
129    fn nan(&self) -> Self;
130    fn is_neg_zero(&self) -> bool;
131    fn zero(&self) -> Self;
132}
133
134impl FloatDouble for f32 {
135    fn is_nan(&self) -> bool {
136        f32::is_nan(*self)
137    }
138    fn nan(&self) -> Self {
139        f32::NAN
140    }
141    fn is_neg_zero(&self) -> bool {
142        *self == -0.0
143    }
144    fn zero(&self) -> Self {
145        0.0
146    }
147}
148impl FloatDouble for f64 {
149    fn is_nan(&self) -> bool {
150        f64::is_nan(*self)
151    }
152    fn nan(&self) -> Self {
153        f64::NAN
154    }
155    fn is_neg_zero(&self) -> bool {
156        *self == -0.0
157    }
158    fn zero(&self) -> Self {
159        0.0
160    }
161}