datafusion_functions_aggregate/
variance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`VarianceSample`]: variance sample aggregations.
19//! [`VariancePopulation`]: variance population aggregations.
20
21use arrow::datatypes::FieldRef;
22use arrow::{
23    array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
24    buffer::NullBuffer,
25    compute::kernels::cast,
26    datatypes::{DataType, Field},
27};
28use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue};
29use datafusion_expr::{
30    function::{AccumulatorArgs, StateFieldsArgs},
31    utils::format_state_name,
32    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
33    Volatility,
34};
35use datafusion_functions_aggregate_common::{
36    aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
37};
38use datafusion_macros::user_doc;
39use std::mem::{size_of, size_of_val};
40use std::{fmt::Debug, sync::Arc};
41
42make_udaf_expr_and_func!(
43    VarianceSample,
44    var_sample,
45    expression,
46    "Computes the sample variance.",
47    var_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51    VariancePopulation,
52    var_pop,
53    expression,
54    "Computes the population variance.",
55    var_pop_udaf
56);
57
58#[user_doc(
59    doc_section(label = "General Functions"),
60    description = "Returns the statistical sample variance of a set of numbers.",
61    syntax_example = "var(expression)",
62    standard_argument(name = "expression", prefix = "Numeric")
63)]
64pub struct VarianceSample {
65    signature: Signature,
66    aliases: Vec<String>,
67}
68
69impl Debug for VarianceSample {
70    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
71        f.debug_struct("VarianceSample")
72            .field("name", &self.name())
73            .field("signature", &self.signature)
74            .finish()
75    }
76}
77
78impl Default for VarianceSample {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl VarianceSample {
85    pub fn new() -> Self {
86        Self {
87            aliases: vec![String::from("var_sample"), String::from("var_samp")],
88            signature: Signature::numeric(1, Volatility::Immutable),
89        }
90    }
91}
92
93impl AggregateUDFImpl for VarianceSample {
94    fn as_any(&self) -> &dyn std::any::Any {
95        self
96    }
97
98    fn name(&self) -> &str {
99        "var"
100    }
101
102    fn signature(&self) -> &Signature {
103        &self.signature
104    }
105
106    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
107        Ok(DataType::Float64)
108    }
109
110    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
111        let name = args.name;
112        Ok(vec![
113            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
114            Field::new(format_state_name(name, "mean"), DataType::Float64, true),
115            Field::new(format_state_name(name, "m2"), DataType::Float64, true),
116        ]
117        .into_iter()
118        .map(Arc::new)
119        .collect())
120    }
121
122    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
123        if acc_args.is_distinct {
124            return not_impl_err!("VAR(DISTINCT) aggregations are not available");
125        }
126
127        Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
128    }
129
130    fn aliases(&self) -> &[String] {
131        &self.aliases
132    }
133
134    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
135        !acc_args.is_distinct
136    }
137
138    fn create_groups_accumulator(
139        &self,
140        _args: AccumulatorArgs,
141    ) -> Result<Box<dyn GroupsAccumulator>> {
142        Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample)))
143    }
144
145    fn documentation(&self) -> Option<&Documentation> {
146        self.doc()
147    }
148}
149
150#[user_doc(
151    doc_section(label = "General Functions"),
152    description = "Returns the statistical population variance of a set of numbers.",
153    syntax_example = "var_pop(expression)",
154    standard_argument(name = "expression", prefix = "Numeric")
155)]
156pub struct VariancePopulation {
157    signature: Signature,
158    aliases: Vec<String>,
159}
160
161impl Debug for VariancePopulation {
162    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
163        f.debug_struct("VariancePopulation")
164            .field("name", &self.name())
165            .field("signature", &self.signature)
166            .finish()
167    }
168}
169
170impl Default for VariancePopulation {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176impl VariancePopulation {
177    pub fn new() -> Self {
178        Self {
179            aliases: vec![String::from("var_population")],
180            signature: Signature::numeric(1, Volatility::Immutable),
181        }
182    }
183}
184
185impl AggregateUDFImpl for VariancePopulation {
186    fn as_any(&self) -> &dyn std::any::Any {
187        self
188    }
189
190    fn name(&self) -> &str {
191        "var_pop"
192    }
193
194    fn signature(&self) -> &Signature {
195        &self.signature
196    }
197
198    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
199        if !arg_types[0].is_numeric() {
200            return plan_err!("Variance requires numeric input types");
201        }
202
203        Ok(DataType::Float64)
204    }
205
206    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
207        let name = args.name;
208        Ok(vec![
209            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
210            Field::new(format_state_name(name, "mean"), DataType::Float64, true),
211            Field::new(format_state_name(name, "m2"), DataType::Float64, true),
212        ]
213        .into_iter()
214        .map(Arc::new)
215        .collect())
216    }
217
218    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
219        if acc_args.is_distinct {
220            return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
221        }
222
223        Ok(Box::new(VarianceAccumulator::try_new(
224            StatsType::Population,
225        )?))
226    }
227
228    fn aliases(&self) -> &[String] {
229        &self.aliases
230    }
231
232    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
233        !acc_args.is_distinct
234    }
235
236    fn create_groups_accumulator(
237        &self,
238        _args: AccumulatorArgs,
239    ) -> Result<Box<dyn GroupsAccumulator>> {
240        Ok(Box::new(VarianceGroupsAccumulator::new(
241            StatsType::Population,
242        )))
243    }
244    fn documentation(&self) -> Option<&Documentation> {
245        self.doc()
246    }
247}
248
249/// An accumulator to compute variance
250/// The algorithm used is an online implementation and numerically stable. It is based on this paper:
251/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
252/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
253///
254/// The algorithm has been analyzed here:
255/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
256/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
257
258#[derive(Debug)]
259pub struct VarianceAccumulator {
260    m2: f64,
261    mean: f64,
262    count: u64,
263    stats_type: StatsType,
264}
265
266impl VarianceAccumulator {
267    /// Creates a new `VarianceAccumulator`
268    pub fn try_new(s_type: StatsType) -> Result<Self> {
269        Ok(Self {
270            m2: 0_f64,
271            mean: 0_f64,
272            count: 0_u64,
273            stats_type: s_type,
274        })
275    }
276
277    pub fn get_count(&self) -> u64 {
278        self.count
279    }
280
281    pub fn get_mean(&self) -> f64 {
282        self.mean
283    }
284
285    pub fn get_m2(&self) -> f64 {
286        self.m2
287    }
288}
289
290#[inline]
291fn merge(
292    count: u64,
293    mean: f64,
294    m2: f64,
295    count2: u64,
296    mean2: f64,
297    m22: f64,
298) -> (u64, f64, f64) {
299    debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
300    let new_count = count + count2;
301    let new_mean =
302        mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
303    let delta = mean - mean2;
304    let new_m2 =
305        m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64;
306
307    (new_count, new_mean, new_m2)
308}
309
310#[inline]
311fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) {
312    let new_count = count + 1;
313    let delta1 = value - mean;
314    let new_mean = delta1 / new_count as f64 + mean;
315    let delta2 = value - new_mean;
316    let new_m2 = m2 + delta1 * delta2;
317
318    (new_count, new_mean, new_m2)
319}
320
321impl Accumulator for VarianceAccumulator {
322    fn state(&mut self) -> Result<Vec<ScalarValue>> {
323        Ok(vec![
324            ScalarValue::from(self.count),
325            ScalarValue::from(self.mean),
326            ScalarValue::from(self.m2),
327        ])
328    }
329
330    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
331        let values = &cast(&values[0], &DataType::Float64)?;
332        let arr = downcast_value!(values, Float64Array).iter().flatten();
333
334        for value in arr {
335            (self.count, self.mean, self.m2) =
336                update(self.count, self.mean, self.m2, value)
337        }
338
339        Ok(())
340    }
341
342    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
343        let values = &cast(&values[0], &DataType::Float64)?;
344        let arr = downcast_value!(values, Float64Array).iter().flatten();
345
346        for value in arr {
347            let new_count = self.count - 1;
348            let delta1 = self.mean - value;
349            let new_mean = delta1 / new_count as f64 + self.mean;
350            let delta2 = new_mean - value;
351            let new_m2 = self.m2 - delta1 * delta2;
352
353            self.count -= 1;
354            self.mean = new_mean;
355            self.m2 = new_m2;
356        }
357
358        Ok(())
359    }
360
361    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
362        let counts = downcast_value!(states[0], UInt64Array);
363        let means = downcast_value!(states[1], Float64Array);
364        let m2s = downcast_value!(states[2], Float64Array);
365
366        for i in 0..counts.len() {
367            let c = counts.value(i);
368            if c == 0_u64 {
369                continue;
370            }
371            (self.count, self.mean, self.m2) = merge(
372                self.count,
373                self.mean,
374                self.m2,
375                c,
376                means.value(i),
377                m2s.value(i),
378            )
379        }
380        Ok(())
381    }
382
383    fn evaluate(&mut self) -> Result<ScalarValue> {
384        let count = match self.stats_type {
385            StatsType::Population => self.count,
386            StatsType::Sample => {
387                if self.count > 0 {
388                    self.count - 1
389                } else {
390                    self.count
391                }
392            }
393        };
394
395        Ok(ScalarValue::Float64(match self.count {
396            0 => None,
397            1 => {
398                if let StatsType::Population = self.stats_type {
399                    Some(0.0)
400                } else {
401                    None
402                }
403            }
404            _ => Some(self.m2 / count as f64),
405        }))
406    }
407
408    fn size(&self) -> usize {
409        size_of_val(self)
410    }
411
412    fn supports_retract_batch(&self) -> bool {
413        true
414    }
415}
416
417#[derive(Debug)]
418pub struct VarianceGroupsAccumulator {
419    m2s: Vec<f64>,
420    means: Vec<f64>,
421    counts: Vec<u64>,
422    stats_type: StatsType,
423}
424
425impl VarianceGroupsAccumulator {
426    pub fn new(s_type: StatsType) -> Self {
427        Self {
428            m2s: Vec::new(),
429            means: Vec::new(),
430            counts: Vec::new(),
431            stats_type: s_type,
432        }
433    }
434
435    fn resize(&mut self, total_num_groups: usize) {
436        self.m2s.resize(total_num_groups, 0.0);
437        self.means.resize(total_num_groups, 0.0);
438        self.counts.resize(total_num_groups, 0);
439    }
440
441    fn merge<F>(
442        group_indices: &[usize],
443        counts: &UInt64Array,
444        means: &Float64Array,
445        m2s: &Float64Array,
446        _opt_filter: Option<&BooleanArray>,
447        mut value_fn: F,
448    ) where
449        F: FnMut(usize, u64, f64, f64) + Send,
450    {
451        assert_eq!(counts.null_count(), 0);
452        assert_eq!(means.null_count(), 0);
453        assert_eq!(m2s.null_count(), 0);
454
455        group_indices
456            .iter()
457            .zip(counts.values().iter())
458            .zip(means.values().iter())
459            .zip(m2s.values().iter())
460            .for_each(|(((&group_index, &count), &mean), &m2)| {
461                value_fn(group_index, count, mean, m2);
462            });
463    }
464
465    pub fn variance(
466        &mut self,
467        emit_to: datafusion_expr::EmitTo,
468    ) -> (Vec<f64>, NullBuffer) {
469        let mut counts = emit_to.take_needed(&mut self.counts);
470        // means are only needed for updating m2s and are not needed for the final result.
471        // But we still need to take them to ensure the internal state is consistent.
472        let _ = emit_to.take_needed(&mut self.means);
473        let m2s = emit_to.take_needed(&mut self.m2s);
474
475        if let StatsType::Sample = self.stats_type {
476            counts.iter_mut().for_each(|count| {
477                *count = count.saturating_sub(1);
478            });
479        }
480        let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0));
481        let variance = m2s
482            .iter()
483            .zip(counts)
484            .map(|(m2, count)| m2 / count as f64)
485            .collect();
486        (variance, nulls)
487    }
488}
489
490impl GroupsAccumulator for VarianceGroupsAccumulator {
491    fn update_batch(
492        &mut self,
493        values: &[ArrayRef],
494        group_indices: &[usize],
495        opt_filter: Option<&BooleanArray>,
496        total_num_groups: usize,
497    ) -> Result<()> {
498        assert_eq!(values.len(), 1, "single argument to update_batch");
499        let values = &cast(&values[0], &DataType::Float64)?;
500        let values = downcast_value!(values, Float64Array);
501
502        self.resize(total_num_groups);
503        accumulate(group_indices, values, opt_filter, |group_index, value| {
504            let (new_count, new_mean, new_m2) = update(
505                self.counts[group_index],
506                self.means[group_index],
507                self.m2s[group_index],
508                value,
509            );
510            self.counts[group_index] = new_count;
511            self.means[group_index] = new_mean;
512            self.m2s[group_index] = new_m2;
513        });
514        Ok(())
515    }
516
517    fn merge_batch(
518        &mut self,
519        values: &[ArrayRef],
520        group_indices: &[usize],
521        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
522        _opt_filter: Option<&BooleanArray>,
523        total_num_groups: usize,
524    ) -> Result<()> {
525        assert_eq!(values.len(), 3, "two arguments to merge_batch");
526        // first batch is counts, second is partial means, third is partial m2s
527        let partial_counts = downcast_value!(values[0], UInt64Array);
528        let partial_means = downcast_value!(values[1], Float64Array);
529        let partial_m2s = downcast_value!(values[2], Float64Array);
530
531        self.resize(total_num_groups);
532        Self::merge(
533            group_indices,
534            partial_counts,
535            partial_means,
536            partial_m2s,
537            None,
538            |group_index, partial_count, partial_mean, partial_m2| {
539                if partial_count == 0 {
540                    return;
541                }
542                let (new_count, new_mean, new_m2) = merge(
543                    self.counts[group_index],
544                    self.means[group_index],
545                    self.m2s[group_index],
546                    partial_count,
547                    partial_mean,
548                    partial_m2,
549                );
550                self.counts[group_index] = new_count;
551                self.means[group_index] = new_mean;
552                self.m2s[group_index] = new_m2;
553            },
554        );
555        Ok(())
556    }
557
558    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
559        let (variances, nulls) = self.variance(emit_to);
560        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
561    }
562
563    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
564        let counts = emit_to.take_needed(&mut self.counts);
565        let means = emit_to.take_needed(&mut self.means);
566        let m2s = emit_to.take_needed(&mut self.m2s);
567
568        Ok(vec![
569            Arc::new(UInt64Array::new(counts.into(), None)),
570            Arc::new(Float64Array::new(means.into(), None)),
571            Arc::new(Float64Array::new(m2s.into(), None)),
572        ])
573    }
574
575    fn size(&self) -> usize {
576        self.m2s.capacity() * size_of::<f64>()
577            + self.means.capacity() * size_of::<f64>()
578            + self.counts.capacity() * size_of::<u64>()
579    }
580}
581
582#[cfg(test)]
583mod tests {
584    use datafusion_expr::EmitTo;
585
586    use super::*;
587
588    #[test]
589    fn test_groups_accumulator_merge_empty_states() -> Result<()> {
590        let state_1 = vec![
591            Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
592            Arc::new(Float64Array::from(vec![0.0])),
593            Arc::new(Float64Array::from(vec![0.0])),
594        ];
595        let state_2 = vec![
596            Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
597            Arc::new(Float64Array::from(vec![1.0])),
598            Arc::new(Float64Array::from(vec![1.0])),
599        ];
600        let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
601        acc.merge_batch(&state_1, &[0], None, 1)?;
602        acc.merge_batch(&state_2, &[0], None, 1)?;
603        let result = acc.evaluate(EmitTo::All)?;
604        let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
605        assert_eq!(result.len(), 1);
606        assert_eq!(result.value(0), 1.0);
607        Ok(())
608    }
609}