datafusion_comet_spark_expr/agg_funcs/
variance.rs1use std::any::Any;
19
20use arrow::{
21 array::{ArrayRef, Float64Array},
22 datatypes::{DataType, Field},
23};
24use datafusion::logical_expr::Accumulator;
25use datafusion_common::{downcast_value, Result, ScalarValue};
26use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
27use datafusion_expr::Volatility::Immutable;
28use datafusion_expr::{AggregateUDFImpl, Signature};
29use datafusion_physical_expr::expressions::format_state_name;
30use datafusion_physical_expr::expressions::StatsType;
31
32#[derive(Debug)]
38pub struct Variance {
39 name: String,
40 signature: Signature,
41 stats_type: StatsType,
42 null_on_divide_by_zero: bool,
43}
44
45impl Variance {
46 pub fn new(
48 name: impl Into<String>,
49 data_type: DataType,
50 stats_type: StatsType,
51 null_on_divide_by_zero: bool,
52 ) -> Self {
53 assert!(matches!(data_type, DataType::Float64));
55 Self {
56 name: name.into(),
57 signature: Signature::numeric(1, Immutable),
58 stats_type,
59 null_on_divide_by_zero,
60 }
61 }
62}
63
64impl AggregateUDFImpl for Variance {
65 fn as_any(&self) -> &dyn Any {
67 self
68 }
69
70 fn name(&self) -> &str {
71 &self.name
72 }
73
74 fn signature(&self) -> &Signature {
75 &self.signature
76 }
77
78 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
79 Ok(DataType::Float64)
80 }
81
82 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
83 Ok(Box::new(VarianceAccumulator::try_new(
84 self.stats_type,
85 self.null_on_divide_by_zero,
86 )?))
87 }
88
89 fn create_sliding_accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
90 Ok(Box::new(VarianceAccumulator::try_new(
91 self.stats_type,
92 self.null_on_divide_by_zero,
93 )?))
94 }
95
96 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
97 Ok(vec![
98 Field::new(
99 format_state_name(&self.name, "count"),
100 DataType::Float64,
101 true,
102 ),
103 Field::new(
104 format_state_name(&self.name, "mean"),
105 DataType::Float64,
106 true,
107 ),
108 Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true),
109 ])
110 }
111
112 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
113 Ok(ScalarValue::Float64(None))
114 }
115}
116
117#[derive(Debug)]
119pub struct VarianceAccumulator {
120 m2: f64,
121 mean: f64,
122 count: f64,
123 stats_type: StatsType,
124 null_on_divide_by_zero: bool,
125}
126
127impl VarianceAccumulator {
128 pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result<Self> {
130 Ok(Self {
131 m2: 0_f64,
132 mean: 0_f64,
133 count: 0_f64,
134 stats_type: s_type,
135 null_on_divide_by_zero,
136 })
137 }
138
139 pub fn get_count(&self) -> f64 {
140 self.count
141 }
142
143 pub fn get_mean(&self) -> f64 {
144 self.mean
145 }
146
147 pub fn get_m2(&self) -> f64 {
148 self.m2
149 }
150}
151
152impl Accumulator for VarianceAccumulator {
153 fn state(&mut self) -> Result<Vec<ScalarValue>> {
154 Ok(vec![
155 ScalarValue::from(self.count),
156 ScalarValue::from(self.mean),
157 ScalarValue::from(self.m2),
158 ])
159 }
160
161 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
162 let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
163
164 for value in arr {
165 let new_count = self.count + 1.0;
166 let delta1 = value - self.mean;
167 let new_mean = delta1 / new_count + self.mean;
168 let delta2 = value - new_mean;
169 let new_m2 = self.m2 + delta1 * delta2;
170
171 self.count += 1.0;
172 self.mean = new_mean;
173 self.m2 = new_m2;
174 }
175
176 Ok(())
177 }
178
179 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
180 let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
181
182 for value in arr {
183 let new_count = self.count - 1.0;
184 let delta1 = self.mean - value;
185 let new_mean = delta1 / new_count + self.mean;
186 let delta2 = new_mean - value;
187 let new_m2 = self.m2 - delta1 * delta2;
188
189 self.count -= 1.0;
190 self.mean = new_mean;
191 self.m2 = new_m2;
192 }
193
194 Ok(())
195 }
196
197 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
198 let counts = downcast_value!(states[0], Float64Array);
199 let means = downcast_value!(states[1], Float64Array);
200 let m2s = downcast_value!(states[2], Float64Array);
201
202 for i in 0..counts.len() {
203 let c = counts.value(i);
204 if c == 0_f64 {
205 continue;
206 }
207 let new_count = self.count + c;
208 let new_mean = self.mean * self.count / new_count + means.value(i) * c / new_count;
209 let delta = self.mean - means.value(i);
210 let new_m2 = self.m2 + m2s.value(i) + delta * delta * self.count * c / new_count;
211
212 self.count = new_count;
213 self.mean = new_mean;
214 self.m2 = new_m2;
215 }
216 Ok(())
217 }
218
219 fn evaluate(&mut self) -> Result<ScalarValue> {
220 let count = match self.stats_type {
221 StatsType::Population => self.count,
222 StatsType::Sample => {
223 if self.count > 0.0 {
224 self.count - 1.0
225 } else {
226 self.count
227 }
228 }
229 };
230
231 Ok(ScalarValue::Float64(match self.count {
232 0.0 => None,
233 count if count == 1.0 && StatsType::Sample == self.stats_type => {
234 if self.null_on_divide_by_zero {
235 None
236 } else {
237 Some(f64::NAN)
238 }
239 }
240 _ => Some(self.m2 / count),
241 }))
242 }
243
244 fn size(&self) -> usize {
245 std::mem::size_of_val(self)
246 }
247}