1use arrow::datatypes::{FieldRef, Float64Type};
22use arrow::{
23 array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
24 buffer::NullBuffer,
25 datatypes::{DataType, Field},
26};
27use datafusion_common::cast::{as_float64_array, as_uint64_array};
28use datafusion_common::{Result, ScalarValue};
29use datafusion_expr::{
30 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
31 Volatility,
32 function::{AccumulatorArgs, StateFieldsArgs},
33 utils::format_state_name,
34};
35use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
36use datafusion_functions_aggregate_common::{
37 aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
38};
39use datafusion_macros::user_doc;
40use std::mem::{size_of, size_of_val};
41use std::{fmt::Debug, sync::Arc};
42
43make_udaf_expr_and_func!(
44 VarianceSample,
45 var_sample,
46 expression,
47 "Computes the sample variance.",
48 var_samp_udaf
49);
50
51make_udaf_expr_and_func!(
52 VariancePopulation,
53 var_pop,
54 expression,
55 "Computes the population variance.",
56 var_pop_udaf
57);
58
59#[user_doc(
60 doc_section(label = "General Functions"),
61 description = "Returns the statistical sample variance of a set of numbers.",
62 syntax_example = "var(expression)",
63 standard_argument(name = "expression", prefix = "Numeric")
64)]
65#[derive(PartialEq, Eq, Hash, Debug)]
66pub struct VarianceSample {
67 signature: Signature,
68 aliases: Vec<String>,
69}
70
71impl Default for VarianceSample {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl VarianceSample {
78 pub fn new() -> Self {
79 Self {
80 aliases: vec![String::from("var_sample"), String::from("var_samp")],
81 signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
82 }
83 }
84}
85
86impl AggregateUDFImpl for VarianceSample {
87 fn name(&self) -> &str {
88 "var"
89 }
90
91 fn signature(&self) -> &Signature {
92 &self.signature
93 }
94
95 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
96 Ok(DataType::Float64)
97 }
98
99 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
100 let name = args.name;
101 match args.is_distinct {
102 false => Ok(vec![
103 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
104 Field::new(format_state_name(name, "mean"), DataType::Float64, true),
105 Field::new(format_state_name(name, "m2"), DataType::Float64, true),
106 ]
107 .into_iter()
108 .map(Arc::new)
109 .collect()),
110 true => {
111 let field = Field::new_list_field(DataType::Float64, true);
112 let state_name = "distinct_var";
113 Ok(vec![
114 Field::new(
115 format_state_name(name, state_name),
116 DataType::List(Arc::new(field)),
117 true,
118 )
119 .into(),
120 ])
121 }
122 }
123 }
124
125 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
126 if acc_args.is_distinct {
127 return Ok(Box::new(DistinctVarianceAccumulator::new(
128 StatsType::Sample,
129 )));
130 }
131
132 Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
133 }
134
135 fn aliases(&self) -> &[String] {
136 &self.aliases
137 }
138
139 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
140 !acc_args.is_distinct
141 }
142
143 fn create_groups_accumulator(
144 &self,
145 _args: AccumulatorArgs,
146 ) -> Result<Box<dyn GroupsAccumulator>> {
147 Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample)))
148 }
149
150 fn documentation(&self) -> Option<&Documentation> {
151 self.doc()
152 }
153}
154
155#[user_doc(
156 doc_section(label = "General Functions"),
157 description = "Returns the statistical population variance of a set of numbers.",
158 syntax_example = "var_pop(expression)",
159 standard_argument(name = "expression", prefix = "Numeric")
160)]
161#[derive(PartialEq, Eq, Hash, Debug)]
162pub struct VariancePopulation {
163 signature: Signature,
164 aliases: Vec<String>,
165}
166
167impl Default for VariancePopulation {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl VariancePopulation {
174 pub fn new() -> Self {
175 Self {
176 aliases: vec![String::from("var_population")],
177 signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
178 }
179 }
180}
181
182impl AggregateUDFImpl for VariancePopulation {
183 fn name(&self) -> &str {
184 "var_pop"
185 }
186
187 fn signature(&self) -> &Signature {
188 &self.signature
189 }
190
191 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
192 Ok(DataType::Float64)
193 }
194
195 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
196 match args.is_distinct {
197 false => {
198 let name = args.name;
199 Ok(vec![
200 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
201 Field::new(format_state_name(name, "mean"), DataType::Float64, true),
202 Field::new(format_state_name(name, "m2"), DataType::Float64, true),
203 ]
204 .into_iter()
205 .map(Arc::new)
206 .collect())
207 }
208 true => {
209 let field = Field::new_list_field(DataType::Float64, true);
210 let state_name = "distinct_var";
211 Ok(vec![
212 Field::new(
213 format_state_name(args.name, state_name),
214 DataType::List(Arc::new(field)),
215 true,
216 )
217 .into(),
218 ])
219 }
220 }
221 }
222
223 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
224 if acc_args.is_distinct {
225 return Ok(Box::new(DistinctVarianceAccumulator::new(
226 StatsType::Population,
227 )));
228 }
229
230 Ok(Box::new(VarianceAccumulator::try_new(
231 StatsType::Population,
232 )?))
233 }
234
235 fn aliases(&self) -> &[String] {
236 &self.aliases
237 }
238
239 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
240 !acc_args.is_distinct
241 }
242
243 fn create_groups_accumulator(
244 &self,
245 _args: AccumulatorArgs,
246 ) -> Result<Box<dyn GroupsAccumulator>> {
247 Ok(Box::new(VarianceGroupsAccumulator::new(
248 StatsType::Population,
249 )))
250 }
251
252 fn documentation(&self) -> Option<&Documentation> {
253 self.doc()
254 }
255}
256
257#[derive(Debug)]
267pub struct VarianceAccumulator {
268 m2: f64,
269 mean: f64,
270 count: u64,
271 stats_type: StatsType,
272}
273
274impl VarianceAccumulator {
275 pub fn try_new(s_type: StatsType) -> Result<Self> {
277 Ok(Self {
278 m2: 0_f64,
279 mean: 0_f64,
280 count: 0_u64,
281 stats_type: s_type,
282 })
283 }
284
285 pub fn get_count(&self) -> u64 {
286 self.count
287 }
288
289 pub fn get_mean(&self) -> f64 {
290 self.mean
291 }
292
293 pub fn get_m2(&self) -> f64 {
294 self.m2
295 }
296}
297
298#[inline]
299fn merge(
300 count: u64,
301 mean: f64,
302 m2: f64,
303 count2: u64,
304 mean2: f64,
305 m22: f64,
306) -> (u64, f64, f64) {
307 debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
308 let new_count = count + count2;
309 let new_mean =
310 mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
311 let delta = mean - mean2;
312 let new_m2 =
313 m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64;
314
315 (new_count, new_mean, new_m2)
316}
317
318#[inline]
319fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) {
320 let new_count = count + 1;
321 let delta1 = value - mean;
322 let new_mean = delta1 / new_count as f64 + mean;
323 let delta2 = value - new_mean;
324 let new_m2 = m2 + delta1 * delta2;
325
326 (new_count, new_mean, new_m2)
327}
328
329impl Accumulator for VarianceAccumulator {
330 fn state(&mut self) -> Result<Vec<ScalarValue>> {
331 Ok(vec![
332 ScalarValue::from(self.count),
333 ScalarValue::from(self.mean),
334 ScalarValue::from(self.m2),
335 ])
336 }
337
338 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
339 let arr = as_float64_array(&values[0])?;
340 for value in arr.iter().flatten() {
341 (self.count, self.mean, self.m2) =
342 update(self.count, self.mean, self.m2, value)
343 }
344
345 Ok(())
346 }
347
348 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
349 let arr = as_float64_array(&values[0])?;
350 for value in arr.iter().flatten() {
351 let new_count = self.count - 1;
352 let delta1 = self.mean - value;
353 let new_mean = delta1 / new_count as f64 + self.mean;
354 let delta2 = new_mean - value;
355 let new_m2 = self.m2 - delta1 * delta2;
356
357 self.count -= 1;
358 self.mean = new_mean;
359 self.m2 = new_m2;
360 }
361
362 Ok(())
363 }
364
365 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
366 let counts = as_uint64_array(&states[0])?;
367 let means = as_float64_array(&states[1])?;
368 let m2s = as_float64_array(&states[2])?;
369
370 for i in 0..counts.len() {
371 let c = counts.value(i);
372 if c == 0_u64 {
373 continue;
374 }
375 (self.count, self.mean, self.m2) = merge(
376 self.count,
377 self.mean,
378 self.m2,
379 c,
380 means.value(i),
381 m2s.value(i),
382 )
383 }
384 Ok(())
385 }
386
387 fn evaluate(&mut self) -> Result<ScalarValue> {
388 let count = match self.stats_type {
389 StatsType::Population => self.count,
390 StatsType::Sample => {
391 if self.count > 0 {
392 self.count - 1
393 } else {
394 self.count
395 }
396 }
397 };
398
399 Ok(ScalarValue::Float64(match self.count {
400 0 => None,
401 1 => {
402 if let StatsType::Population = self.stats_type {
403 Some(0.0)
404 } else {
405 None
406 }
407 }
408 _ => Some(self.m2 / count as f64),
409 }))
410 }
411
412 fn size(&self) -> usize {
413 size_of_val(self)
414 }
415
416 fn supports_retract_batch(&self) -> bool {
417 true
418 }
419}
420
421#[derive(Debug)]
422pub struct VarianceGroupsAccumulator {
423 m2s: Vec<f64>,
424 means: Vec<f64>,
425 counts: Vec<u64>,
426 stats_type: StatsType,
427}
428
429impl VarianceGroupsAccumulator {
430 pub fn new(s_type: StatsType) -> Self {
431 Self {
432 m2s: Vec::new(),
433 means: Vec::new(),
434 counts: Vec::new(),
435 stats_type: s_type,
436 }
437 }
438
439 fn resize(&mut self, total_num_groups: usize) {
440 self.m2s.resize(total_num_groups, 0.0);
441 self.means.resize(total_num_groups, 0.0);
442 self.counts.resize(total_num_groups, 0);
443 }
444
445 fn merge<F>(
446 group_indices: &[usize],
447 counts: &UInt64Array,
448 means: &Float64Array,
449 m2s: &Float64Array,
450 _opt_filter: Option<&BooleanArray>,
451 mut value_fn: F,
452 ) where
453 F: FnMut(usize, u64, f64, f64) + Send,
454 {
455 assert_eq!(counts.null_count(), 0);
456 assert_eq!(means.null_count(), 0);
457 assert_eq!(m2s.null_count(), 0);
458
459 group_indices
460 .iter()
461 .zip(counts.values().iter())
462 .zip(means.values().iter())
463 .zip(m2s.values().iter())
464 .for_each(|(((&group_index, &count), &mean), &m2)| {
465 value_fn(group_index, count, mean, m2);
466 });
467 }
468
469 pub fn variance(
470 &mut self,
471 emit_to: datafusion_expr::EmitTo,
472 ) -> (Vec<f64>, NullBuffer) {
473 let mut counts = emit_to.take_needed(&mut self.counts);
474 let _ = emit_to.take_needed(&mut self.means);
477 let m2s = emit_to.take_needed(&mut self.m2s);
478
479 if let StatsType::Sample = self.stats_type {
480 counts.iter_mut().for_each(|count| {
481 *count = count.saturating_sub(1);
482 });
483 }
484 let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0));
485 let variance = m2s
486 .iter()
487 .zip(counts)
488 .map(|(m2, count)| m2 / count as f64)
489 .collect();
490 (variance, nulls)
491 }
492}
493
494impl GroupsAccumulator for VarianceGroupsAccumulator {
495 fn update_batch(
496 &mut self,
497 values: &[ArrayRef],
498 group_indices: &[usize],
499 opt_filter: Option<&BooleanArray>,
500 total_num_groups: usize,
501 ) -> Result<()> {
502 assert_eq!(values.len(), 1, "single argument to update_batch");
503 let values = as_float64_array(&values[0])?;
504
505 self.resize(total_num_groups);
506 accumulate(group_indices, values, opt_filter, |group_index, value| {
507 let (new_count, new_mean, new_m2) = update(
508 self.counts[group_index],
509 self.means[group_index],
510 self.m2s[group_index],
511 value,
512 );
513 self.counts[group_index] = new_count;
514 self.means[group_index] = new_mean;
515 self.m2s[group_index] = new_m2;
516 });
517 Ok(())
518 }
519
520 fn merge_batch(
521 &mut self,
522 values: &[ArrayRef],
523 group_indices: &[usize],
524 _opt_filter: Option<&BooleanArray>,
526 total_num_groups: usize,
527 ) -> Result<()> {
528 assert_eq!(values.len(), 3, "two arguments to merge_batch");
529 let partial_counts = as_uint64_array(&values[0])?;
531 let partial_means = as_float64_array(&values[1])?;
532 let partial_m2s = as_float64_array(&values[2])?;
533
534 self.resize(total_num_groups);
535 Self::merge(
536 group_indices,
537 partial_counts,
538 partial_means,
539 partial_m2s,
540 None,
541 |group_index, partial_count, partial_mean, partial_m2| {
542 if partial_count == 0 {
543 return;
544 }
545 let (new_count, new_mean, new_m2) = merge(
546 self.counts[group_index],
547 self.means[group_index],
548 self.m2s[group_index],
549 partial_count,
550 partial_mean,
551 partial_m2,
552 );
553 self.counts[group_index] = new_count;
554 self.means[group_index] = new_mean;
555 self.m2s[group_index] = new_m2;
556 },
557 );
558 Ok(())
559 }
560
561 fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
562 let (variances, nulls) = self.variance(emit_to);
563 Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
564 }
565
566 fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
567 let counts = emit_to.take_needed(&mut self.counts);
568 let means = emit_to.take_needed(&mut self.means);
569 let m2s = emit_to.take_needed(&mut self.m2s);
570
571 Ok(vec![
572 Arc::new(UInt64Array::new(counts.into(), None)),
573 Arc::new(Float64Array::new(means.into(), None)),
574 Arc::new(Float64Array::new(m2s.into(), None)),
575 ])
576 }
577
578 fn size(&self) -> usize {
579 self.m2s.capacity() * size_of::<f64>()
580 + self.means.capacity() * size_of::<f64>()
581 + self.counts.capacity() * size_of::<u64>()
582 }
583}
584
585#[derive(Debug)]
586pub struct DistinctVarianceAccumulator {
587 distinct_values: GenericDistinctBuffer<Float64Type>,
588 stat_type: StatsType,
589}
590
591impl DistinctVarianceAccumulator {
592 pub fn new(stat_type: StatsType) -> Self {
593 Self {
594 distinct_values: GenericDistinctBuffer::<Float64Type>::new(DataType::Float64),
595 stat_type,
596 }
597 }
598}
599
600impl Accumulator for DistinctVarianceAccumulator {
601 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
602 self.distinct_values.update_batch(values)
603 }
604
605 fn evaluate(&mut self) -> Result<ScalarValue> {
606 let values = self
607 .distinct_values
608 .values
609 .iter()
610 .map(|v| v.0)
611 .collect::<Vec<_>>();
612
613 let count = match self.stat_type {
614 StatsType::Sample => {
615 if !values.is_empty() {
616 values.len() - 1
617 } else {
618 0
619 }
620 }
621 StatsType::Population => values.len(),
622 };
623
624 let mean = values.iter().sum::<f64>() / values.len() as f64;
625 let m2 = values.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>();
626
627 Ok(ScalarValue::Float64(match values.len() {
628 0 => None,
629 1 => match self.stat_type {
630 StatsType::Population => Some(0.0),
631 StatsType::Sample => None,
632 },
633 _ => Some(m2 / count as f64),
634 }))
635 }
636
637 fn size(&self) -> usize {
638 size_of_val(self) + self.distinct_values.size()
639 }
640
641 fn state(&mut self) -> Result<Vec<ScalarValue>> {
642 self.distinct_values.state()
643 }
644
645 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
646 self.distinct_values.merge_batch(states)
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use datafusion_expr::EmitTo;
653
654 use super::*;
655
656 #[test]
657 fn test_groups_accumulator_merge_empty_states() -> Result<()> {
658 let state_1 = vec![
659 Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
660 Arc::new(Float64Array::from(vec![0.0])),
661 Arc::new(Float64Array::from(vec![0.0])),
662 ];
663 let state_2 = vec![
664 Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
665 Arc::new(Float64Array::from(vec![1.0])),
666 Arc::new(Float64Array::from(vec![1.0])),
667 ];
668 let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
669 acc.merge_batch(&state_1, &[0], None, 1)?;
670 acc.merge_batch(&state_2, &[0], None, 1)?;
671 let result = acc.evaluate(EmitTo::All)?;
672 let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
673 assert_eq!(result.len(), 1);
674 assert_eq!(result.value(0), 1.0);
675 Ok(())
676 }
677}