datafusion_comet_spark_expr/agg_funcs/
correlation.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::compute::{and, filter, is_not_null};
19
20use std::{any::Any, sync::Arc};
21
22use crate::agg_funcs::covariance::CovarianceAccumulator;
23use crate::agg_funcs::stddev::StddevAccumulator;
24use arrow::datatypes::FieldRef;
25use arrow::{
26    array::ArrayRef,
27    datatypes::{DataType, Field},
28};
29use datafusion::common::{Result, ScalarValue};
30use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion::logical_expr::type_coercion::aggregates::NUMERICS;
32use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
33use datafusion::physical_expr::expressions::format_state_name;
34use datafusion::physical_expr::expressions::StatsType;
35
36/// CORR aggregate expression
37/// The implementation mostly is the same as the DataFusion's implementation. The reason
38/// we have our own implementation is that DataFusion has UInt64 for state_field `count`,
39/// while Spark has Double for count. Also we have added `null_on_divide_by_zero`
40/// to be consistent with Spark's implementation.
41#[derive(Debug)]
42pub struct Correlation {
43    name: String,
44    signature: Signature,
45    null_on_divide_by_zero: bool,
46}
47
48impl Correlation {
49    pub fn new(name: impl Into<String>, data_type: DataType, null_on_divide_by_zero: bool) -> Self {
50        // the result of correlation just support FLOAT64 data type.
51        assert!(matches!(data_type, DataType::Float64));
52        Self {
53            name: name.into(),
54            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
55            null_on_divide_by_zero,
56        }
57    }
58}
59
60impl AggregateUDFImpl for Correlation {
61    /// Return a reference to Any that can be used for downcasting
62    fn as_any(&self) -> &dyn Any {
63        self
64    }
65
66    fn name(&self) -> &str {
67        &self.name
68    }
69
70    fn signature(&self) -> &Signature {
71        &self.signature
72    }
73
74    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75        Ok(DataType::Float64)
76    }
77    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
78        Ok(ScalarValue::Float64(None))
79    }
80
81    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
82        Ok(Box::new(CorrelationAccumulator::try_new(
83            self.null_on_divide_by_zero,
84        )?))
85    }
86
87    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
88        Ok(vec![
89            Arc::new(Field::new(
90                format_state_name(&self.name, "count"),
91                DataType::Float64,
92                true,
93            )),
94            Arc::new(Field::new(
95                format_state_name(&self.name, "mean1"),
96                DataType::Float64,
97                true,
98            )),
99            Arc::new(Field::new(
100                format_state_name(&self.name, "mean2"),
101                DataType::Float64,
102                true,
103            )),
104            Arc::new(Field::new(
105                format_state_name(&self.name, "algo_const"),
106                DataType::Float64,
107                true,
108            )),
109            Arc::new(Field::new(
110                format_state_name(&self.name, "m2_1"),
111                DataType::Float64,
112                true,
113            )),
114            Arc::new(Field::new(
115                format_state_name(&self.name, "m2_2"),
116                DataType::Float64,
117                true,
118            )),
119        ])
120    }
121}
122
123/// An accumulator to compute correlation
124#[derive(Debug)]
125pub struct CorrelationAccumulator {
126    covar: CovarianceAccumulator,
127    stddev1: StddevAccumulator,
128    stddev2: StddevAccumulator,
129    null_on_divide_by_zero: bool,
130}
131
132impl CorrelationAccumulator {
133    /// Creates a new `CorrelationAccumulator`
134    pub fn try_new(null_on_divide_by_zero: bool) -> Result<Self> {
135        Ok(Self {
136            covar: CovarianceAccumulator::try_new(StatsType::Population, null_on_divide_by_zero)?,
137            stddev1: StddevAccumulator::try_new(StatsType::Population, null_on_divide_by_zero)?,
138            stddev2: StddevAccumulator::try_new(StatsType::Population, null_on_divide_by_zero)?,
139            null_on_divide_by_zero,
140        })
141    }
142}
143
144impl Accumulator for CorrelationAccumulator {
145    fn state(&mut self) -> Result<Vec<ScalarValue>> {
146        Ok(vec![
147            ScalarValue::from(self.covar.get_count()),
148            ScalarValue::from(self.covar.get_mean1()),
149            ScalarValue::from(self.covar.get_mean2()),
150            ScalarValue::from(self.covar.get_algo_const()),
151            ScalarValue::from(self.stddev1.get_m2()),
152            ScalarValue::from(self.stddev2.get_m2()),
153        ])
154    }
155
156    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
157        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
158            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
159            let values1 = filter(&values[0], &mask)?;
160            let values2 = filter(&values[1], &mask)?;
161
162            vec![values1, values2]
163        } else {
164            values.to_vec()
165        };
166
167        if !values[0].is_empty() && !values[1].is_empty() {
168            self.covar.update_batch(&values)?;
169            self.stddev1.update_batch(&values[0..1])?;
170            self.stddev2.update_batch(&values[1..2])?;
171        }
172
173        Ok(())
174    }
175
176    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
177        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
178            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
179            let values1 = filter(&values[0], &mask)?;
180            let values2 = filter(&values[1], &mask)?;
181
182            vec![values1, values2]
183        } else {
184            values.to_vec()
185        };
186
187        self.covar.retract_batch(&values)?;
188        self.stddev1.retract_batch(&values[0..1])?;
189        self.stddev2.retract_batch(&values[1..2])?;
190        Ok(())
191    }
192
193    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
194        let states_c = [
195            Arc::clone(&states[0]),
196            Arc::clone(&states[1]),
197            Arc::clone(&states[2]),
198            Arc::clone(&states[3]),
199        ];
200        let states_s1 = [
201            Arc::clone(&states[0]),
202            Arc::clone(&states[1]),
203            Arc::clone(&states[4]),
204        ];
205        let states_s2 = [
206            Arc::clone(&states[0]),
207            Arc::clone(&states[2]),
208            Arc::clone(&states[5]),
209        ];
210
211        if !states[0].is_empty() && !states[1].is_empty() && !states[2].is_empty() {
212            self.covar.merge_batch(&states_c)?;
213            self.stddev1.merge_batch(&states_s1)?;
214            self.stddev2.merge_batch(&states_s2)?;
215        }
216        Ok(())
217    }
218
219    fn evaluate(&mut self) -> Result<ScalarValue> {
220        let covar = self.covar.evaluate()?;
221        let stddev1 = self.stddev1.evaluate()?;
222        let stddev2 = self.stddev2.evaluate()?;
223
224        match (covar, stddev1, stddev2) {
225            (
226                ScalarValue::Float64(Some(c)),
227                ScalarValue::Float64(Some(s1)),
228                ScalarValue::Float64(Some(s2)),
229            ) if s1 != 0.0 && s2 != 0.0 => Ok(ScalarValue::Float64(Some(c / (s1 * s2)))),
230            _ if self.null_on_divide_by_zero => Ok(ScalarValue::Float64(None)),
231            _ => {
232                if self.covar.get_count() == 1.0 {
233                    return Ok(ScalarValue::Float64(Some(f64::NAN)));
234                }
235                Ok(ScalarValue::Float64(None))
236            }
237        }
238    }
239
240    fn size(&self) -> usize {
241        std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) + self.covar.size()
242            - std::mem::size_of_val(&self.stddev1)
243            + self.stddev1.size()
244            - std::mem::size_of_val(&self.stddev2)
245            + self.stddev2.size()
246    }
247}