Skip to main content

datafusion_functions_aggregate/
variance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`VarianceSample`]: variance sample aggregations.
19//! [`VariancePopulation`]: variance population aggregations.
20
21use arrow::datatypes::{FieldRef, Float64Type};
22use arrow::{
23    array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
24    buffer::NullBuffer,
25    datatypes::{DataType, Field},
26};
27use datafusion_common::cast::{as_float64_array, as_uint64_array};
28use datafusion_common::{Result, ScalarValue};
29use datafusion_expr::{
30    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
31    Volatility,
32    function::{AccumulatorArgs, StateFieldsArgs},
33    utils::format_state_name,
34};
35use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
36use datafusion_functions_aggregate_common::{
37    aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
38};
39use datafusion_macros::user_doc;
40use std::mem::{size_of, size_of_val};
41use std::{fmt::Debug, sync::Arc};
42
43make_udaf_expr_and_func!(
44    VarianceSample,
45    var_sample,
46    expression,
47    "Computes the sample variance.",
48    var_samp_udaf
49);
50
51make_udaf_expr_and_func!(
52    VariancePopulation,
53    var_pop,
54    expression,
55    "Computes the population variance.",
56    var_pop_udaf
57);
58
59#[user_doc(
60    doc_section(label = "General Functions"),
61    description = "Returns the statistical sample variance of a set of numbers.",
62    syntax_example = "var(expression)",
63    standard_argument(name = "expression", prefix = "Numeric")
64)]
65#[derive(PartialEq, Eq, Hash, Debug)]
66pub struct VarianceSample {
67    signature: Signature,
68    aliases: Vec<String>,
69}
70
71impl Default for VarianceSample {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl VarianceSample {
78    pub fn new() -> Self {
79        Self {
80            aliases: vec![String::from("var_sample"), String::from("var_samp")],
81            signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
82        }
83    }
84}
85
86impl AggregateUDFImpl for VarianceSample {
87    fn name(&self) -> &str {
88        "var"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
96        Ok(DataType::Float64)
97    }
98
99    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
100        let name = args.name;
101        match args.is_distinct {
102            false => Ok(vec![
103                Field::new(format_state_name(name, "count"), DataType::UInt64, true),
104                Field::new(format_state_name(name, "mean"), DataType::Float64, true),
105                Field::new(format_state_name(name, "m2"), DataType::Float64, true),
106            ]
107            .into_iter()
108            .map(Arc::new)
109            .collect()),
110            true => {
111                let field = Field::new_list_field(DataType::Float64, true);
112                let state_name = "distinct_var";
113                Ok(vec![
114                    Field::new(
115                        format_state_name(name, state_name),
116                        DataType::List(Arc::new(field)),
117                        true,
118                    )
119                    .into(),
120                ])
121            }
122        }
123    }
124
125    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
126        if acc_args.is_distinct {
127            return Ok(Box::new(DistinctVarianceAccumulator::new(
128                StatsType::Sample,
129            )));
130        }
131
132        Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
133    }
134
135    fn aliases(&self) -> &[String] {
136        &self.aliases
137    }
138
139    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
140        !acc_args.is_distinct
141    }
142
143    fn create_groups_accumulator(
144        &self,
145        _args: AccumulatorArgs,
146    ) -> Result<Box<dyn GroupsAccumulator>> {
147        Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample)))
148    }
149
150    fn documentation(&self) -> Option<&Documentation> {
151        self.doc()
152    }
153}
154
155#[user_doc(
156    doc_section(label = "General Functions"),
157    description = "Returns the statistical population variance of a set of numbers.",
158    syntax_example = "var_pop(expression)",
159    standard_argument(name = "expression", prefix = "Numeric")
160)]
161#[derive(PartialEq, Eq, Hash, Debug)]
162pub struct VariancePopulation {
163    signature: Signature,
164    aliases: Vec<String>,
165}
166
167impl Default for VariancePopulation {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173impl VariancePopulation {
174    pub fn new() -> Self {
175        Self {
176            aliases: vec![String::from("var_population")],
177            signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
178        }
179    }
180}
181
182impl AggregateUDFImpl for VariancePopulation {
183    fn name(&self) -> &str {
184        "var_pop"
185    }
186
187    fn signature(&self) -> &Signature {
188        &self.signature
189    }
190
191    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
192        Ok(DataType::Float64)
193    }
194
195    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
196        match args.is_distinct {
197            false => {
198                let name = args.name;
199                Ok(vec![
200                    Field::new(format_state_name(name, "count"), DataType::UInt64, true),
201                    Field::new(format_state_name(name, "mean"), DataType::Float64, true),
202                    Field::new(format_state_name(name, "m2"), DataType::Float64, true),
203                ]
204                .into_iter()
205                .map(Arc::new)
206                .collect())
207            }
208            true => {
209                let field = Field::new_list_field(DataType::Float64, true);
210                let state_name = "distinct_var";
211                Ok(vec![
212                    Field::new(
213                        format_state_name(args.name, state_name),
214                        DataType::List(Arc::new(field)),
215                        true,
216                    )
217                    .into(),
218                ])
219            }
220        }
221    }
222
223    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
224        if acc_args.is_distinct {
225            return Ok(Box::new(DistinctVarianceAccumulator::new(
226                StatsType::Population,
227            )));
228        }
229
230        Ok(Box::new(VarianceAccumulator::try_new(
231            StatsType::Population,
232        )?))
233    }
234
235    fn aliases(&self) -> &[String] {
236        &self.aliases
237    }
238
239    fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
240        !acc_args.is_distinct
241    }
242
243    fn create_groups_accumulator(
244        &self,
245        _args: AccumulatorArgs,
246    ) -> Result<Box<dyn GroupsAccumulator>> {
247        Ok(Box::new(VarianceGroupsAccumulator::new(
248            StatsType::Population,
249        )))
250    }
251
252    fn documentation(&self) -> Option<&Documentation> {
253        self.doc()
254    }
255}
256
257/// An accumulator to compute variance
258/// The algorithm used is an online implementation and numerically stable. It is based on this paper:
259/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
260/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
261///
262/// The algorithm has been analyzed here:
263/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
264/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
265
266#[derive(Debug)]
267pub struct VarianceAccumulator {
268    m2: f64,
269    mean: f64,
270    count: u64,
271    stats_type: StatsType,
272}
273
274impl VarianceAccumulator {
275    /// Creates a new `VarianceAccumulator`
276    pub fn try_new(s_type: StatsType) -> Result<Self> {
277        Ok(Self {
278            m2: 0_f64,
279            mean: 0_f64,
280            count: 0_u64,
281            stats_type: s_type,
282        })
283    }
284
285    pub fn get_count(&self) -> u64 {
286        self.count
287    }
288
289    pub fn get_mean(&self) -> f64 {
290        self.mean
291    }
292
293    pub fn get_m2(&self) -> f64 {
294        self.m2
295    }
296}
297
298#[inline]
299fn merge(
300    count: u64,
301    mean: f64,
302    m2: f64,
303    count2: u64,
304    mean2: f64,
305    m22: f64,
306) -> (u64, f64, f64) {
307    debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
308    let new_count = count + count2;
309    let new_mean =
310        mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
311    let delta = mean - mean2;
312    let new_m2 =
313        m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64;
314
315    (new_count, new_mean, new_m2)
316}
317
318#[inline]
319fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) {
320    let new_count = count + 1;
321    let delta1 = value - mean;
322    let new_mean = delta1 / new_count as f64 + mean;
323    let delta2 = value - new_mean;
324    let new_m2 = m2 + delta1 * delta2;
325
326    (new_count, new_mean, new_m2)
327}
328
329impl Accumulator for VarianceAccumulator {
330    fn state(&mut self) -> Result<Vec<ScalarValue>> {
331        Ok(vec![
332            ScalarValue::from(self.count),
333            ScalarValue::from(self.mean),
334            ScalarValue::from(self.m2),
335        ])
336    }
337
338    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
339        let arr = as_float64_array(&values[0])?;
340        for value in arr.iter().flatten() {
341            (self.count, self.mean, self.m2) =
342                update(self.count, self.mean, self.m2, value)
343        }
344
345        Ok(())
346    }
347
348    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
349        let arr = as_float64_array(&values[0])?;
350        for value in arr.iter().flatten() {
351            let new_count = self.count - 1;
352            let delta1 = self.mean - value;
353            let new_mean = delta1 / new_count as f64 + self.mean;
354            let delta2 = new_mean - value;
355            let new_m2 = self.m2 - delta1 * delta2;
356
357            self.count -= 1;
358            self.mean = new_mean;
359            self.m2 = new_m2;
360        }
361
362        Ok(())
363    }
364
365    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
366        let counts = as_uint64_array(&states[0])?;
367        let means = as_float64_array(&states[1])?;
368        let m2s = as_float64_array(&states[2])?;
369
370        for i in 0..counts.len() {
371            let c = counts.value(i);
372            if c == 0_u64 {
373                continue;
374            }
375            (self.count, self.mean, self.m2) = merge(
376                self.count,
377                self.mean,
378                self.m2,
379                c,
380                means.value(i),
381                m2s.value(i),
382            )
383        }
384        Ok(())
385    }
386
387    fn evaluate(&mut self) -> Result<ScalarValue> {
388        let count = match self.stats_type {
389            StatsType::Population => self.count,
390            StatsType::Sample => {
391                if self.count > 0 {
392                    self.count - 1
393                } else {
394                    self.count
395                }
396            }
397        };
398
399        Ok(ScalarValue::Float64(match self.count {
400            0 => None,
401            1 => {
402                if let StatsType::Population = self.stats_type {
403                    Some(0.0)
404                } else {
405                    None
406                }
407            }
408            _ => Some(self.m2 / count as f64),
409        }))
410    }
411
412    fn size(&self) -> usize {
413        size_of_val(self)
414    }
415
416    fn supports_retract_batch(&self) -> bool {
417        true
418    }
419}
420
421#[derive(Debug)]
422pub struct VarianceGroupsAccumulator {
423    m2s: Vec<f64>,
424    means: Vec<f64>,
425    counts: Vec<u64>,
426    stats_type: StatsType,
427}
428
429impl VarianceGroupsAccumulator {
430    pub fn new(s_type: StatsType) -> Self {
431        Self {
432            m2s: Vec::new(),
433            means: Vec::new(),
434            counts: Vec::new(),
435            stats_type: s_type,
436        }
437    }
438
439    fn resize(&mut self, total_num_groups: usize) {
440        self.m2s.resize(total_num_groups, 0.0);
441        self.means.resize(total_num_groups, 0.0);
442        self.counts.resize(total_num_groups, 0);
443    }
444
445    fn merge<F>(
446        group_indices: &[usize],
447        counts: &UInt64Array,
448        means: &Float64Array,
449        m2s: &Float64Array,
450        _opt_filter: Option<&BooleanArray>,
451        mut value_fn: F,
452    ) where
453        F: FnMut(usize, u64, f64, f64) + Send,
454    {
455        assert_eq!(counts.null_count(), 0);
456        assert_eq!(means.null_count(), 0);
457        assert_eq!(m2s.null_count(), 0);
458
459        group_indices
460            .iter()
461            .zip(counts.values().iter())
462            .zip(means.values().iter())
463            .zip(m2s.values().iter())
464            .for_each(|(((&group_index, &count), &mean), &m2)| {
465                value_fn(group_index, count, mean, m2);
466            });
467    }
468
469    pub fn variance(
470        &mut self,
471        emit_to: datafusion_expr::EmitTo,
472    ) -> (Vec<f64>, NullBuffer) {
473        let mut counts = emit_to.take_needed(&mut self.counts);
474        // means are only needed for updating m2s and are not needed for the final result.
475        // But we still need to take them to ensure the internal state is consistent.
476        let _ = emit_to.take_needed(&mut self.means);
477        let m2s = emit_to.take_needed(&mut self.m2s);
478
479        if let StatsType::Sample = self.stats_type {
480            counts.iter_mut().for_each(|count| {
481                *count = count.saturating_sub(1);
482            });
483        }
484        let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0));
485        let variance = m2s
486            .iter()
487            .zip(counts)
488            .map(|(m2, count)| m2 / count as f64)
489            .collect();
490        (variance, nulls)
491    }
492}
493
494impl GroupsAccumulator for VarianceGroupsAccumulator {
495    fn update_batch(
496        &mut self,
497        values: &[ArrayRef],
498        group_indices: &[usize],
499        opt_filter: Option<&BooleanArray>,
500        total_num_groups: usize,
501    ) -> Result<()> {
502        assert_eq!(values.len(), 1, "single argument to update_batch");
503        let values = as_float64_array(&values[0])?;
504
505        self.resize(total_num_groups);
506        accumulate(group_indices, values, opt_filter, |group_index, value| {
507            let (new_count, new_mean, new_m2) = update(
508                self.counts[group_index],
509                self.means[group_index],
510                self.m2s[group_index],
511                value,
512            );
513            self.counts[group_index] = new_count;
514            self.means[group_index] = new_mean;
515            self.m2s[group_index] = new_m2;
516        });
517        Ok(())
518    }
519
520    fn merge_batch(
521        &mut self,
522        values: &[ArrayRef],
523        group_indices: &[usize],
524        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
525        _opt_filter: Option<&BooleanArray>,
526        total_num_groups: usize,
527    ) -> Result<()> {
528        assert_eq!(values.len(), 3, "two arguments to merge_batch");
529        // first batch is counts, second is partial means, third is partial m2s
530        let partial_counts = as_uint64_array(&values[0])?;
531        let partial_means = as_float64_array(&values[1])?;
532        let partial_m2s = as_float64_array(&values[2])?;
533
534        self.resize(total_num_groups);
535        Self::merge(
536            group_indices,
537            partial_counts,
538            partial_means,
539            partial_m2s,
540            None,
541            |group_index, partial_count, partial_mean, partial_m2| {
542                if partial_count == 0 {
543                    return;
544                }
545                let (new_count, new_mean, new_m2) = merge(
546                    self.counts[group_index],
547                    self.means[group_index],
548                    self.m2s[group_index],
549                    partial_count,
550                    partial_mean,
551                    partial_m2,
552                );
553                self.counts[group_index] = new_count;
554                self.means[group_index] = new_mean;
555                self.m2s[group_index] = new_m2;
556            },
557        );
558        Ok(())
559    }
560
561    fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
562        let (variances, nulls) = self.variance(emit_to);
563        Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
564    }
565
566    fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
567        let counts = emit_to.take_needed(&mut self.counts);
568        let means = emit_to.take_needed(&mut self.means);
569        let m2s = emit_to.take_needed(&mut self.m2s);
570
571        Ok(vec![
572            Arc::new(UInt64Array::new(counts.into(), None)),
573            Arc::new(Float64Array::new(means.into(), None)),
574            Arc::new(Float64Array::new(m2s.into(), None)),
575        ])
576    }
577
578    fn size(&self) -> usize {
579        self.m2s.capacity() * size_of::<f64>()
580            + self.means.capacity() * size_of::<f64>()
581            + self.counts.capacity() * size_of::<u64>()
582    }
583}
584
585#[derive(Debug)]
586pub struct DistinctVarianceAccumulator {
587    distinct_values: GenericDistinctBuffer<Float64Type>,
588    stat_type: StatsType,
589}
590
591impl DistinctVarianceAccumulator {
592    pub fn new(stat_type: StatsType) -> Self {
593        Self {
594            distinct_values: GenericDistinctBuffer::<Float64Type>::new(DataType::Float64),
595            stat_type,
596        }
597    }
598}
599
600impl Accumulator for DistinctVarianceAccumulator {
601    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
602        self.distinct_values.update_batch(values)
603    }
604
605    fn evaluate(&mut self) -> Result<ScalarValue> {
606        let values = self
607            .distinct_values
608            .values
609            .iter()
610            .map(|v| v.0)
611            .collect::<Vec<_>>();
612
613        let count = match self.stat_type {
614            StatsType::Sample => {
615                if !values.is_empty() {
616                    values.len() - 1
617                } else {
618                    0
619                }
620            }
621            StatsType::Population => values.len(),
622        };
623
624        let mean = values.iter().sum::<f64>() / values.len() as f64;
625        let m2 = values.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>();
626
627        Ok(ScalarValue::Float64(match values.len() {
628            0 => None,
629            1 => match self.stat_type {
630                StatsType::Population => Some(0.0),
631                StatsType::Sample => None,
632            },
633            _ => Some(m2 / count as f64),
634        }))
635    }
636
637    fn size(&self) -> usize {
638        size_of_val(self) + self.distinct_values.size()
639    }
640
641    fn state(&mut self) -> Result<Vec<ScalarValue>> {
642        self.distinct_values.state()
643    }
644
645    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
646        self.distinct_values.merge_batch(states)
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use datafusion_expr::EmitTo;
653
654    use super::*;
655
656    #[test]
657    fn test_groups_accumulator_merge_empty_states() -> Result<()> {
658        let state_1 = vec![
659            Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
660            Arc::new(Float64Array::from(vec![0.0])),
661            Arc::new(Float64Array::from(vec![0.0])),
662        ];
663        let state_2 = vec![
664            Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
665            Arc::new(Float64Array::from(vec![1.0])),
666            Arc::new(Float64Array::from(vec![1.0])),
667        ];
668        let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
669        acc.merge_batch(&state_1, &[0], None, 1)?;
670        acc.merge_batch(&state_2, &[0], None, 1)?;
671        let result = acc.evaluate(EmitTo::All)?;
672        let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
673        assert_eq!(result.len(), 1);
674        assert_eq!(result.value(0), 1.0);
675        Ok(())
676    }
677}