datafusion_functions_aggregate/
variance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`VarianceSample`]: variance sample aggregations.
19//! [`VariancePopulation`]: variance population aggregations.
20
21use arrow::{
22    array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
23    buffer::NullBuffer,
24    compute::kernels::cast,
25    datatypes::{DataType, Field},
26};
27use std::mem::{size_of, size_of_val};
28use std::{fmt::Debug, sync::Arc};
29
30use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue};
31use datafusion_expr::{
32    function::{AccumulatorArgs, StateFieldsArgs},
33    utils::format_state_name,
34    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
35    Volatility,
36};
37use datafusion_functions_aggregate_common::{
38    aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
39};
40use datafusion_macros::user_doc;
41
42make_udaf_expr_and_func!(
43    VarianceSample,
44    var_sample,
45    expression,
46    "Computes the sample variance.",
47    var_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51    VariancePopulation,
52    var_pop,
53    expression,
54    "Computes the population variance.",
55    var_pop_udaf
56);
57
58#[user_doc(
59    doc_section(label = "General Functions"),
60    description = "Returns the statistical sample variance of a set of numbers.",
61    syntax_example = "var(expression)",
62    standard_argument(name = "expression", prefix = "Numeric")
63)]
64pub struct VarianceSample {
65    signature: Signature,
66    aliases: Vec<String>,
67}
68
69impl Debug for VarianceSample {
70    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
71        f.debug_struct("VarianceSample")
72            .field("name", &self.name())
73            .field("signature", &self.signature)
74            .finish()
75    }
76}
77
78impl Default for VarianceSample {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl VarianceSample {
85    pub fn new() -> Self {
86        Self {
87            aliases: vec![String::from("var_sample"), String::from("var_samp")],
88            signature: Signature::numeric(1, Volatility::Immutable),
89        }
90    }
91}
92
93impl AggregateUDFImpl for VarianceSample {
94    fn as_any(&self) -> &dyn std::any::Any {
95        self
96    }
97
98    fn name(&self) -> &str {
99        "var"
100    }
101
102    fn signature(&self) -> &Signature {
103        &self.signature
104    }
105
106    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
107        Ok(DataType::Float64)
108    }
109
110    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
111        let name = args.name;
112        Ok(vec![
113            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
114            Field::new(format_state_name(name, "mean"), DataType::Float64, true),
115            Field::new(format_state_name(name, "m2"), DataType::Float64, true),
116        ])
117    }
118
119    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
120        if acc_args.is_distinct {
121            return not_impl_err!("VAR(DISTINCT) aggregations are not available");
122        }
123
124        Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
125    }
126
127    fn aliases(&self) -> &[String] {
128        &self.aliases
129    }
130
131    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
132        !acc_args.is_distinct
133    }
134
135    fn create_groups_accumulator(
136        &self,
137        _args: AccumulatorArgs,
138    ) -> Result<Box<dyn GroupsAccumulator>> {
139        Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample)))
140    }
141
142    fn documentation(&self) -> Option<&Documentation> {
143        self.doc()
144    }
145}
146
147#[user_doc(
148    doc_section(label = "General Functions"),
149    description = "Returns the statistical population variance of a set of numbers.",
150    syntax_example = "var_pop(expression)",
151    standard_argument(name = "expression", prefix = "Numeric")
152)]
153pub struct VariancePopulation {
154    signature: Signature,
155    aliases: Vec<String>,
156}
157
158impl Debug for VariancePopulation {
159    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
160        f.debug_struct("VariancePopulation")
161            .field("name", &self.name())
162            .field("signature", &self.signature)
163            .finish()
164    }
165}
166
167impl Default for VariancePopulation {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173impl VariancePopulation {
174    pub fn new() -> Self {
175        Self {
176            aliases: vec![String::from("var_population")],
177            signature: Signature::numeric(1, Volatility::Immutable),
178        }
179    }
180}
181
182impl AggregateUDFImpl for VariancePopulation {
183    fn as_any(&self) -> &dyn std::any::Any {
184        self
185    }
186
187    fn name(&self) -> &str {
188        "var_pop"
189    }
190
191    fn signature(&self) -> &Signature {
192        &self.signature
193    }
194
195    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
196        if !arg_types[0].is_numeric() {
197            return plan_err!("Variance requires numeric input types");
198        }
199
200        Ok(DataType::Float64)
201    }
202
203    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
204        let name = args.name;
205        Ok(vec![
206            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
207            Field::new(format_state_name(name, "mean"), DataType::Float64, true),
208            Field::new(format_state_name(name, "m2"), DataType::Float64, true),
209        ])
210    }
211
212    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
213        if acc_args.is_distinct {
214            return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
215        }
216
217        Ok(Box::new(VarianceAccumulator::try_new(
218            StatsType::Population,
219        )?))
220    }
221
222    fn aliases(&self) -> &[String] {
223        &self.aliases
224    }
225
226    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
227        !acc_args.is_distinct
228    }
229
230    fn create_groups_accumulator(
231        &self,
232        _args: AccumulatorArgs,
233    ) -> Result<Box<dyn GroupsAccumulator>> {
234        Ok(Box::new(VarianceGroupsAccumulator::new(
235            StatsType::Population,
236        )))
237    }
238    fn documentation(&self) -> Option<&Documentation> {
239        self.doc()
240    }
241}
242
243/// An accumulator to compute variance
244/// The algorithm used is an online implementation and numerically stable. It is based on this paper:
245/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
246/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
247///
248/// The algorithm has been analyzed here:
249/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
250/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
251
252#[derive(Debug)]
253pub struct VarianceAccumulator {
254    m2: f64,
255    mean: f64,
256    count: u64,
257    stats_type: StatsType,
258}
259
260impl VarianceAccumulator {
261    /// Creates a new `VarianceAccumulator`
262    pub fn try_new(s_type: StatsType) -> Result<Self> {
263        Ok(Self {
264            m2: 0_f64,
265            mean: 0_f64,
266            count: 0_u64,
267            stats_type: s_type,
268        })
269    }
270
271    pub fn get_count(&self) -> u64 {
272        self.count
273    }
274
275    pub fn get_mean(&self) -> f64 {
276        self.mean
277    }
278
279    pub fn get_m2(&self) -> f64 {
280        self.m2
281    }
282}
283
284#[inline]
285fn merge(
286    count: u64,
287    mean: f64,
288    m2: f64,
289    count2: u64,
290    mean2: f64,
291    m22: f64,
292) -> (u64, f64, f64) {
293    debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
294    let new_count = count + count2;
295    let new_mean =
296        mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
297    let delta = mean - mean2;
298    let new_m2 =
299        m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64;
300
301    (new_count, new_mean, new_m2)
302}
303
304#[inline]
305fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) {
306    let new_count = count + 1;
307    let delta1 = value - mean;
308    let new_mean = delta1 / new_count as f64 + mean;
309    let delta2 = value - new_mean;
310    let new_m2 = m2 + delta1 * delta2;
311
312    (new_count, new_mean, new_m2)
313}
314
315impl Accumulator for VarianceAccumulator {
316    fn state(&mut self) -> Result<Vec<ScalarValue>> {
317        Ok(vec![
318            ScalarValue::from(self.count),
319            ScalarValue::from(self.mean),
320            ScalarValue::from(self.m2),
321        ])
322    }
323
324    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
325        let values = &cast(&values[0], &DataType::Float64)?;
326        let arr = downcast_value!(values, Float64Array).iter().flatten();
327
328        for value in arr {
329            (self.count, self.mean, self.m2) =
330                update(self.count, self.mean, self.m2, value)
331        }
332
333        Ok(())
334    }
335
336    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
337        let values = &cast(&values[0], &DataType::Float64)?;
338        let arr = downcast_value!(values, Float64Array).iter().flatten();
339
340        for value in arr {
341            let new_count = self.count - 1;
342            let delta1 = self.mean - value;
343            let new_mean = delta1 / new_count as f64 + self.mean;
344            let delta2 = new_mean - value;
345            let new_m2 = self.m2 - delta1 * delta2;
346
347            self.count -= 1;
348            self.mean = new_mean;
349            self.m2 = new_m2;
350        }
351
352        Ok(())
353    }
354
355    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
356        let counts = downcast_value!(states[0], UInt64Array);
357        let means = downcast_value!(states[1], Float64Array);
358        let m2s = downcast_value!(states[2], Float64Array);
359
360        for i in 0..counts.len() {
361            let c = counts.value(i);
362            if c == 0_u64 {
363                continue;
364            }
365            (self.count, self.mean, self.m2) = merge(
366                self.count,
367                self.mean,
368                self.m2,
369                c,
370                means.value(i),
371                m2s.value(i),
372            )
373        }
374        Ok(())
375    }
376
377    fn evaluate(&mut self) -> Result<ScalarValue> {
378        let count = match self.stats_type {
379            StatsType::Population => self.count,
380            StatsType::Sample => {
381                if self.count > 0 {
382                    self.count - 1
383                } else {
384                    self.count
385                }
386            }
387        };
388
389        Ok(ScalarValue::Float64(match self.count {
390            0 => None,
391            1 => {
392                if let StatsType::Population = self.stats_type {
393                    Some(0.0)
394                } else {
395                    None
396                }
397            }
398            _ => Some(self.m2 / count as f64),
399        }))
400    }
401
402    fn size(&self) -> usize {
403        size_of_val(self)
404    }
405
406    fn supports_retract_batch(&self) -> bool {
407        true
408    }
409}
410
411#[derive(Debug)]
412pub struct VarianceGroupsAccumulator {
413    m2s: Vec<f64>,
414    means: Vec<f64>,
415    counts: Vec<u64>,
416    stats_type: StatsType,
417}
418
419impl VarianceGroupsAccumulator {
420    pub fn new(s_type: StatsType) -> Self {
421        Self {
422            m2s: Vec::new(),
423            means: Vec::new(),
424            counts: Vec::new(),
425            stats_type: s_type,
426        }
427    }
428
429    fn resize(&mut self, total_num_groups: usize) {
430        self.m2s.resize(total_num_groups, 0.0);
431        self.means.resize(total_num_groups, 0.0);
432        self.counts.resize(total_num_groups, 0);
433    }
434
435    fn merge<F>(
436        group_indices: &[usize],
437        counts: &UInt64Array,
438        means: &Float64Array,
439        m2s: &Float64Array,
440        _opt_filter: Option<&BooleanArray>,
441        mut value_fn: F,
442    ) where
443        F: FnMut(usize, u64, f64, f64) + Send,
444    {
445        assert_eq!(counts.null_count(), 0);
446        assert_eq!(means.null_count(), 0);
447        assert_eq!(m2s.null_count(), 0);
448
449        group_indices
450            .iter()
451            .zip(counts.values().iter())
452            .zip(means.values().iter())
453            .zip(m2s.values().iter())
454            .for_each(|(((&group_index, &count), &mean), &m2)| {
455                value_fn(group_index, count, mean, m2);
456            });
457    }
458
459    pub fn variance(
460        &mut self,
461        emit_to: datafusion_expr::EmitTo,
462    ) -> (Vec<f64>, NullBuffer) {
463        let mut counts = emit_to.take_needed(&mut self.counts);
464        // means are only needed for updating m2s and are not needed for the final result.
465        // But we still need to take them to ensure the internal state is consistent.
466        let _ = emit_to.take_needed(&mut self.means);
467        let m2s = emit_to.take_needed(&mut self.m2s);
468
469        if let StatsType::Sample = self.stats_type {
470            counts.iter_mut().for_each(|count| {
471                *count = count.saturating_sub(1);
472            });
473        }
474        let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0));
475        let variance = m2s
476            .iter()
477            .zip(counts)
478            .map(|(m2, count)| m2 / count as f64)
479            .collect();
480        (variance, nulls)
481    }
482}
483
484impl GroupsAccumulator for VarianceGroupsAccumulator {
485    fn update_batch(
486        &mut self,
487        values: &[ArrayRef],
488        group_indices: &[usize],
489        opt_filter: Option<&BooleanArray>,
490        total_num_groups: usize,
491    ) -> Result<()> {
492        assert_eq!(values.len(), 1, "single argument to update_batch");
493        let values = &cast(&values[0], &DataType::Float64)?;
494        let values = downcast_value!(values, Float64Array);
495
496        self.resize(total_num_groups);
497        accumulate(group_indices, values, opt_filter, |group_index, value| {
498            let (new_count, new_mean, new_m2) = update(
499                self.counts[group_index],
500                self.means[group_index],
501                self.m2s[group_index],
502                value,
503            );
504            self.counts[group_index] = new_count;
505            self.means[group_index] = new_mean;
506            self.m2s[group_index] = new_m2;
507        });
508        Ok(())
509    }
510
511    fn merge_batch(
512        &mut self,
513        values: &[ArrayRef],
514        group_indices: &[usize],
515        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
516        _opt_filter: Option<&BooleanArray>,
517        total_num_groups: usize,
518    ) -> Result<()> {
519        assert_eq!(values.len(), 3, "two arguments to merge_batch");
520        // first batch is counts, second is partial means, third is partial m2s
521        let partial_counts = downcast_value!(values[0], UInt64Array);
522        let partial_means = downcast_value!(values[1], Float64Array);
523        let partial_m2s = downcast_value!(values[2], Float64Array);
524
525        self.resize(total_num_groups);
526        Self::merge(
527            group_indices,
528            partial_counts,
529            partial_means,
530            partial_m2s,
531            None,
532            |group_index, partial_count, partial_mean, partial_m2| {
533                if partial_count == 0 {
534                    return;
535                }
536                let (new_count, new_mean, new_m2) = merge(
537                    self.counts[group_index],
538                    self.means[group_index],
539                    self.m2s[group_index],
540                    partial_count,
541                    partial_mean,
542                    partial_m2,
543                );
544                self.counts[group_index] = new_count;
545                self.means[group_index] = new_mean;
546                self.m2s[group_index] = new_m2;
547            },
548        );
549        Ok(())
550    }
551
552    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
553        let (variances, nulls) = self.variance(emit_to);
554        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
555    }
556
557    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
558        let counts = emit_to.take_needed(&mut self.counts);
559        let means = emit_to.take_needed(&mut self.means);
560        let m2s = emit_to.take_needed(&mut self.m2s);
561
562        Ok(vec![
563            Arc::new(UInt64Array::new(counts.into(), None)),
564            Arc::new(Float64Array::new(means.into(), None)),
565            Arc::new(Float64Array::new(m2s.into(), None)),
566        ])
567    }
568
569    fn size(&self) -> usize {
570        self.m2s.capacity() * size_of::<f64>()
571            + self.means.capacity() * size_of::<f64>()
572            + self.counts.capacity() * size_of::<u64>()
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use datafusion_expr::EmitTo;
579
580    use super::*;
581
582    #[test]
583    fn test_groups_accumulator_merge_empty_states() -> Result<()> {
584        let state_1 = vec![
585            Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
586            Arc::new(Float64Array::from(vec![0.0])),
587            Arc::new(Float64Array::from(vec![0.0])),
588        ];
589        let state_2 = vec![
590            Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
591            Arc::new(Float64Array::from(vec![1.0])),
592            Arc::new(Float64Array::from(vec![1.0])),
593        ];
594        let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
595        acc.merge_batch(&state_1, &[0], None, 1)?;
596        acc.merge_batch(&state_2, &[0], None, 1)?;
597        let result = acc.evaluate(EmitTo::All)?;
598        let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
599        assert_eq!(result.len(), 1);
600        assert_eq!(result.value(0), 1.0);
601        Ok(())
602    }
603}