1use arrow::{
22 array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
23 buffer::NullBuffer,
24 compute::kernels::cast,
25 datatypes::{DataType, Field},
26};
27use std::mem::{size_of, size_of_val};
28use std::{fmt::Debug, sync::Arc};
29
30use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue};
31use datafusion_expr::{
32 function::{AccumulatorArgs, StateFieldsArgs},
33 utils::format_state_name,
34 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
35 Volatility,
36};
37use datafusion_functions_aggregate_common::{
38 aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
39};
40use datafusion_macros::user_doc;
41
42make_udaf_expr_and_func!(
43 VarianceSample,
44 var_sample,
45 expression,
46 "Computes the sample variance.",
47 var_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51 VariancePopulation,
52 var_pop,
53 expression,
54 "Computes the population variance.",
55 var_pop_udaf
56);
57
58#[user_doc(
59 doc_section(label = "General Functions"),
60 description = "Returns the statistical sample variance of a set of numbers.",
61 syntax_example = "var(expression)",
62 standard_argument(name = "expression", prefix = "Numeric")
63)]
64pub struct VarianceSample {
65 signature: Signature,
66 aliases: Vec<String>,
67}
68
69impl Debug for VarianceSample {
70 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
71 f.debug_struct("VarianceSample")
72 .field("name", &self.name())
73 .field("signature", &self.signature)
74 .finish()
75 }
76}
77
78impl Default for VarianceSample {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl VarianceSample {
85 pub fn new() -> Self {
86 Self {
87 aliases: vec![String::from("var_sample"), String::from("var_samp")],
88 signature: Signature::numeric(1, Volatility::Immutable),
89 }
90 }
91}
92
93impl AggregateUDFImpl for VarianceSample {
94 fn as_any(&self) -> &dyn std::any::Any {
95 self
96 }
97
98 fn name(&self) -> &str {
99 "var"
100 }
101
102 fn signature(&self) -> &Signature {
103 &self.signature
104 }
105
106 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
107 Ok(DataType::Float64)
108 }
109
110 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
111 let name = args.name;
112 Ok(vec![
113 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
114 Field::new(format_state_name(name, "mean"), DataType::Float64, true),
115 Field::new(format_state_name(name, "m2"), DataType::Float64, true),
116 ])
117 }
118
119 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
120 if acc_args.is_distinct {
121 return not_impl_err!("VAR(DISTINCT) aggregations are not available");
122 }
123
124 Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
125 }
126
127 fn aliases(&self) -> &[String] {
128 &self.aliases
129 }
130
131 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
132 !acc_args.is_distinct
133 }
134
135 fn create_groups_accumulator(
136 &self,
137 _args: AccumulatorArgs,
138 ) -> Result<Box<dyn GroupsAccumulator>> {
139 Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample)))
140 }
141
142 fn documentation(&self) -> Option<&Documentation> {
143 self.doc()
144 }
145}
146
147#[user_doc(
148 doc_section(label = "General Functions"),
149 description = "Returns the statistical population variance of a set of numbers.",
150 syntax_example = "var_pop(expression)",
151 standard_argument(name = "expression", prefix = "Numeric")
152)]
153pub struct VariancePopulation {
154 signature: Signature,
155 aliases: Vec<String>,
156}
157
158impl Debug for VariancePopulation {
159 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
160 f.debug_struct("VariancePopulation")
161 .field("name", &self.name())
162 .field("signature", &self.signature)
163 .finish()
164 }
165}
166
167impl Default for VariancePopulation {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl VariancePopulation {
174 pub fn new() -> Self {
175 Self {
176 aliases: vec![String::from("var_population")],
177 signature: Signature::numeric(1, Volatility::Immutable),
178 }
179 }
180}
181
182impl AggregateUDFImpl for VariancePopulation {
183 fn as_any(&self) -> &dyn std::any::Any {
184 self
185 }
186
187 fn name(&self) -> &str {
188 "var_pop"
189 }
190
191 fn signature(&self) -> &Signature {
192 &self.signature
193 }
194
195 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
196 if !arg_types[0].is_numeric() {
197 return plan_err!("Variance requires numeric input types");
198 }
199
200 Ok(DataType::Float64)
201 }
202
203 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
204 let name = args.name;
205 Ok(vec![
206 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
207 Field::new(format_state_name(name, "mean"), DataType::Float64, true),
208 Field::new(format_state_name(name, "m2"), DataType::Float64, true),
209 ])
210 }
211
212 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
213 if acc_args.is_distinct {
214 return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
215 }
216
217 Ok(Box::new(VarianceAccumulator::try_new(
218 StatsType::Population,
219 )?))
220 }
221
222 fn aliases(&self) -> &[String] {
223 &self.aliases
224 }
225
226 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
227 !acc_args.is_distinct
228 }
229
230 fn create_groups_accumulator(
231 &self,
232 _args: AccumulatorArgs,
233 ) -> Result<Box<dyn GroupsAccumulator>> {
234 Ok(Box::new(VarianceGroupsAccumulator::new(
235 StatsType::Population,
236 )))
237 }
238 fn documentation(&self) -> Option<&Documentation> {
239 self.doc()
240 }
241}
242
243#[derive(Debug)]
253pub struct VarianceAccumulator {
254 m2: f64,
255 mean: f64,
256 count: u64,
257 stats_type: StatsType,
258}
259
260impl VarianceAccumulator {
261 pub fn try_new(s_type: StatsType) -> Result<Self> {
263 Ok(Self {
264 m2: 0_f64,
265 mean: 0_f64,
266 count: 0_u64,
267 stats_type: s_type,
268 })
269 }
270
271 pub fn get_count(&self) -> u64 {
272 self.count
273 }
274
275 pub fn get_mean(&self) -> f64 {
276 self.mean
277 }
278
279 pub fn get_m2(&self) -> f64 {
280 self.m2
281 }
282}
283
284#[inline]
285fn merge(
286 count: u64,
287 mean: f64,
288 m2: f64,
289 count2: u64,
290 mean2: f64,
291 m22: f64,
292) -> (u64, f64, f64) {
293 debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
294 let new_count = count + count2;
295 let new_mean =
296 mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
297 let delta = mean - mean2;
298 let new_m2 =
299 m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64;
300
301 (new_count, new_mean, new_m2)
302}
303
304#[inline]
305fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) {
306 let new_count = count + 1;
307 let delta1 = value - mean;
308 let new_mean = delta1 / new_count as f64 + mean;
309 let delta2 = value - new_mean;
310 let new_m2 = m2 + delta1 * delta2;
311
312 (new_count, new_mean, new_m2)
313}
314
315impl Accumulator for VarianceAccumulator {
316 fn state(&mut self) -> Result<Vec<ScalarValue>> {
317 Ok(vec![
318 ScalarValue::from(self.count),
319 ScalarValue::from(self.mean),
320 ScalarValue::from(self.m2),
321 ])
322 }
323
324 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
325 let values = &cast(&values[0], &DataType::Float64)?;
326 let arr = downcast_value!(values, Float64Array).iter().flatten();
327
328 for value in arr {
329 (self.count, self.mean, self.m2) =
330 update(self.count, self.mean, self.m2, value)
331 }
332
333 Ok(())
334 }
335
336 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
337 let values = &cast(&values[0], &DataType::Float64)?;
338 let arr = downcast_value!(values, Float64Array).iter().flatten();
339
340 for value in arr {
341 let new_count = self.count - 1;
342 let delta1 = self.mean - value;
343 let new_mean = delta1 / new_count as f64 + self.mean;
344 let delta2 = new_mean - value;
345 let new_m2 = self.m2 - delta1 * delta2;
346
347 self.count -= 1;
348 self.mean = new_mean;
349 self.m2 = new_m2;
350 }
351
352 Ok(())
353 }
354
355 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
356 let counts = downcast_value!(states[0], UInt64Array);
357 let means = downcast_value!(states[1], Float64Array);
358 let m2s = downcast_value!(states[2], Float64Array);
359
360 for i in 0..counts.len() {
361 let c = counts.value(i);
362 if c == 0_u64 {
363 continue;
364 }
365 (self.count, self.mean, self.m2) = merge(
366 self.count,
367 self.mean,
368 self.m2,
369 c,
370 means.value(i),
371 m2s.value(i),
372 )
373 }
374 Ok(())
375 }
376
377 fn evaluate(&mut self) -> Result<ScalarValue> {
378 let count = match self.stats_type {
379 StatsType::Population => self.count,
380 StatsType::Sample => {
381 if self.count > 0 {
382 self.count - 1
383 } else {
384 self.count
385 }
386 }
387 };
388
389 Ok(ScalarValue::Float64(match self.count {
390 0 => None,
391 1 => {
392 if let StatsType::Population = self.stats_type {
393 Some(0.0)
394 } else {
395 None
396 }
397 }
398 _ => Some(self.m2 / count as f64),
399 }))
400 }
401
402 fn size(&self) -> usize {
403 size_of_val(self)
404 }
405
406 fn supports_retract_batch(&self) -> bool {
407 true
408 }
409}
410
411#[derive(Debug)]
412pub struct VarianceGroupsAccumulator {
413 m2s: Vec<f64>,
414 means: Vec<f64>,
415 counts: Vec<u64>,
416 stats_type: StatsType,
417}
418
419impl VarianceGroupsAccumulator {
420 pub fn new(s_type: StatsType) -> Self {
421 Self {
422 m2s: Vec::new(),
423 means: Vec::new(),
424 counts: Vec::new(),
425 stats_type: s_type,
426 }
427 }
428
429 fn resize(&mut self, total_num_groups: usize) {
430 self.m2s.resize(total_num_groups, 0.0);
431 self.means.resize(total_num_groups, 0.0);
432 self.counts.resize(total_num_groups, 0);
433 }
434
435 fn merge<F>(
436 group_indices: &[usize],
437 counts: &UInt64Array,
438 means: &Float64Array,
439 m2s: &Float64Array,
440 _opt_filter: Option<&BooleanArray>,
441 mut value_fn: F,
442 ) where
443 F: FnMut(usize, u64, f64, f64) + Send,
444 {
445 assert_eq!(counts.null_count(), 0);
446 assert_eq!(means.null_count(), 0);
447 assert_eq!(m2s.null_count(), 0);
448
449 group_indices
450 .iter()
451 .zip(counts.values().iter())
452 .zip(means.values().iter())
453 .zip(m2s.values().iter())
454 .for_each(|(((&group_index, &count), &mean), &m2)| {
455 value_fn(group_index, count, mean, m2);
456 });
457 }
458
459 pub fn variance(
460 &mut self,
461 emit_to: datafusion_expr::EmitTo,
462 ) -> (Vec<f64>, NullBuffer) {
463 let mut counts = emit_to.take_needed(&mut self.counts);
464 let _ = emit_to.take_needed(&mut self.means);
467 let m2s = emit_to.take_needed(&mut self.m2s);
468
469 if let StatsType::Sample = self.stats_type {
470 counts.iter_mut().for_each(|count| {
471 *count = count.saturating_sub(1);
472 });
473 }
474 let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0));
475 let variance = m2s
476 .iter()
477 .zip(counts)
478 .map(|(m2, count)| m2 / count as f64)
479 .collect();
480 (variance, nulls)
481 }
482}
483
484impl GroupsAccumulator for VarianceGroupsAccumulator {
485 fn update_batch(
486 &mut self,
487 values: &[ArrayRef],
488 group_indices: &[usize],
489 opt_filter: Option<&BooleanArray>,
490 total_num_groups: usize,
491 ) -> Result<()> {
492 assert_eq!(values.len(), 1, "single argument to update_batch");
493 let values = &cast(&values[0], &DataType::Float64)?;
494 let values = downcast_value!(values, Float64Array);
495
496 self.resize(total_num_groups);
497 accumulate(group_indices, values, opt_filter, |group_index, value| {
498 let (new_count, new_mean, new_m2) = update(
499 self.counts[group_index],
500 self.means[group_index],
501 self.m2s[group_index],
502 value,
503 );
504 self.counts[group_index] = new_count;
505 self.means[group_index] = new_mean;
506 self.m2s[group_index] = new_m2;
507 });
508 Ok(())
509 }
510
511 fn merge_batch(
512 &mut self,
513 values: &[ArrayRef],
514 group_indices: &[usize],
515 _opt_filter: Option<&BooleanArray>,
517 total_num_groups: usize,
518 ) -> Result<()> {
519 assert_eq!(values.len(), 3, "two arguments to merge_batch");
520 let partial_counts = downcast_value!(values[0], UInt64Array);
522 let partial_means = downcast_value!(values[1], Float64Array);
523 let partial_m2s = downcast_value!(values[2], Float64Array);
524
525 self.resize(total_num_groups);
526 Self::merge(
527 group_indices,
528 partial_counts,
529 partial_means,
530 partial_m2s,
531 None,
532 |group_index, partial_count, partial_mean, partial_m2| {
533 if partial_count == 0 {
534 return;
535 }
536 let (new_count, new_mean, new_m2) = merge(
537 self.counts[group_index],
538 self.means[group_index],
539 self.m2s[group_index],
540 partial_count,
541 partial_mean,
542 partial_m2,
543 );
544 self.counts[group_index] = new_count;
545 self.means[group_index] = new_mean;
546 self.m2s[group_index] = new_m2;
547 },
548 );
549 Ok(())
550 }
551
552 fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
553 let (variances, nulls) = self.variance(emit_to);
554 Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
555 }
556
557 fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
558 let counts = emit_to.take_needed(&mut self.counts);
559 let means = emit_to.take_needed(&mut self.means);
560 let m2s = emit_to.take_needed(&mut self.m2s);
561
562 Ok(vec![
563 Arc::new(UInt64Array::new(counts.into(), None)),
564 Arc::new(Float64Array::new(means.into(), None)),
565 Arc::new(Float64Array::new(m2s.into(), None)),
566 ])
567 }
568
569 fn size(&self) -> usize {
570 self.m2s.capacity() * size_of::<f64>()
571 + self.means.capacity() * size_of::<f64>()
572 + self.counts.capacity() * size_of::<u64>()
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use datafusion_expr::EmitTo;
579
580 use super::*;
581
582 #[test]
583 fn test_groups_accumulator_merge_empty_states() -> Result<()> {
584 let state_1 = vec![
585 Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
586 Arc::new(Float64Array::from(vec![0.0])),
587 Arc::new(Float64Array::from(vec![0.0])),
588 ];
589 let state_2 = vec![
590 Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
591 Arc::new(Float64Array::from(vec![1.0])),
592 Arc::new(Float64Array::from(vec![1.0])),
593 ];
594 let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
595 acc.merge_batch(&state_1, &[0], None, 1)?;
596 acc.merge_batch(&state_2, &[0], None, 1)?;
597 let result = acc.evaluate(EmitTo::All)?;
598 let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
599 assert_eq!(result.len(), 1);
600 assert_eq!(result.value(0), 1.0);
601 Ok(())
602 }
603}