Skip to main content

datafusion_functions_aggregate/
variance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`VarianceSample`]: variance sample aggregations.
19//! [`VariancePopulation`]: variance population aggregations.
20
21use arrow::datatypes::{FieldRef, Float64Type};
22use arrow::{
23    array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
24    buffer::NullBuffer,
25    datatypes::{DataType, Field},
26};
27use datafusion_common::cast::{as_float64_array, as_uint64_array};
28use datafusion_common::{Result, ScalarValue};
29use datafusion_expr::{
30    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
31    Volatility,
32    function::{AccumulatorArgs, StateFieldsArgs},
33    utils::format_state_name,
34};
35use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
36use datafusion_functions_aggregate_common::{
37    aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
38};
39use datafusion_macros::user_doc;
40use std::mem::{size_of, size_of_val};
41use std::{fmt::Debug, sync::Arc};
42
43make_udaf_expr_and_func!(
44    VarianceSample,
45    var_sample,
46    expression,
47    "Computes the sample variance.",
48    var_samp_udaf
49);
50
51make_udaf_expr_and_func!(
52    VariancePopulation,
53    var_pop,
54    expression,
55    "Computes the population variance.",
56    var_pop_udaf
57);
58
59#[user_doc(
60    doc_section(label = "General Functions"),
61    description = "Returns the statistical sample variance of a set of numbers.",
62    syntax_example = "var(expression)",
63    standard_argument(name = "expression", prefix = "Numeric")
64)]
65#[derive(PartialEq, Eq, Hash, Debug)]
66pub struct VarianceSample {
67    signature: Signature,
68    aliases: Vec<String>,
69}
70
71impl Default for VarianceSample {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl VarianceSample {
78    pub fn new() -> Self {
79        Self {
80            aliases: vec![String::from("var_sample"), String::from("var_samp")],
81            signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
82        }
83    }
84}
85
86impl AggregateUDFImpl for VarianceSample {
87    fn as_any(&self) -> &dyn std::any::Any {
88        self
89    }
90
91    fn name(&self) -> &str {
92        "var"
93    }
94
95    fn signature(&self) -> &Signature {
96        &self.signature
97    }
98
99    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
100        Ok(DataType::Float64)
101    }
102
103    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
104        let name = args.name;
105        match args.is_distinct {
106            false => Ok(vec![
107                Field::new(format_state_name(name, "count"), DataType::UInt64, true),
108                Field::new(format_state_name(name, "mean"), DataType::Float64, true),
109                Field::new(format_state_name(name, "m2"), DataType::Float64, true),
110            ]
111            .into_iter()
112            .map(Arc::new)
113            .collect()),
114            true => {
115                let field = Field::new_list_field(DataType::Float64, true);
116                let state_name = "distinct_var";
117                Ok(vec![
118                    Field::new(
119                        format_state_name(name, state_name),
120                        DataType::List(Arc::new(field)),
121                        true,
122                    )
123                    .into(),
124                ])
125            }
126        }
127    }
128
129    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
130        if acc_args.is_distinct {
131            return Ok(Box::new(DistinctVarianceAccumulator::new(
132                StatsType::Sample,
133            )));
134        }
135
136        Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
137    }
138
139    fn aliases(&self) -> &[String] {
140        &self.aliases
141    }
142
143    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
144        !acc_args.is_distinct
145    }
146
147    fn create_groups_accumulator(
148        &self,
149        _args: AccumulatorArgs,
150    ) -> Result<Box<dyn GroupsAccumulator>> {
151        Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample)))
152    }
153
154    fn documentation(&self) -> Option<&Documentation> {
155        self.doc()
156    }
157}
158
159#[user_doc(
160    doc_section(label = "General Functions"),
161    description = "Returns the statistical population variance of a set of numbers.",
162    syntax_example = "var_pop(expression)",
163    standard_argument(name = "expression", prefix = "Numeric")
164)]
165#[derive(PartialEq, Eq, Hash, Debug)]
166pub struct VariancePopulation {
167    signature: Signature,
168    aliases: Vec<String>,
169}
170
171impl Default for VariancePopulation {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl VariancePopulation {
178    pub fn new() -> Self {
179        Self {
180            aliases: vec![String::from("var_population")],
181            signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
182        }
183    }
184}
185
186impl AggregateUDFImpl for VariancePopulation {
187    fn as_any(&self) -> &dyn std::any::Any {
188        self
189    }
190
191    fn name(&self) -> &str {
192        "var_pop"
193    }
194
195    fn signature(&self) -> &Signature {
196        &self.signature
197    }
198
199    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
200        Ok(DataType::Float64)
201    }
202
203    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
204        match args.is_distinct {
205            false => {
206                let name = args.name;
207                Ok(vec![
208                    Field::new(format_state_name(name, "count"), DataType::UInt64, true),
209                    Field::new(format_state_name(name, "mean"), DataType::Float64, true),
210                    Field::new(format_state_name(name, "m2"), DataType::Float64, true),
211                ]
212                .into_iter()
213                .map(Arc::new)
214                .collect())
215            }
216            true => {
217                let field = Field::new_list_field(DataType::Float64, true);
218                let state_name = "distinct_var";
219                Ok(vec![
220                    Field::new(
221                        format_state_name(args.name, state_name),
222                        DataType::List(Arc::new(field)),
223                        true,
224                    )
225                    .into(),
226                ])
227            }
228        }
229    }
230
231    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
232        if acc_args.is_distinct {
233            return Ok(Box::new(DistinctVarianceAccumulator::new(
234                StatsType::Population,
235            )));
236        }
237
238        Ok(Box::new(VarianceAccumulator::try_new(
239            StatsType::Population,
240        )?))
241    }
242
243    fn aliases(&self) -> &[String] {
244        &self.aliases
245    }
246
247    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
248        !acc_args.is_distinct
249    }
250
251    fn create_groups_accumulator(
252        &self,
253        _args: AccumulatorArgs,
254    ) -> Result<Box<dyn GroupsAccumulator>> {
255        Ok(Box::new(VarianceGroupsAccumulator::new(
256            StatsType::Population,
257        )))
258    }
259
260    fn documentation(&self) -> Option<&Documentation> {
261        self.doc()
262    }
263}
264
265/// An accumulator to compute variance
266/// The algorithm used is an online implementation and numerically stable. It is based on this paper:
267/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
268/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
269///
270/// The algorithm has been analyzed here:
271/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
272/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
273
274#[derive(Debug)]
275pub struct VarianceAccumulator {
276    m2: f64,
277    mean: f64,
278    count: u64,
279    stats_type: StatsType,
280}
281
282impl VarianceAccumulator {
283    /// Creates a new `VarianceAccumulator`
284    pub fn try_new(s_type: StatsType) -> Result<Self> {
285        Ok(Self {
286            m2: 0_f64,
287            mean: 0_f64,
288            count: 0_u64,
289            stats_type: s_type,
290        })
291    }
292
293    pub fn get_count(&self) -> u64 {
294        self.count
295    }
296
297    pub fn get_mean(&self) -> f64 {
298        self.mean
299    }
300
301    pub fn get_m2(&self) -> f64 {
302        self.m2
303    }
304}
305
306#[inline]
307fn merge(
308    count: u64,
309    mean: f64,
310    m2: f64,
311    count2: u64,
312    mean2: f64,
313    m22: f64,
314) -> (u64, f64, f64) {
315    debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
316    let new_count = count + count2;
317    let new_mean =
318        mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
319    let delta = mean - mean2;
320    let new_m2 =
321        m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64;
322
323    (new_count, new_mean, new_m2)
324}
325
326#[inline]
327fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) {
328    let new_count = count + 1;
329    let delta1 = value - mean;
330    let new_mean = delta1 / new_count as f64 + mean;
331    let delta2 = value - new_mean;
332    let new_m2 = m2 + delta1 * delta2;
333
334    (new_count, new_mean, new_m2)
335}
336
337impl Accumulator for VarianceAccumulator {
338    fn state(&mut self) -> Result<Vec<ScalarValue>> {
339        Ok(vec![
340            ScalarValue::from(self.count),
341            ScalarValue::from(self.mean),
342            ScalarValue::from(self.m2),
343        ])
344    }
345
346    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
347        let arr = as_float64_array(&values[0])?;
348        for value in arr.iter().flatten() {
349            (self.count, self.mean, self.m2) =
350                update(self.count, self.mean, self.m2, value)
351        }
352
353        Ok(())
354    }
355
356    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
357        let arr = as_float64_array(&values[0])?;
358        for value in arr.iter().flatten() {
359            let new_count = self.count - 1;
360            let delta1 = self.mean - value;
361            let new_mean = delta1 / new_count as f64 + self.mean;
362            let delta2 = new_mean - value;
363            let new_m2 = self.m2 - delta1 * delta2;
364
365            self.count -= 1;
366            self.mean = new_mean;
367            self.m2 = new_m2;
368        }
369
370        Ok(())
371    }
372
373    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
374        let counts = as_uint64_array(&states[0])?;
375        let means = as_float64_array(&states[1])?;
376        let m2s = as_float64_array(&states[2])?;
377
378        for i in 0..counts.len() {
379            let c = counts.value(i);
380            if c == 0_u64 {
381                continue;
382            }
383            (self.count, self.mean, self.m2) = merge(
384                self.count,
385                self.mean,
386                self.m2,
387                c,
388                means.value(i),
389                m2s.value(i),
390            )
391        }
392        Ok(())
393    }
394
395    fn evaluate(&mut self) -> Result<ScalarValue> {
396        let count = match self.stats_type {
397            StatsType::Population => self.count,
398            StatsType::Sample => {
399                if self.count > 0 {
400                    self.count - 1
401                } else {
402                    self.count
403                }
404            }
405        };
406
407        Ok(ScalarValue::Float64(match self.count {
408            0 => None,
409            1 => {
410                if let StatsType::Population = self.stats_type {
411                    Some(0.0)
412                } else {
413                    None
414                }
415            }
416            _ => Some(self.m2 / count as f64),
417        }))
418    }
419
420    fn size(&self) -> usize {
421        size_of_val(self)
422    }
423
424    fn supports_retract_batch(&self) -> bool {
425        true
426    }
427}
428
429#[derive(Debug)]
430pub struct VarianceGroupsAccumulator {
431    m2s: Vec<f64>,
432    means: Vec<f64>,
433    counts: Vec<u64>,
434    stats_type: StatsType,
435}
436
437impl VarianceGroupsAccumulator {
438    pub fn new(s_type: StatsType) -> Self {
439        Self {
440            m2s: Vec::new(),
441            means: Vec::new(),
442            counts: Vec::new(),
443            stats_type: s_type,
444        }
445    }
446
447    fn resize(&mut self, total_num_groups: usize) {
448        self.m2s.resize(total_num_groups, 0.0);
449        self.means.resize(total_num_groups, 0.0);
450        self.counts.resize(total_num_groups, 0);
451    }
452
453    fn merge<F>(
454        group_indices: &[usize],
455        counts: &UInt64Array,
456        means: &Float64Array,
457        m2s: &Float64Array,
458        _opt_filter: Option<&BooleanArray>,
459        mut value_fn: F,
460    ) where
461        F: FnMut(usize, u64, f64, f64) + Send,
462    {
463        assert_eq!(counts.null_count(), 0);
464        assert_eq!(means.null_count(), 0);
465        assert_eq!(m2s.null_count(), 0);
466
467        group_indices
468            .iter()
469            .zip(counts.values().iter())
470            .zip(means.values().iter())
471            .zip(m2s.values().iter())
472            .for_each(|(((&group_index, &count), &mean), &m2)| {
473                value_fn(group_index, count, mean, m2);
474            });
475    }
476
477    pub fn variance(
478        &mut self,
479        emit_to: datafusion_expr::EmitTo,
480    ) -> (Vec<f64>, NullBuffer) {
481        let mut counts = emit_to.take_needed(&mut self.counts);
482        // means are only needed for updating m2s and are not needed for the final result.
483        // But we still need to take them to ensure the internal state is consistent.
484        let _ = emit_to.take_needed(&mut self.means);
485        let m2s = emit_to.take_needed(&mut self.m2s);
486
487        if let StatsType::Sample = self.stats_type {
488            counts.iter_mut().for_each(|count| {
489                *count = count.saturating_sub(1);
490            });
491        }
492        let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0));
493        let variance = m2s
494            .iter()
495            .zip(counts)
496            .map(|(m2, count)| m2 / count as f64)
497            .collect();
498        (variance, nulls)
499    }
500}
501
502impl GroupsAccumulator for VarianceGroupsAccumulator {
503    fn update_batch(
504        &mut self,
505        values: &[ArrayRef],
506        group_indices: &[usize],
507        opt_filter: Option<&BooleanArray>,
508        total_num_groups: usize,
509    ) -> Result<()> {
510        assert_eq!(values.len(), 1, "single argument to update_batch");
511        let values = as_float64_array(&values[0])?;
512
513        self.resize(total_num_groups);
514        accumulate(group_indices, values, opt_filter, |group_index, value| {
515            let (new_count, new_mean, new_m2) = update(
516                self.counts[group_index],
517                self.means[group_index],
518                self.m2s[group_index],
519                value,
520            );
521            self.counts[group_index] = new_count;
522            self.means[group_index] = new_mean;
523            self.m2s[group_index] = new_m2;
524        });
525        Ok(())
526    }
527
528    fn merge_batch(
529        &mut self,
530        values: &[ArrayRef],
531        group_indices: &[usize],
532        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
533        _opt_filter: Option<&BooleanArray>,
534        total_num_groups: usize,
535    ) -> Result<()> {
536        assert_eq!(values.len(), 3, "two arguments to merge_batch");
537        // first batch is counts, second is partial means, third is partial m2s
538        let partial_counts = as_uint64_array(&values[0])?;
539        let partial_means = as_float64_array(&values[1])?;
540        let partial_m2s = as_float64_array(&values[2])?;
541
542        self.resize(total_num_groups);
543        Self::merge(
544            group_indices,
545            partial_counts,
546            partial_means,
547            partial_m2s,
548            None,
549            |group_index, partial_count, partial_mean, partial_m2| {
550                if partial_count == 0 {
551                    return;
552                }
553                let (new_count, new_mean, new_m2) = merge(
554                    self.counts[group_index],
555                    self.means[group_index],
556                    self.m2s[group_index],
557                    partial_count,
558                    partial_mean,
559                    partial_m2,
560                );
561                self.counts[group_index] = new_count;
562                self.means[group_index] = new_mean;
563                self.m2s[group_index] = new_m2;
564            },
565        );
566        Ok(())
567    }
568
569    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
570        let (variances, nulls) = self.variance(emit_to);
571        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
572    }
573
574    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
575        let counts = emit_to.take_needed(&mut self.counts);
576        let means = emit_to.take_needed(&mut self.means);
577        let m2s = emit_to.take_needed(&mut self.m2s);
578
579        Ok(vec![
580            Arc::new(UInt64Array::new(counts.into(), None)),
581            Arc::new(Float64Array::new(means.into(), None)),
582            Arc::new(Float64Array::new(m2s.into(), None)),
583        ])
584    }
585
586    fn size(&self) -> usize {
587        self.m2s.capacity() * size_of::<f64>()
588            + self.means.capacity() * size_of::<f64>()
589            + self.counts.capacity() * size_of::<u64>()
590    }
591}
592
593#[derive(Debug)]
594pub struct DistinctVarianceAccumulator {
595    distinct_values: GenericDistinctBuffer<Float64Type>,
596    stat_type: StatsType,
597}
598
599impl DistinctVarianceAccumulator {
600    pub fn new(stat_type: StatsType) -> Self {
601        Self {
602            distinct_values: GenericDistinctBuffer::<Float64Type>::new(DataType::Float64),
603            stat_type,
604        }
605    }
606}
607
608impl Accumulator for DistinctVarianceAccumulator {
609    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
610        self.distinct_values.update_batch(values)
611    }
612
613    fn evaluate(&mut self) -> Result<ScalarValue> {
614        let values = self
615            .distinct_values
616            .values
617            .iter()
618            .map(|v| v.0)
619            .collect::<Vec<_>>();
620
621        let count = match self.stat_type {
622            StatsType::Sample => {
623                if !values.is_empty() {
624                    values.len() - 1
625                } else {
626                    0
627                }
628            }
629            StatsType::Population => values.len(),
630        };
631
632        let mean = values.iter().sum::<f64>() / values.len() as f64;
633        let m2 = values.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>();
634
635        Ok(ScalarValue::Float64(match values.len() {
636            0 => None,
637            1 => match self.stat_type {
638                StatsType::Population => Some(0.0),
639                StatsType::Sample => None,
640            },
641            _ => Some(m2 / count as f64),
642        }))
643    }
644
645    fn size(&self) -> usize {
646        size_of_val(self) + self.distinct_values.size()
647    }
648
649    fn state(&mut self) -> Result<Vec<ScalarValue>> {
650        self.distinct_values.state()
651    }
652
653    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
654        self.distinct_values.merge_batch(states)
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use datafusion_expr::EmitTo;
661
662    use super::*;
663
664    #[test]
665    fn test_groups_accumulator_merge_empty_states() -> Result<()> {
666        let state_1 = vec![
667            Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
668            Arc::new(Float64Array::from(vec![0.0])),
669            Arc::new(Float64Array::from(vec![0.0])),
670        ];
671        let state_2 = vec![
672            Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
673            Arc::new(Float64Array::from(vec![1.0])),
674            Arc::new(Float64Array::from(vec![1.0])),
675        ];
676        let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
677        acc.merge_batch(&state_1, &[0], None, 1)?;
678        acc.merge_batch(&state_2, &[0], None, 1)?;
679        let result = acc.evaluate(EmitTo::All)?;
680        let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
681        assert_eq!(result.len(), 1);
682        assert_eq!(result.value(0), 1.0);
683        Ok(())
684    }
685}