1use arrow::datatypes::FieldRef;
22use arrow::{
23 array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
24 buffer::NullBuffer,
25 compute::kernels::cast,
26 datatypes::{DataType, Field},
27};
28use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue};
29use datafusion_expr::{
30 function::{AccumulatorArgs, StateFieldsArgs},
31 utils::format_state_name,
32 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
33 Volatility,
34};
35use datafusion_functions_aggregate_common::{
36 aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
37};
38use datafusion_macros::user_doc;
39use std::mem::{size_of, size_of_val};
40use std::{fmt::Debug, sync::Arc};
41
42make_udaf_expr_and_func!(
43 VarianceSample,
44 var_sample,
45 expression,
46 "Computes the sample variance.",
47 var_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51 VariancePopulation,
52 var_pop,
53 expression,
54 "Computes the population variance.",
55 var_pop_udaf
56);
57
58#[user_doc(
59 doc_section(label = "General Functions"),
60 description = "Returns the statistical sample variance of a set of numbers.",
61 syntax_example = "var(expression)",
62 standard_argument(name = "expression", prefix = "Numeric")
63)]
64pub struct VarianceSample {
65 signature: Signature,
66 aliases: Vec<String>,
67}
68
69impl Debug for VarianceSample {
70 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
71 f.debug_struct("VarianceSample")
72 .field("name", &self.name())
73 .field("signature", &self.signature)
74 .finish()
75 }
76}
77
78impl Default for VarianceSample {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl VarianceSample {
85 pub fn new() -> Self {
86 Self {
87 aliases: vec![String::from("var_sample"), String::from("var_samp")],
88 signature: Signature::numeric(1, Volatility::Immutable),
89 }
90 }
91}
92
93impl AggregateUDFImpl for VarianceSample {
94 fn as_any(&self) -> &dyn std::any::Any {
95 self
96 }
97
98 fn name(&self) -> &str {
99 "var"
100 }
101
102 fn signature(&self) -> &Signature {
103 &self.signature
104 }
105
106 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
107 Ok(DataType::Float64)
108 }
109
110 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
111 let name = args.name;
112 Ok(vec![
113 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
114 Field::new(format_state_name(name, "mean"), DataType::Float64, true),
115 Field::new(format_state_name(name, "m2"), DataType::Float64, true),
116 ]
117 .into_iter()
118 .map(Arc::new)
119 .collect())
120 }
121
122 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
123 if acc_args.is_distinct {
124 return not_impl_err!("VAR(DISTINCT) aggregations are not available");
125 }
126
127 Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
128 }
129
130 fn aliases(&self) -> &[String] {
131 &self.aliases
132 }
133
134 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
135 !acc_args.is_distinct
136 }
137
138 fn create_groups_accumulator(
139 &self,
140 _args: AccumulatorArgs,
141 ) -> Result<Box<dyn GroupsAccumulator>> {
142 Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample)))
143 }
144
145 fn documentation(&self) -> Option<&Documentation> {
146 self.doc()
147 }
148}
149
150#[user_doc(
151 doc_section(label = "General Functions"),
152 description = "Returns the statistical population variance of a set of numbers.",
153 syntax_example = "var_pop(expression)",
154 standard_argument(name = "expression", prefix = "Numeric")
155)]
156pub struct VariancePopulation {
157 signature: Signature,
158 aliases: Vec<String>,
159}
160
161impl Debug for VariancePopulation {
162 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
163 f.debug_struct("VariancePopulation")
164 .field("name", &self.name())
165 .field("signature", &self.signature)
166 .finish()
167 }
168}
169
170impl Default for VariancePopulation {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176impl VariancePopulation {
177 pub fn new() -> Self {
178 Self {
179 aliases: vec![String::from("var_population")],
180 signature: Signature::numeric(1, Volatility::Immutable),
181 }
182 }
183}
184
185impl AggregateUDFImpl for VariancePopulation {
186 fn as_any(&self) -> &dyn std::any::Any {
187 self
188 }
189
190 fn name(&self) -> &str {
191 "var_pop"
192 }
193
194 fn signature(&self) -> &Signature {
195 &self.signature
196 }
197
198 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
199 if !arg_types[0].is_numeric() {
200 return plan_err!("Variance requires numeric input types");
201 }
202
203 Ok(DataType::Float64)
204 }
205
206 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
207 let name = args.name;
208 Ok(vec![
209 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
210 Field::new(format_state_name(name, "mean"), DataType::Float64, true),
211 Field::new(format_state_name(name, "m2"), DataType::Float64, true),
212 ]
213 .into_iter()
214 .map(Arc::new)
215 .collect())
216 }
217
218 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
219 if acc_args.is_distinct {
220 return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
221 }
222
223 Ok(Box::new(VarianceAccumulator::try_new(
224 StatsType::Population,
225 )?))
226 }
227
228 fn aliases(&self) -> &[String] {
229 &self.aliases
230 }
231
232 fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
233 !acc_args.is_distinct
234 }
235
236 fn create_groups_accumulator(
237 &self,
238 _args: AccumulatorArgs,
239 ) -> Result<Box<dyn GroupsAccumulator>> {
240 Ok(Box::new(VarianceGroupsAccumulator::new(
241 StatsType::Population,
242 )))
243 }
244 fn documentation(&self) -> Option<&Documentation> {
245 self.doc()
246 }
247}
248
249#[derive(Debug)]
259pub struct VarianceAccumulator {
260 m2: f64,
261 mean: f64,
262 count: u64,
263 stats_type: StatsType,
264}
265
266impl VarianceAccumulator {
267 pub fn try_new(s_type: StatsType) -> Result<Self> {
269 Ok(Self {
270 m2: 0_f64,
271 mean: 0_f64,
272 count: 0_u64,
273 stats_type: s_type,
274 })
275 }
276
277 pub fn get_count(&self) -> u64 {
278 self.count
279 }
280
281 pub fn get_mean(&self) -> f64 {
282 self.mean
283 }
284
285 pub fn get_m2(&self) -> f64 {
286 self.m2
287 }
288}
289
290#[inline]
291fn merge(
292 count: u64,
293 mean: f64,
294 m2: f64,
295 count2: u64,
296 mean2: f64,
297 m22: f64,
298) -> (u64, f64, f64) {
299 debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
300 let new_count = count + count2;
301 let new_mean =
302 mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
303 let delta = mean - mean2;
304 let new_m2 =
305 m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64;
306
307 (new_count, new_mean, new_m2)
308}
309
310#[inline]
311fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) {
312 let new_count = count + 1;
313 let delta1 = value - mean;
314 let new_mean = delta1 / new_count as f64 + mean;
315 let delta2 = value - new_mean;
316 let new_m2 = m2 + delta1 * delta2;
317
318 (new_count, new_mean, new_m2)
319}
320
321impl Accumulator for VarianceAccumulator {
322 fn state(&mut self) -> Result<Vec<ScalarValue>> {
323 Ok(vec![
324 ScalarValue::from(self.count),
325 ScalarValue::from(self.mean),
326 ScalarValue::from(self.m2),
327 ])
328 }
329
330 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
331 let values = &cast(&values[0], &DataType::Float64)?;
332 let arr = downcast_value!(values, Float64Array).iter().flatten();
333
334 for value in arr {
335 (self.count, self.mean, self.m2) =
336 update(self.count, self.mean, self.m2, value)
337 }
338
339 Ok(())
340 }
341
342 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
343 let values = &cast(&values[0], &DataType::Float64)?;
344 let arr = downcast_value!(values, Float64Array).iter().flatten();
345
346 for value in arr {
347 let new_count = self.count - 1;
348 let delta1 = self.mean - value;
349 let new_mean = delta1 / new_count as f64 + self.mean;
350 let delta2 = new_mean - value;
351 let new_m2 = self.m2 - delta1 * delta2;
352
353 self.count -= 1;
354 self.mean = new_mean;
355 self.m2 = new_m2;
356 }
357
358 Ok(())
359 }
360
361 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
362 let counts = downcast_value!(states[0], UInt64Array);
363 let means = downcast_value!(states[1], Float64Array);
364 let m2s = downcast_value!(states[2], Float64Array);
365
366 for i in 0..counts.len() {
367 let c = counts.value(i);
368 if c == 0_u64 {
369 continue;
370 }
371 (self.count, self.mean, self.m2) = merge(
372 self.count,
373 self.mean,
374 self.m2,
375 c,
376 means.value(i),
377 m2s.value(i),
378 )
379 }
380 Ok(())
381 }
382
383 fn evaluate(&mut self) -> Result<ScalarValue> {
384 let count = match self.stats_type {
385 StatsType::Population => self.count,
386 StatsType::Sample => {
387 if self.count > 0 {
388 self.count - 1
389 } else {
390 self.count
391 }
392 }
393 };
394
395 Ok(ScalarValue::Float64(match self.count {
396 0 => None,
397 1 => {
398 if let StatsType::Population = self.stats_type {
399 Some(0.0)
400 } else {
401 None
402 }
403 }
404 _ => Some(self.m2 / count as f64),
405 }))
406 }
407
408 fn size(&self) -> usize {
409 size_of_val(self)
410 }
411
412 fn supports_retract_batch(&self) -> bool {
413 true
414 }
415}
416
417#[derive(Debug)]
418pub struct VarianceGroupsAccumulator {
419 m2s: Vec<f64>,
420 means: Vec<f64>,
421 counts: Vec<u64>,
422 stats_type: StatsType,
423}
424
425impl VarianceGroupsAccumulator {
426 pub fn new(s_type: StatsType) -> Self {
427 Self {
428 m2s: Vec::new(),
429 means: Vec::new(),
430 counts: Vec::new(),
431 stats_type: s_type,
432 }
433 }
434
435 fn resize(&mut self, total_num_groups: usize) {
436 self.m2s.resize(total_num_groups, 0.0);
437 self.means.resize(total_num_groups, 0.0);
438 self.counts.resize(total_num_groups, 0);
439 }
440
441 fn merge<F>(
442 group_indices: &[usize],
443 counts: &UInt64Array,
444 means: &Float64Array,
445 m2s: &Float64Array,
446 _opt_filter: Option<&BooleanArray>,
447 mut value_fn: F,
448 ) where
449 F: FnMut(usize, u64, f64, f64) + Send,
450 {
451 assert_eq!(counts.null_count(), 0);
452 assert_eq!(means.null_count(), 0);
453 assert_eq!(m2s.null_count(), 0);
454
455 group_indices
456 .iter()
457 .zip(counts.values().iter())
458 .zip(means.values().iter())
459 .zip(m2s.values().iter())
460 .for_each(|(((&group_index, &count), &mean), &m2)| {
461 value_fn(group_index, count, mean, m2);
462 });
463 }
464
465 pub fn variance(
466 &mut self,
467 emit_to: datafusion_expr::EmitTo,
468 ) -> (Vec<f64>, NullBuffer) {
469 let mut counts = emit_to.take_needed(&mut self.counts);
470 let _ = emit_to.take_needed(&mut self.means);
473 let m2s = emit_to.take_needed(&mut self.m2s);
474
475 if let StatsType::Sample = self.stats_type {
476 counts.iter_mut().for_each(|count| {
477 *count = count.saturating_sub(1);
478 });
479 }
480 let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0));
481 let variance = m2s
482 .iter()
483 .zip(counts)
484 .map(|(m2, count)| m2 / count as f64)
485 .collect();
486 (variance, nulls)
487 }
488}
489
490impl GroupsAccumulator for VarianceGroupsAccumulator {
491 fn update_batch(
492 &mut self,
493 values: &[ArrayRef],
494 group_indices: &[usize],
495 opt_filter: Option<&BooleanArray>,
496 total_num_groups: usize,
497 ) -> Result<()> {
498 assert_eq!(values.len(), 1, "single argument to update_batch");
499 let values = &cast(&values[0], &DataType::Float64)?;
500 let values = downcast_value!(values, Float64Array);
501
502 self.resize(total_num_groups);
503 accumulate(group_indices, values, opt_filter, |group_index, value| {
504 let (new_count, new_mean, new_m2) = update(
505 self.counts[group_index],
506 self.means[group_index],
507 self.m2s[group_index],
508 value,
509 );
510 self.counts[group_index] = new_count;
511 self.means[group_index] = new_mean;
512 self.m2s[group_index] = new_m2;
513 });
514 Ok(())
515 }
516
517 fn merge_batch(
518 &mut self,
519 values: &[ArrayRef],
520 group_indices: &[usize],
521 _opt_filter: Option<&BooleanArray>,
523 total_num_groups: usize,
524 ) -> Result<()> {
525 assert_eq!(values.len(), 3, "two arguments to merge_batch");
526 let partial_counts = downcast_value!(values[0], UInt64Array);
528 let partial_means = downcast_value!(values[1], Float64Array);
529 let partial_m2s = downcast_value!(values[2], Float64Array);
530
531 self.resize(total_num_groups);
532 Self::merge(
533 group_indices,
534 partial_counts,
535 partial_means,
536 partial_m2s,
537 None,
538 |group_index, partial_count, partial_mean, partial_m2| {
539 if partial_count == 0 {
540 return;
541 }
542 let (new_count, new_mean, new_m2) = merge(
543 self.counts[group_index],
544 self.means[group_index],
545 self.m2s[group_index],
546 partial_count,
547 partial_mean,
548 partial_m2,
549 );
550 self.counts[group_index] = new_count;
551 self.means[group_index] = new_mean;
552 self.m2s[group_index] = new_m2;
553 },
554 );
555 Ok(())
556 }
557
558 fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
559 let (variances, nulls) = self.variance(emit_to);
560 Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
561 }
562
563 fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
564 let counts = emit_to.take_needed(&mut self.counts);
565 let means = emit_to.take_needed(&mut self.means);
566 let m2s = emit_to.take_needed(&mut self.m2s);
567
568 Ok(vec![
569 Arc::new(UInt64Array::new(counts.into(), None)),
570 Arc::new(Float64Array::new(means.into(), None)),
571 Arc::new(Float64Array::new(m2s.into(), None)),
572 ])
573 }
574
575 fn size(&self) -> usize {
576 self.m2s.capacity() * size_of::<f64>()
577 + self.means.capacity() * size_of::<f64>()
578 + self.counts.capacity() * size_of::<u64>()
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use datafusion_expr::EmitTo;
585
586 use super::*;
587
588 #[test]
589 fn test_groups_accumulator_merge_empty_states() -> Result<()> {
590 let state_1 = vec![
591 Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
592 Arc::new(Float64Array::from(vec![0.0])),
593 Arc::new(Float64Array::from(vec![0.0])),
594 ];
595 let state_2 = vec![
596 Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
597 Arc::new(Float64Array::from(vec![1.0])),
598 Arc::new(Float64Array::from(vec![1.0])),
599 ];
600 let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
601 acc.merge_batch(&state_1, &[0], None, 1)?;
602 acc.merge_batch(&state_2, &[0], None, 1)?;
603 let result = acc.evaluate(EmitTo::All)?;
604 let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
605 assert_eq!(result.len(), 1);
606 assert_eq!(result.value(0), 1.0);
607 Ok(())
608 }
609}