datafusion_functions_aggregate/
variance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`VarianceSample`]: variance sample aggregations.
19//! [`VariancePopulation`]: variance population aggregations.
20
21use arrow::datatypes::FieldRef;
22use arrow::{
23    array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
24    buffer::NullBuffer,
25    compute::kernels::cast,
26    datatypes::{DataType, Field},
27};
28use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue};
29use datafusion_expr::{
30    function::{AccumulatorArgs, StateFieldsArgs},
31    utils::format_state_name,
32    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
33    Volatility,
34};
35use datafusion_functions_aggregate_common::{
36    aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
37};
38use datafusion_macros::user_doc;
39use std::mem::{size_of, size_of_val};
40use std::{fmt::Debug, sync::Arc};
41
42make_udaf_expr_and_func!(
43    VarianceSample,
44    var_sample,
45    expression,
46    "Computes the sample variance.",
47    var_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51    VariancePopulation,
52    var_pop,
53    expression,
54    "Computes the population variance.",
55    var_pop_udaf
56);
57
58#[user_doc(
59    doc_section(label = "General Functions"),
60    description = "Returns the statistical sample variance of a set of numbers.",
61    syntax_example = "var(expression)",
62    standard_argument(name = "expression", prefix = "Numeric")
63)]
64#[derive(PartialEq, Eq, Hash)]
65pub struct VarianceSample {
66    signature: Signature,
67    aliases: Vec<String>,
68}
69
70impl Debug for VarianceSample {
71    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
72        f.debug_struct("VarianceSample")
73            .field("name", &self.name())
74            .field("signature", &self.signature)
75            .finish()
76    }
77}
78
79impl Default for VarianceSample {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl VarianceSample {
86    pub fn new() -> Self {
87        Self {
88            aliases: vec![String::from("var_sample"), String::from("var_samp")],
89            signature: Signature::numeric(1, Volatility::Immutable),
90        }
91    }
92}
93
94impl AggregateUDFImpl for VarianceSample {
95    fn as_any(&self) -> &dyn std::any::Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "var"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
108        Ok(DataType::Float64)
109    }
110
111    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
112        let name = args.name;
113        Ok(vec![
114            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
115            Field::new(format_state_name(name, "mean"), DataType::Float64, true),
116            Field::new(format_state_name(name, "m2"), DataType::Float64, true),
117        ]
118        .into_iter()
119        .map(Arc::new)
120        .collect())
121    }
122
123    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
124        if acc_args.is_distinct {
125            return not_impl_err!("VAR(DISTINCT) aggregations are not available");
126        }
127
128        Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
129    }
130
131    fn aliases(&self) -> &[String] {
132        &self.aliases
133    }
134
135    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
136        !acc_args.is_distinct
137    }
138
139    fn create_groups_accumulator(
140        &self,
141        _args: AccumulatorArgs,
142    ) -> Result<Box<dyn GroupsAccumulator>> {
143        Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample)))
144    }
145
146    fn documentation(&self) -> Option<&Documentation> {
147        self.doc()
148    }
149}
150
151#[user_doc(
152    doc_section(label = "General Functions"),
153    description = "Returns the statistical population variance of a set of numbers.",
154    syntax_example = "var_pop(expression)",
155    standard_argument(name = "expression", prefix = "Numeric")
156)]
157#[derive(PartialEq, Eq, Hash)]
158pub struct VariancePopulation {
159    signature: Signature,
160    aliases: Vec<String>,
161}
162
163impl Debug for VariancePopulation {
164    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
165        f.debug_struct("VariancePopulation")
166            .field("name", &self.name())
167            .field("signature", &self.signature)
168            .finish()
169    }
170}
171
172impl Default for VariancePopulation {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178impl VariancePopulation {
179    pub fn new() -> Self {
180        Self {
181            aliases: vec![String::from("var_population")],
182            signature: Signature::numeric(1, Volatility::Immutable),
183        }
184    }
185}
186
187impl AggregateUDFImpl for VariancePopulation {
188    fn as_any(&self) -> &dyn std::any::Any {
189        self
190    }
191
192    fn name(&self) -> &str {
193        "var_pop"
194    }
195
196    fn signature(&self) -> &Signature {
197        &self.signature
198    }
199
200    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
201        if !arg_types[0].is_numeric() {
202            return plan_err!("Variance requires numeric input types");
203        }
204
205        Ok(DataType::Float64)
206    }
207
208    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
209        let name = args.name;
210        Ok(vec![
211            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
212            Field::new(format_state_name(name, "mean"), DataType::Float64, true),
213            Field::new(format_state_name(name, "m2"), DataType::Float64, true),
214        ]
215        .into_iter()
216        .map(Arc::new)
217        .collect())
218    }
219
220    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
221        if acc_args.is_distinct {
222            return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
223        }
224
225        Ok(Box::new(VarianceAccumulator::try_new(
226            StatsType::Population,
227        )?))
228    }
229
230    fn aliases(&self) -> &[String] {
231        &self.aliases
232    }
233
234    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
235        !acc_args.is_distinct
236    }
237
238    fn create_groups_accumulator(
239        &self,
240        _args: AccumulatorArgs,
241    ) -> Result<Box<dyn GroupsAccumulator>> {
242        Ok(Box::new(VarianceGroupsAccumulator::new(
243            StatsType::Population,
244        )))
245    }
246    fn documentation(&self) -> Option<&Documentation> {
247        self.doc()
248    }
249}
250
251/// An accumulator to compute variance
252/// The algorithm used is an online implementation and numerically stable. It is based on this paper:
253/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
254/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
255///
256/// The algorithm has been analyzed here:
257/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
258/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
259
260#[derive(Debug)]
261pub struct VarianceAccumulator {
262    m2: f64,
263    mean: f64,
264    count: u64,
265    stats_type: StatsType,
266}
267
268impl VarianceAccumulator {
269    /// Creates a new `VarianceAccumulator`
270    pub fn try_new(s_type: StatsType) -> Result<Self> {
271        Ok(Self {
272            m2: 0_f64,
273            mean: 0_f64,
274            count: 0_u64,
275            stats_type: s_type,
276        })
277    }
278
279    pub fn get_count(&self) -> u64 {
280        self.count
281    }
282
283    pub fn get_mean(&self) -> f64 {
284        self.mean
285    }
286
287    pub fn get_m2(&self) -> f64 {
288        self.m2
289    }
290}
291
292#[inline]
293fn merge(
294    count: u64,
295    mean: f64,
296    m2: f64,
297    count2: u64,
298    mean2: f64,
299    m22: f64,
300) -> (u64, f64, f64) {
301    debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
302    let new_count = count + count2;
303    let new_mean =
304        mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
305    let delta = mean - mean2;
306    let new_m2 =
307        m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64;
308
309    (new_count, new_mean, new_m2)
310}
311
312#[inline]
313fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) {
314    let new_count = count + 1;
315    let delta1 = value - mean;
316    let new_mean = delta1 / new_count as f64 + mean;
317    let delta2 = value - new_mean;
318    let new_m2 = m2 + delta1 * delta2;
319
320    (new_count, new_mean, new_m2)
321}
322
323impl Accumulator for VarianceAccumulator {
324    fn state(&mut self) -> Result<Vec<ScalarValue>> {
325        Ok(vec![
326            ScalarValue::from(self.count),
327            ScalarValue::from(self.mean),
328            ScalarValue::from(self.m2),
329        ])
330    }
331
332    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
333        let values = &cast(&values[0], &DataType::Float64)?;
334        let arr = downcast_value!(values, Float64Array).iter().flatten();
335
336        for value in arr {
337            (self.count, self.mean, self.m2) =
338                update(self.count, self.mean, self.m2, value)
339        }
340
341        Ok(())
342    }
343
344    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
345        let values = &cast(&values[0], &DataType::Float64)?;
346        let arr = downcast_value!(values, Float64Array).iter().flatten();
347
348        for value in arr {
349            let new_count = self.count - 1;
350            let delta1 = self.mean - value;
351            let new_mean = delta1 / new_count as f64 + self.mean;
352            let delta2 = new_mean - value;
353            let new_m2 = self.m2 - delta1 * delta2;
354
355            self.count -= 1;
356            self.mean = new_mean;
357            self.m2 = new_m2;
358        }
359
360        Ok(())
361    }
362
363    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
364        let counts = downcast_value!(states[0], UInt64Array);
365        let means = downcast_value!(states[1], Float64Array);
366        let m2s = downcast_value!(states[2], Float64Array);
367
368        for i in 0..counts.len() {
369            let c = counts.value(i);
370            if c == 0_u64 {
371                continue;
372            }
373            (self.count, self.mean, self.m2) = merge(
374                self.count,
375                self.mean,
376                self.m2,
377                c,
378                means.value(i),
379                m2s.value(i),
380            )
381        }
382        Ok(())
383    }
384
385    fn evaluate(&mut self) -> Result<ScalarValue> {
386        let count = match self.stats_type {
387            StatsType::Population => self.count,
388            StatsType::Sample => {
389                if self.count > 0 {
390                    self.count - 1
391                } else {
392                    self.count
393                }
394            }
395        };
396
397        Ok(ScalarValue::Float64(match self.count {
398            0 => None,
399            1 => {
400                if let StatsType::Population = self.stats_type {
401                    Some(0.0)
402                } else {
403                    None
404                }
405            }
406            _ => Some(self.m2 / count as f64),
407        }))
408    }
409
410    fn size(&self) -> usize {
411        size_of_val(self)
412    }
413
414    fn supports_retract_batch(&self) -> bool {
415        true
416    }
417}
418
419#[derive(Debug)]
420pub struct VarianceGroupsAccumulator {
421    m2s: Vec<f64>,
422    means: Vec<f64>,
423    counts: Vec<u64>,
424    stats_type: StatsType,
425}
426
427impl VarianceGroupsAccumulator {
428    pub fn new(s_type: StatsType) -> Self {
429        Self {
430            m2s: Vec::new(),
431            means: Vec::new(),
432            counts: Vec::new(),
433            stats_type: s_type,
434        }
435    }
436
437    fn resize(&mut self, total_num_groups: usize) {
438        self.m2s.resize(total_num_groups, 0.0);
439        self.means.resize(total_num_groups, 0.0);
440        self.counts.resize(total_num_groups, 0);
441    }
442
443    fn merge<F>(
444        group_indices: &[usize],
445        counts: &UInt64Array,
446        means: &Float64Array,
447        m2s: &Float64Array,
448        _opt_filter: Option<&BooleanArray>,
449        mut value_fn: F,
450    ) where
451        F: FnMut(usize, u64, f64, f64) + Send,
452    {
453        assert_eq!(counts.null_count(), 0);
454        assert_eq!(means.null_count(), 0);
455        assert_eq!(m2s.null_count(), 0);
456
457        group_indices
458            .iter()
459            .zip(counts.values().iter())
460            .zip(means.values().iter())
461            .zip(m2s.values().iter())
462            .for_each(|(((&group_index, &count), &mean), &m2)| {
463                value_fn(group_index, count, mean, m2);
464            });
465    }
466
467    pub fn variance(
468        &mut self,
469        emit_to: datafusion_expr::EmitTo,
470    ) -> (Vec<f64>, NullBuffer) {
471        let mut counts = emit_to.take_needed(&mut self.counts);
472        // means are only needed for updating m2s and are not needed for the final result.
473        // But we still need to take them to ensure the internal state is consistent.
474        let _ = emit_to.take_needed(&mut self.means);
475        let m2s = emit_to.take_needed(&mut self.m2s);
476
477        if let StatsType::Sample = self.stats_type {
478            counts.iter_mut().for_each(|count| {
479                *count = count.saturating_sub(1);
480            });
481        }
482        let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0));
483        let variance = m2s
484            .iter()
485            .zip(counts)
486            .map(|(m2, count)| m2 / count as f64)
487            .collect();
488        (variance, nulls)
489    }
490}
491
492impl GroupsAccumulator for VarianceGroupsAccumulator {
493    fn update_batch(
494        &mut self,
495        values: &[ArrayRef],
496        group_indices: &[usize],
497        opt_filter: Option<&BooleanArray>,
498        total_num_groups: usize,
499    ) -> Result<()> {
500        assert_eq!(values.len(), 1, "single argument to update_batch");
501        let values = &cast(&values[0], &DataType::Float64)?;
502        let values = downcast_value!(values, Float64Array);
503
504        self.resize(total_num_groups);
505        accumulate(group_indices, values, opt_filter, |group_index, value| {
506            let (new_count, new_mean, new_m2) = update(
507                self.counts[group_index],
508                self.means[group_index],
509                self.m2s[group_index],
510                value,
511            );
512            self.counts[group_index] = new_count;
513            self.means[group_index] = new_mean;
514            self.m2s[group_index] = new_m2;
515        });
516        Ok(())
517    }
518
519    fn merge_batch(
520        &mut self,
521        values: &[ArrayRef],
522        group_indices: &[usize],
523        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
524        _opt_filter: Option<&BooleanArray>,
525        total_num_groups: usize,
526    ) -> Result<()> {
527        assert_eq!(values.len(), 3, "two arguments to merge_batch");
528        // first batch is counts, second is partial means, third is partial m2s
529        let partial_counts = downcast_value!(values[0], UInt64Array);
530        let partial_means = downcast_value!(values[1], Float64Array);
531        let partial_m2s = downcast_value!(values[2], Float64Array);
532
533        self.resize(total_num_groups);
534        Self::merge(
535            group_indices,
536            partial_counts,
537            partial_means,
538            partial_m2s,
539            None,
540            |group_index, partial_count, partial_mean, partial_m2| {
541                if partial_count == 0 {
542                    return;
543                }
544                let (new_count, new_mean, new_m2) = merge(
545                    self.counts[group_index],
546                    self.means[group_index],
547                    self.m2s[group_index],
548                    partial_count,
549                    partial_mean,
550                    partial_m2,
551                );
552                self.counts[group_index] = new_count;
553                self.means[group_index] = new_mean;
554                self.m2s[group_index] = new_m2;
555            },
556        );
557        Ok(())
558    }
559
560    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
561        let (variances, nulls) = self.variance(emit_to);
562        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
563    }
564
565    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
566        let counts = emit_to.take_needed(&mut self.counts);
567        let means = emit_to.take_needed(&mut self.means);
568        let m2s = emit_to.take_needed(&mut self.m2s);
569
570        Ok(vec![
571            Arc::new(UInt64Array::new(counts.into(), None)),
572            Arc::new(Float64Array::new(means.into(), None)),
573            Arc::new(Float64Array::new(m2s.into(), None)),
574        ])
575    }
576
577    fn size(&self) -> usize {
578        self.m2s.capacity() * size_of::<f64>()
579            + self.means.capacity() * size_of::<f64>()
580            + self.counts.capacity() * size_of::<u64>()
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use datafusion_expr::EmitTo;
587
588    use super::*;
589
590    #[test]
591    fn test_groups_accumulator_merge_empty_states() -> Result<()> {
592        let state_1 = vec![
593            Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
594            Arc::new(Float64Array::from(vec![0.0])),
595            Arc::new(Float64Array::from(vec![0.0])),
596        ];
597        let state_2 = vec![
598            Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
599            Arc::new(Float64Array::from(vec![1.0])),
600            Arc::new(Float64Array::from(vec![1.0])),
601        ];
602        let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
603        acc.merge_batch(&state_1, &[0], None, 1)?;
604        acc.merge_batch(&state_2, &[0], None, 1)?;
605        let result = acc.evaluate(EmitTo::All)?;
606        let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
607        assert_eq!(result.len(), 1);
608        assert_eq!(result.value(0), 1.0);
609        Ok(())
610    }
611}