datafusion_comet_spark_expr/agg_funcs/
variance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::FieldRef;
19use arrow::{
20    array::{ArrayRef, Float64Array},
21    datatypes::{DataType, Field},
22};
23use datafusion::common::{downcast_value, Result, ScalarValue};
24use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
25use datafusion::logical_expr::Volatility::Immutable;
26use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature};
27use datafusion::physical_expr::expressions::format_state_name;
28use datafusion::physical_expr::expressions::StatsType;
29use std::any::Any;
30use std::sync::Arc;
31
32/// VAR_SAMP and VAR_POP aggregate expression
33/// The implementation mostly is the same as the DataFusion's implementation. The reason
34/// we have our own implementation is that DataFusion has UInt64 for state_field `count`,
35/// while Spark has Double for count. Also we have added `null_on_divide_by_zero`
36/// to be consistent with Spark's implementation.
37#[derive(Debug)]
38pub struct Variance {
39    name: String,
40    signature: Signature,
41    stats_type: StatsType,
42    null_on_divide_by_zero: bool,
43}
44
45impl Variance {
46    /// Create a new VARIANCE aggregate function
47    pub fn new(
48        name: impl Into<String>,
49        data_type: DataType,
50        stats_type: StatsType,
51        null_on_divide_by_zero: bool,
52    ) -> Self {
53        // the result of variance just support FLOAT64 data type.
54        assert!(matches!(data_type, DataType::Float64));
55        Self {
56            name: name.into(),
57            signature: Signature::numeric(1, Immutable),
58            stats_type,
59            null_on_divide_by_zero,
60        }
61    }
62}
63
64impl AggregateUDFImpl for Variance {
65    /// Return a reference to Any that can be used for downcasting
66    fn as_any(&self) -> &dyn Any {
67        self
68    }
69
70    fn name(&self) -> &str {
71        &self.name
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
79        Ok(DataType::Float64)
80    }
81
82    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
83        Ok(Box::new(VarianceAccumulator::try_new(
84            self.stats_type,
85            self.null_on_divide_by_zero,
86        )?))
87    }
88
89    fn create_sliding_accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
90        Ok(Box::new(VarianceAccumulator::try_new(
91            self.stats_type,
92            self.null_on_divide_by_zero,
93        )?))
94    }
95
96    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
97        Ok(vec![
98            Arc::new(Field::new(
99                format_state_name(&self.name, "count"),
100                DataType::Float64,
101                true,
102            )),
103            Arc::new(Field::new(
104                format_state_name(&self.name, "mean"),
105                DataType::Float64,
106                true,
107            )),
108            Arc::new(Field::new(
109                format_state_name(&self.name, "m2"),
110                DataType::Float64,
111                true,
112            )),
113        ])
114    }
115
116    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
117        Ok(ScalarValue::Float64(None))
118    }
119}
120
121/// An accumulator to compute variance
122#[derive(Debug)]
123pub struct VarianceAccumulator {
124    m2: f64,
125    mean: f64,
126    count: f64,
127    stats_type: StatsType,
128    null_on_divide_by_zero: bool,
129}
130
131impl VarianceAccumulator {
132    /// Creates a new `VarianceAccumulator`
133    pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result<Self> {
134        Ok(Self {
135            m2: 0_f64,
136            mean: 0_f64,
137            count: 0_f64,
138            stats_type: s_type,
139            null_on_divide_by_zero,
140        })
141    }
142
143    pub fn get_count(&self) -> f64 {
144        self.count
145    }
146
147    pub fn get_mean(&self) -> f64 {
148        self.mean
149    }
150
151    pub fn get_m2(&self) -> f64 {
152        self.m2
153    }
154}
155
156impl Accumulator for VarianceAccumulator {
157    fn state(&mut self) -> Result<Vec<ScalarValue>> {
158        Ok(vec![
159            ScalarValue::from(self.count),
160            ScalarValue::from(self.mean),
161            ScalarValue::from(self.m2),
162        ])
163    }
164
165    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
166        let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
167
168        for value in arr {
169            let new_count = self.count + 1.0;
170            let delta1 = value - self.mean;
171            let new_mean = delta1 / new_count + self.mean;
172            let delta2 = value - new_mean;
173            let new_m2 = self.m2 + delta1 * delta2;
174
175            self.count += 1.0;
176            self.mean = new_mean;
177            self.m2 = new_m2;
178        }
179
180        Ok(())
181    }
182
183    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
184        let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
185
186        for value in arr {
187            let new_count = self.count - 1.0;
188            let delta1 = self.mean - value;
189            let new_mean = delta1 / new_count + self.mean;
190            let delta2 = new_mean - value;
191            let new_m2 = self.m2 - delta1 * delta2;
192
193            self.count -= 1.0;
194            self.mean = new_mean;
195            self.m2 = new_m2;
196        }
197
198        Ok(())
199    }
200
201    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
202        let counts = downcast_value!(states[0], Float64Array);
203        let means = downcast_value!(states[1], Float64Array);
204        let m2s = downcast_value!(states[2], Float64Array);
205
206        for i in 0..counts.len() {
207            let c = counts.value(i);
208            if c == 0_f64 {
209                continue;
210            }
211            let new_count = self.count + c;
212            let new_mean = self.mean * self.count / new_count + means.value(i) * c / new_count;
213            let delta = self.mean - means.value(i);
214            let new_m2 = self.m2 + m2s.value(i) + delta * delta * self.count * c / new_count;
215
216            self.count = new_count;
217            self.mean = new_mean;
218            self.m2 = new_m2;
219        }
220        Ok(())
221    }
222
223    fn evaluate(&mut self) -> Result<ScalarValue> {
224        let count = match self.stats_type {
225            StatsType::Population => self.count,
226            StatsType::Sample => {
227                if self.count > 0.0 {
228                    self.count - 1.0
229                } else {
230                    self.count
231                }
232            }
233        };
234
235        Ok(ScalarValue::Float64(match self.count {
236            0.0 => None,
237            count if count == 1.0 && StatsType::Sample == self.stats_type => {
238                if self.null_on_divide_by_zero {
239                    None
240                } else {
241                    Some(f64::NAN)
242                }
243            }
244            _ => Some(self.m2 / count),
245        }))
246    }
247
248    fn size(&self) -> usize {
249        std::mem::size_of_val(self)
250    }
251}