datafusion_comet_spark_expr/agg_funcs/
variance.rs1use arrow::datatypes::FieldRef;
19use arrow::{
20 array::{ArrayRef, Float64Array},
21 datatypes::{DataType, Field},
22};
23use datafusion::common::{downcast_value, Result, ScalarValue};
24use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
25use datafusion::logical_expr::Volatility::Immutable;
26use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature};
27use datafusion::physical_expr::expressions::format_state_name;
28use datafusion::physical_expr::expressions::StatsType;
29use std::any::Any;
30use std::sync::Arc;
31
32#[derive(Debug)]
38pub struct Variance {
39 name: String,
40 signature: Signature,
41 stats_type: StatsType,
42 null_on_divide_by_zero: bool,
43}
44
45impl Variance {
46 pub fn new(
48 name: impl Into<String>,
49 data_type: DataType,
50 stats_type: StatsType,
51 null_on_divide_by_zero: bool,
52 ) -> Self {
53 assert!(matches!(data_type, DataType::Float64));
55 Self {
56 name: name.into(),
57 signature: Signature::numeric(1, Immutable),
58 stats_type,
59 null_on_divide_by_zero,
60 }
61 }
62}
63
64impl AggregateUDFImpl for Variance {
65 fn as_any(&self) -> &dyn Any {
67 self
68 }
69
70 fn name(&self) -> &str {
71 &self.name
72 }
73
74 fn signature(&self) -> &Signature {
75 &self.signature
76 }
77
78 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
79 Ok(DataType::Float64)
80 }
81
82 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
83 Ok(Box::new(VarianceAccumulator::try_new(
84 self.stats_type,
85 self.null_on_divide_by_zero,
86 )?))
87 }
88
89 fn create_sliding_accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
90 Ok(Box::new(VarianceAccumulator::try_new(
91 self.stats_type,
92 self.null_on_divide_by_zero,
93 )?))
94 }
95
96 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
97 Ok(vec![
98 Arc::new(Field::new(
99 format_state_name(&self.name, "count"),
100 DataType::Float64,
101 true,
102 )),
103 Arc::new(Field::new(
104 format_state_name(&self.name, "mean"),
105 DataType::Float64,
106 true,
107 )),
108 Arc::new(Field::new(
109 format_state_name(&self.name, "m2"),
110 DataType::Float64,
111 true,
112 )),
113 ])
114 }
115
116 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
117 Ok(ScalarValue::Float64(None))
118 }
119}
120
121#[derive(Debug)]
123pub struct VarianceAccumulator {
124 m2: f64,
125 mean: f64,
126 count: f64,
127 stats_type: StatsType,
128 null_on_divide_by_zero: bool,
129}
130
131impl VarianceAccumulator {
132 pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result<Self> {
134 Ok(Self {
135 m2: 0_f64,
136 mean: 0_f64,
137 count: 0_f64,
138 stats_type: s_type,
139 null_on_divide_by_zero,
140 })
141 }
142
143 pub fn get_count(&self) -> f64 {
144 self.count
145 }
146
147 pub fn get_mean(&self) -> f64 {
148 self.mean
149 }
150
151 pub fn get_m2(&self) -> f64 {
152 self.m2
153 }
154}
155
156impl Accumulator for VarianceAccumulator {
157 fn state(&mut self) -> Result<Vec<ScalarValue>> {
158 Ok(vec![
159 ScalarValue::from(self.count),
160 ScalarValue::from(self.mean),
161 ScalarValue::from(self.m2),
162 ])
163 }
164
165 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
166 let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
167
168 for value in arr {
169 let new_count = self.count + 1.0;
170 let delta1 = value - self.mean;
171 let new_mean = delta1 / new_count + self.mean;
172 let delta2 = value - new_mean;
173 let new_m2 = self.m2 + delta1 * delta2;
174
175 self.count += 1.0;
176 self.mean = new_mean;
177 self.m2 = new_m2;
178 }
179
180 Ok(())
181 }
182
183 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
184 let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
185
186 for value in arr {
187 let new_count = self.count - 1.0;
188 let delta1 = self.mean - value;
189 let new_mean = delta1 / new_count + self.mean;
190 let delta2 = new_mean - value;
191 let new_m2 = self.m2 - delta1 * delta2;
192
193 self.count -= 1.0;
194 self.mean = new_mean;
195 self.m2 = new_m2;
196 }
197
198 Ok(())
199 }
200
201 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
202 let counts = downcast_value!(states[0], Float64Array);
203 let means = downcast_value!(states[1], Float64Array);
204 let m2s = downcast_value!(states[2], Float64Array);
205
206 for i in 0..counts.len() {
207 let c = counts.value(i);
208 if c == 0_f64 {
209 continue;
210 }
211 let new_count = self.count + c;
212 let new_mean = self.mean * self.count / new_count + means.value(i) * c / new_count;
213 let delta = self.mean - means.value(i);
214 let new_m2 = self.m2 + m2s.value(i) + delta * delta * self.count * c / new_count;
215
216 self.count = new_count;
217 self.mean = new_mean;
218 self.m2 = new_m2;
219 }
220 Ok(())
221 }
222
223 fn evaluate(&mut self) -> Result<ScalarValue> {
224 let count = match self.stats_type {
225 StatsType::Population => self.count,
226 StatsType::Sample => {
227 if self.count > 0.0 {
228 self.count - 1.0
229 } else {
230 self.count
231 }
232 }
233 };
234
235 Ok(ScalarValue::Float64(match self.count {
236 0.0 => None,
237 count if count == 1.0 && StatsType::Sample == self.stats_type => {
238 if self.null_on_divide_by_zero {
239 None
240 } else {
241 Some(f64::NAN)
242 }
243 }
244 _ => Some(self.m2 / count),
245 }))
246 }
247
248 fn size(&self) -> usize {
249 std::mem::size_of_val(self)
250 }
251}