datafusion_comet_spark_expr/agg_funcs/
variance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19
20use arrow::{
21    array::{ArrayRef, Float64Array},
22    datatypes::{DataType, Field},
23};
24use datafusion::logical_expr::Accumulator;
25use datafusion_common::{downcast_value, Result, ScalarValue};
26use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
27use datafusion_expr::Volatility::Immutable;
28use datafusion_expr::{AggregateUDFImpl, Signature};
29use datafusion_physical_expr::expressions::format_state_name;
30use datafusion_physical_expr::expressions::StatsType;
31
32/// VAR_SAMP and VAR_POP aggregate expression
33/// The implementation mostly is the same as the DataFusion's implementation. The reason
34/// we have our own implementation is that DataFusion has UInt64 for state_field `count`,
35/// while Spark has Double for count. Also we have added `null_on_divide_by_zero`
36/// to be consistent with Spark's implementation.
37#[derive(Debug)]
38pub struct Variance {
39    name: String,
40    signature: Signature,
41    stats_type: StatsType,
42    null_on_divide_by_zero: bool,
43}
44
45impl Variance {
46    /// Create a new VARIANCE aggregate function
47    pub fn new(
48        name: impl Into<String>,
49        data_type: DataType,
50        stats_type: StatsType,
51        null_on_divide_by_zero: bool,
52    ) -> Self {
53        // the result of variance just support FLOAT64 data type.
54        assert!(matches!(data_type, DataType::Float64));
55        Self {
56            name: name.into(),
57            signature: Signature::numeric(1, Immutable),
58            stats_type,
59            null_on_divide_by_zero,
60        }
61    }
62}
63
64impl AggregateUDFImpl for Variance {
65    /// Return a reference to Any that can be used for downcasting
66    fn as_any(&self) -> &dyn Any {
67        self
68    }
69
70    fn name(&self) -> &str {
71        &self.name
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
79        Ok(DataType::Float64)
80    }
81
82    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
83        Ok(Box::new(VarianceAccumulator::try_new(
84            self.stats_type,
85            self.null_on_divide_by_zero,
86        )?))
87    }
88
89    fn create_sliding_accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
90        Ok(Box::new(VarianceAccumulator::try_new(
91            self.stats_type,
92            self.null_on_divide_by_zero,
93        )?))
94    }
95
96    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
97        Ok(vec![
98            Field::new(
99                format_state_name(&self.name, "count"),
100                DataType::Float64,
101                true,
102            ),
103            Field::new(
104                format_state_name(&self.name, "mean"),
105                DataType::Float64,
106                true,
107            ),
108            Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true),
109        ])
110    }
111
112    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
113        Ok(ScalarValue::Float64(None))
114    }
115}
116
117/// An accumulator to compute variance
118#[derive(Debug)]
119pub struct VarianceAccumulator {
120    m2: f64,
121    mean: f64,
122    count: f64,
123    stats_type: StatsType,
124    null_on_divide_by_zero: bool,
125}
126
127impl VarianceAccumulator {
128    /// Creates a new `VarianceAccumulator`
129    pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result<Self> {
130        Ok(Self {
131            m2: 0_f64,
132            mean: 0_f64,
133            count: 0_f64,
134            stats_type: s_type,
135            null_on_divide_by_zero,
136        })
137    }
138
139    pub fn get_count(&self) -> f64 {
140        self.count
141    }
142
143    pub fn get_mean(&self) -> f64 {
144        self.mean
145    }
146
147    pub fn get_m2(&self) -> f64 {
148        self.m2
149    }
150}
151
152impl Accumulator for VarianceAccumulator {
153    fn state(&mut self) -> Result<Vec<ScalarValue>> {
154        Ok(vec![
155            ScalarValue::from(self.count),
156            ScalarValue::from(self.mean),
157            ScalarValue::from(self.m2),
158        ])
159    }
160
161    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
162        let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
163
164        for value in arr {
165            let new_count = self.count + 1.0;
166            let delta1 = value - self.mean;
167            let new_mean = delta1 / new_count + self.mean;
168            let delta2 = value - new_mean;
169            let new_m2 = self.m2 + delta1 * delta2;
170
171            self.count += 1.0;
172            self.mean = new_mean;
173            self.m2 = new_m2;
174        }
175
176        Ok(())
177    }
178
179    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
180        let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
181
182        for value in arr {
183            let new_count = self.count - 1.0;
184            let delta1 = self.mean - value;
185            let new_mean = delta1 / new_count + self.mean;
186            let delta2 = new_mean - value;
187            let new_m2 = self.m2 - delta1 * delta2;
188
189            self.count -= 1.0;
190            self.mean = new_mean;
191            self.m2 = new_m2;
192        }
193
194        Ok(())
195    }
196
197    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
198        let counts = downcast_value!(states[0], Float64Array);
199        let means = downcast_value!(states[1], Float64Array);
200        let m2s = downcast_value!(states[2], Float64Array);
201
202        for i in 0..counts.len() {
203            let c = counts.value(i);
204            if c == 0_f64 {
205                continue;
206            }
207            let new_count = self.count + c;
208            let new_mean = self.mean * self.count / new_count + means.value(i) * c / new_count;
209            let delta = self.mean - means.value(i);
210            let new_m2 = self.m2 + m2s.value(i) + delta * delta * self.count * c / new_count;
211
212            self.count = new_count;
213            self.mean = new_mean;
214            self.m2 = new_m2;
215        }
216        Ok(())
217    }
218
219    fn evaluate(&mut self) -> Result<ScalarValue> {
220        let count = match self.stats_type {
221            StatsType::Population => self.count,
222            StatsType::Sample => {
223                if self.count > 0.0 {
224                    self.count - 1.0
225                } else {
226                    self.count
227                }
228            }
229        };
230
231        Ok(ScalarValue::Float64(match self.count {
232            0.0 => None,
233            count if count == 1.0 && StatsType::Sample == self.stats_type => {
234                if self.null_on_divide_by_zero {
235                    None
236                } else {
237                    Some(f64::NAN)
238                }
239            }
240            _ => Some(self.m2 / count),
241        }))
242    }
243
244    fn size(&self) -> usize {
245        std::mem::size_of_val(self)
246    }
247}