Skip to main content

datafusion_functions_aggregate/
covariance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CovarianceSample`]: covariance sample aggregations.
19
20use arrow::array::ArrayRef;
21use arrow::datatypes::{DataType, Field, FieldRef};
22use datafusion_common::cast::{as_float64_array, as_uint64_array};
23use datafusion_common::{Result, ScalarValue};
24use datafusion_expr::{
25    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
26    function::{AccumulatorArgs, StateFieldsArgs},
27    utils::format_state_name,
28};
29use datafusion_functions_aggregate_common::stats::StatsType;
30use datafusion_macros::user_doc;
31use std::fmt::Debug;
32use std::mem::size_of_val;
33use std::sync::Arc;
34
35make_udaf_expr_and_func!(
36    CovarianceSample,
37    covar_samp,
38    y x,
39    "Computes the sample covariance.",
40    covar_samp_udaf
41);
42
43make_udaf_expr_and_func!(
44    CovariancePopulation,
45    covar_pop,
46    y x,
47    "Computes the population covariance.",
48    covar_pop_udaf
49);
50
51#[user_doc(
52    doc_section(label = "Statistical Functions"),
53    description = "Returns the sample covariance of a set of number pairs.",
54    syntax_example = "covar_samp(expression1, expression2)",
55    sql_example = r#"```sql
56> SELECT covar_samp(column1, column2) FROM table_name;
57+-----------------------------------+
58| covar_samp(column1, column2)      |
59+-----------------------------------+
60| 8.25                              |
61+-----------------------------------+
62```"#,
63    standard_argument(name = "expression1", prefix = "First"),
64    standard_argument(name = "expression2", prefix = "Second")
65)]
66#[derive(PartialEq, Eq, Hash, Debug)]
67pub struct CovarianceSample {
68    signature: Signature,
69    aliases: Vec<String>,
70}
71
72impl Default for CovarianceSample {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl CovarianceSample {
79    pub fn new() -> Self {
80        Self {
81            aliases: vec![String::from("covar")],
82            signature: Signature::exact(
83                vec![DataType::Float64, DataType::Float64],
84                Volatility::Immutable,
85            ),
86        }
87    }
88}
89
90impl AggregateUDFImpl for CovarianceSample {
91    fn as_any(&self) -> &dyn std::any::Any {
92        self
93    }
94
95    fn name(&self) -> &str {
96        "covar_samp"
97    }
98
99    fn signature(&self) -> &Signature {
100        &self.signature
101    }
102
103    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
104        Ok(DataType::Float64)
105    }
106
107    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
108        let name = args.name;
109        Ok(vec![
110            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
111            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
112            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
113            Field::new(
114                format_state_name(name, "algo_const"),
115                DataType::Float64,
116                true,
117            ),
118        ]
119        .into_iter()
120        .map(Arc::new)
121        .collect())
122    }
123
124    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
125        Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
126    }
127
128    fn aliases(&self) -> &[String] {
129        &self.aliases
130    }
131
132    fn documentation(&self) -> Option<&Documentation> {
133        self.doc()
134    }
135}
136
137#[user_doc(
138    doc_section(label = "Statistical Functions"),
139    description = "Returns the sample covariance of a set of number pairs.",
140    syntax_example = "covar_samp(expression1, expression2)",
141    sql_example = r#"```sql
142> SELECT covar_samp(column1, column2) FROM table_name;
143+-----------------------------------+
144| covar_samp(column1, column2)      |
145+-----------------------------------+
146| 8.25                              |
147+-----------------------------------+
148```"#,
149    standard_argument(name = "expression1", prefix = "First"),
150    standard_argument(name = "expression2", prefix = "Second")
151)]
152#[derive(PartialEq, Eq, Hash, Debug)]
153pub struct CovariancePopulation {
154    signature: Signature,
155}
156
157impl Default for CovariancePopulation {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163impl CovariancePopulation {
164    pub fn new() -> Self {
165        Self {
166            signature: Signature::exact(
167                vec![DataType::Float64, DataType::Float64],
168                Volatility::Immutable,
169            ),
170        }
171    }
172}
173
174impl AggregateUDFImpl for CovariancePopulation {
175    fn as_any(&self) -> &dyn std::any::Any {
176        self
177    }
178
179    fn name(&self) -> &str {
180        "covar_pop"
181    }
182
183    fn signature(&self) -> &Signature {
184        &self.signature
185    }
186
187    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
188        Ok(DataType::Float64)
189    }
190
191    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
192        let name = args.name;
193        Ok(vec![
194            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
195            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
196            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
197            Field::new(
198                format_state_name(name, "algo_const"),
199                DataType::Float64,
200                true,
201            ),
202        ]
203        .into_iter()
204        .map(Arc::new)
205        .collect())
206    }
207
208    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
209        Ok(Box::new(CovarianceAccumulator::try_new(
210            StatsType::Population,
211        )?))
212    }
213
214    fn documentation(&self) -> Option<&Documentation> {
215        self.doc()
216    }
217}
218
219/// An accumulator to compute covariance
220/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper
221/// for calculating variance:
222/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
223/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
224///
225/// The algorithm has been analyzed here:
226/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
227/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
228///
229/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online,
230/// parallelize and numerically stable.
231
232#[derive(Debug)]
233pub struct CovarianceAccumulator {
234    algo_const: f64,
235    mean1: f64,
236    mean2: f64,
237    count: u64,
238    stats_type: StatsType,
239}
240
241impl CovarianceAccumulator {
242    /// Creates a new `CovarianceAccumulator`
243    pub fn try_new(s_type: StatsType) -> Result<Self> {
244        Ok(Self {
245            algo_const: 0_f64,
246            mean1: 0_f64,
247            mean2: 0_f64,
248            count: 0_u64,
249            stats_type: s_type,
250        })
251    }
252
253    pub fn get_count(&self) -> u64 {
254        self.count
255    }
256
257    pub fn get_mean1(&self) -> f64 {
258        self.mean1
259    }
260
261    pub fn get_mean2(&self) -> f64 {
262        self.mean2
263    }
264
265    pub fn get_algo_const(&self) -> f64 {
266        self.algo_const
267    }
268}
269
270impl Accumulator for CovarianceAccumulator {
271    fn state(&mut self) -> Result<Vec<ScalarValue>> {
272        Ok(vec![
273            ScalarValue::from(self.count),
274            ScalarValue::from(self.mean1),
275            ScalarValue::from(self.mean2),
276            ScalarValue::from(self.algo_const),
277        ])
278    }
279
280    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
281        let values1 = as_float64_array(&values[0])?;
282        let values2 = as_float64_array(&values[1])?;
283
284        for (value1, value2) in values1.iter().zip(values2) {
285            let (value1, value2) = match (value1, value2) {
286                (Some(a), Some(b)) => (a, b),
287                _ => continue,
288            };
289
290            let new_count = self.count + 1;
291            let delta1 = value1 - self.mean1;
292            let new_mean1 = delta1 / new_count as f64 + self.mean1;
293            let delta2 = value2 - self.mean2;
294            let new_mean2 = delta2 / new_count as f64 + self.mean2;
295            let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
296
297            self.count += 1;
298            self.mean1 = new_mean1;
299            self.mean2 = new_mean2;
300            self.algo_const = new_c;
301        }
302
303        Ok(())
304    }
305
306    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
307        let values1 = as_float64_array(&values[0])?;
308        let values2 = as_float64_array(&values[1])?;
309
310        for (value1, value2) in values1.iter().zip(values2) {
311            let (value1, value2) = match (value1, value2) {
312                (Some(a), Some(b)) => (a, b),
313                _ => continue,
314            };
315
316            let new_count = self.count - 1;
317            let delta1 = self.mean1 - value1;
318            let new_mean1 = delta1 / new_count as f64 + self.mean1;
319            let delta2 = self.mean2 - value2;
320            let new_mean2 = delta2 / new_count as f64 + self.mean2;
321            let new_c = self.algo_const - delta1 * (new_mean2 - value2);
322
323            self.count -= 1;
324            self.mean1 = new_mean1;
325            self.mean2 = new_mean2;
326            self.algo_const = new_c;
327        }
328
329        Ok(())
330    }
331
332    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
333        let counts = as_uint64_array(&states[0])?;
334        let means1 = as_float64_array(&states[1])?;
335        let means2 = as_float64_array(&states[2])?;
336        let cs = as_float64_array(&states[3])?;
337
338        for i in 0..counts.len() {
339            let c = counts.value(i);
340            if c == 0_u64 {
341                continue;
342            }
343            let new_count = self.count + c;
344            let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
345                + means1.value(i) * c as f64 / new_count as f64;
346            let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
347                + means2.value(i) * c as f64 / new_count as f64;
348            let delta1 = self.mean1 - means1.value(i);
349            let delta2 = self.mean2 - means2.value(i);
350            let new_c = self.algo_const
351                + cs.value(i)
352                + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
353
354            self.count = new_count;
355            self.mean1 = new_mean1;
356            self.mean2 = new_mean2;
357            self.algo_const = new_c;
358        }
359        Ok(())
360    }
361
362    fn evaluate(&mut self) -> Result<ScalarValue> {
363        let count = match self.stats_type {
364            StatsType::Population => self.count,
365            StatsType::Sample => {
366                if self.count > 0 {
367                    self.count - 1
368                } else {
369                    self.count
370                }
371            }
372        };
373
374        if count == 0 {
375            Ok(ScalarValue::Float64(None))
376        } else {
377            Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
378        }
379    }
380
381    fn size(&self) -> usize {
382        size_of_val(self)
383    }
384}