datafusion_comet_spark_expr/agg_funcs/
covariance.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20use arrow::datatypes::FieldRef;
21use arrow::{
22    array::{ArrayRef, Float64Array},
23    compute::cast,
24    datatypes::{DataType, Field},
25};
26use datafusion::common::{
27    downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue,
28};
29use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
30use datafusion::logical_expr::type_coercion::aggregates::NUMERICS;
31use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
32use datafusion::physical_expr::expressions::format_state_name;
33use datafusion::physical_expr::expressions::StatsType;
34use std::any::Any;
35use std::sync::Arc;
36
37/// COVAR_SAMP and COVAR_POP aggregate expression
38/// The implementation mostly is the same as the DataFusion's implementation. The reason
39/// we have our own implementation is that DataFusion has UInt64 for state_field count,
40/// while Spark has Double for count.
41#[derive(Debug, Clone)]
42pub struct Covariance {
43    name: String,
44    signature: Signature,
45    stats_type: StatsType,
46    null_on_divide_by_zero: bool,
47}
48
49impl Covariance {
50    /// Create a new COVAR aggregate function
51    pub fn new(
52        name: impl Into<String>,
53        data_type: DataType,
54        stats_type: StatsType,
55        null_on_divide_by_zero: bool,
56    ) -> Self {
57        // the result of covariance just support FLOAT64 data type.
58        assert!(matches!(data_type, DataType::Float64));
59        Self {
60            name: name.into(),
61            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
62            stats_type,
63            null_on_divide_by_zero,
64        }
65    }
66}
67
68impl AggregateUDFImpl for Covariance {
69    /// Return a reference to Any that can be used for downcasting
70    fn as_any(&self) -> &dyn Any {
71        self
72    }
73
74    fn name(&self) -> &str {
75        &self.name
76    }
77
78    fn signature(&self) -> &Signature {
79        &self.signature
80    }
81
82    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
83        Ok(DataType::Float64)
84    }
85    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
86        Ok(ScalarValue::Float64(None))
87    }
88
89    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
90        Ok(Box::new(CovarianceAccumulator::try_new(
91            self.stats_type,
92            self.null_on_divide_by_zero,
93        )?))
94    }
95
96    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
97        Ok(vec![
98            Arc::new(Field::new(
99                format_state_name(&self.name, "count"),
100                DataType::Float64,
101                true,
102            )),
103            Arc::new(Field::new(
104                format_state_name(&self.name, "mean1"),
105                DataType::Float64,
106                true,
107            )),
108            Arc::new(Field::new(
109                format_state_name(&self.name, "mean2"),
110                DataType::Float64,
111                true,
112            )),
113            Arc::new(Field::new(
114                format_state_name(&self.name, "algo_const"),
115                DataType::Float64,
116                true,
117            )),
118        ])
119    }
120}
121
122/// An accumulator to compute covariance
123#[derive(Debug)]
124pub struct CovarianceAccumulator {
125    algo_const: f64,
126    mean1: f64,
127    mean2: f64,
128    count: f64,
129    stats_type: StatsType,
130    null_on_divide_by_zero: bool,
131}
132
133impl CovarianceAccumulator {
134    /// Creates a new `CovarianceAccumulator`
135    pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result<Self> {
136        Ok(Self {
137            algo_const: 0_f64,
138            mean1: 0_f64,
139            mean2: 0_f64,
140            count: 0_f64,
141            stats_type: s_type,
142            null_on_divide_by_zero,
143        })
144    }
145
146    pub fn get_count(&self) -> f64 {
147        self.count
148    }
149
150    pub fn get_mean1(&self) -> f64 {
151        self.mean1
152    }
153
154    pub fn get_mean2(&self) -> f64 {
155        self.mean2
156    }
157
158    pub fn get_algo_const(&self) -> f64 {
159        self.algo_const
160    }
161}
162
163impl Accumulator for CovarianceAccumulator {
164    fn state(&mut self) -> Result<Vec<ScalarValue>> {
165        Ok(vec![
166            ScalarValue::from(self.count),
167            ScalarValue::from(self.mean1),
168            ScalarValue::from(self.mean2),
169            ScalarValue::from(self.algo_const),
170        ])
171    }
172
173    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
174        let values1 = &cast(&values[0], &DataType::Float64)?;
175        let values2 = &cast(&values[1], &DataType::Float64)?;
176
177        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
178        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
179
180        for i in 0..values1.len() {
181            let value1 = if values1.is_valid(i) {
182                arr1.next()
183            } else {
184                None
185            };
186            let value2 = if values2.is_valid(i) {
187                arr2.next()
188            } else {
189                None
190            };
191
192            if value1.is_none() || value2.is_none() {
193                continue;
194            }
195
196            let value1 = unwrap_or_internal_err!(value1);
197            let value2 = unwrap_or_internal_err!(value2);
198            let new_count = self.count + 1.0;
199            let delta1 = value1 - self.mean1;
200            let new_mean1 = delta1 / new_count + self.mean1;
201            let delta2 = value2 - self.mean2;
202            let new_mean2 = delta2 / new_count + self.mean2;
203            let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
204
205            self.count += 1.0;
206            self.mean1 = new_mean1;
207            self.mean2 = new_mean2;
208            self.algo_const = new_c;
209        }
210
211        Ok(())
212    }
213
214    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
215        let values1 = &cast(&values[0], &DataType::Float64)?;
216        let values2 = &cast(&values[1], &DataType::Float64)?;
217        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
218        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
219
220        for i in 0..values1.len() {
221            let value1 = if values1.is_valid(i) {
222                arr1.next()
223            } else {
224                None
225            };
226            let value2 = if values2.is_valid(i) {
227                arr2.next()
228            } else {
229                None
230            };
231
232            if value1.is_none() || value2.is_none() {
233                continue;
234            }
235
236            let value1 = unwrap_or_internal_err!(value1);
237            let value2 = unwrap_or_internal_err!(value2);
238
239            let new_count = self.count - 1.0;
240            let delta1 = self.mean1 - value1;
241            let new_mean1 = delta1 / new_count + self.mean1;
242            let delta2 = self.mean2 - value2;
243            let new_mean2 = delta2 / new_count + self.mean2;
244            let new_c = self.algo_const - delta1 * (new_mean2 - value2);
245
246            self.count -= 1.0;
247            self.mean1 = new_mean1;
248            self.mean2 = new_mean2;
249            self.algo_const = new_c;
250        }
251
252        Ok(())
253    }
254
255    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
256        let counts = downcast_value!(states[0], Float64Array);
257        let means1 = downcast_value!(states[1], Float64Array);
258        let means2 = downcast_value!(states[2], Float64Array);
259        let cs = downcast_value!(states[3], Float64Array);
260
261        for i in 0..counts.len() {
262            let c = counts.value(i);
263            if c == 0.0 {
264                continue;
265            }
266            let new_count = self.count + c;
267            let new_mean1 = self.mean1 * self.count / new_count + means1.value(i) * c / new_count;
268            let new_mean2 = self.mean2 * self.count / new_count + means2.value(i) * c / new_count;
269            let delta1 = self.mean1 - means1.value(i);
270            let delta2 = self.mean2 - means2.value(i);
271            let new_c =
272                self.algo_const + cs.value(i) + delta1 * delta2 * self.count * c / new_count;
273
274            self.count = new_count;
275            self.mean1 = new_mean1;
276            self.mean2 = new_mean2;
277            self.algo_const = new_c;
278        }
279        Ok(())
280    }
281
282    fn evaluate(&mut self) -> Result<ScalarValue> {
283        if self.count == 0.0 {
284            return Ok(ScalarValue::Float64(None));
285        }
286
287        let count = match self.stats_type {
288            StatsType::Population => self.count,
289            StatsType::Sample if self.count > 1.0 => self.count - 1.0,
290            StatsType::Sample => {
291                // self.count == 1.0
292                return if self.null_on_divide_by_zero {
293                    Ok(ScalarValue::Float64(None))
294                } else {
295                    Ok(ScalarValue::Float64(Some(f64::NAN)))
296                };
297            }
298        };
299
300        Ok(ScalarValue::Float64(Some(self.algo_const / count)))
301    }
302
303    fn size(&self) -> usize {
304        std::mem::size_of_val(self)
305    }
306}