datafusion_functions_aggregate/
covariance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CovarianceSample`]: covariance sample aggregations.
19
20use arrow::datatypes::FieldRef;
21use arrow::{
22    array::{ArrayRef, Float64Array, UInt64Array},
23    compute::kernels::cast,
24    datatypes::{DataType, Field},
25};
26use datafusion_common::{
27    downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result,
28    ScalarValue,
29};
30use datafusion_expr::{
31    function::{AccumulatorArgs, StateFieldsArgs},
32    type_coercion::aggregates::NUMERICS,
33    utils::format_state_name,
34    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
35};
36use datafusion_functions_aggregate_common::stats::StatsType;
37use datafusion_macros::user_doc;
38use std::fmt::Debug;
39use std::mem::size_of_val;
40use std::sync::Arc;
41
42make_udaf_expr_and_func!(
43    CovarianceSample,
44    covar_samp,
45    y x,
46    "Computes the sample covariance.",
47    covar_samp_udaf
48);
49
50make_udaf_expr_and_func!(
51    CovariancePopulation,
52    covar_pop,
53    y x,
54    "Computes the population covariance.",
55    covar_pop_udaf
56);
57
58#[user_doc(
59    doc_section(label = "Statistical Functions"),
60    description = "Returns the sample covariance of a set of number pairs.",
61    syntax_example = "covar_samp(expression1, expression2)",
62    sql_example = r#"```sql
63> SELECT covar_samp(column1, column2) FROM table_name;
64+-----------------------------------+
65| covar_samp(column1, column2)      |
66+-----------------------------------+
67| 8.25                              |
68+-----------------------------------+
69```"#,
70    standard_argument(name = "expression1", prefix = "First"),
71    standard_argument(name = "expression2", prefix = "Second")
72)]
73#[derive(PartialEq, Eq, Hash)]
74pub struct CovarianceSample {
75    signature: Signature,
76    aliases: Vec<String>,
77}
78
79impl Debug for CovarianceSample {
80    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
81        f.debug_struct("CovarianceSample")
82            .field("name", &self.name())
83            .field("signature", &self.signature)
84            .finish()
85    }
86}
87
88impl Default for CovarianceSample {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl CovarianceSample {
95    pub fn new() -> Self {
96        Self {
97            aliases: vec![String::from("covar")],
98            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
99        }
100    }
101}
102
103impl AggregateUDFImpl for CovarianceSample {
104    fn as_any(&self) -> &dyn std::any::Any {
105        self
106    }
107
108    fn name(&self) -> &str {
109        "covar_samp"
110    }
111
112    fn signature(&self) -> &Signature {
113        &self.signature
114    }
115
116    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
117        if !arg_types[0].is_numeric() {
118            return plan_err!("Covariance requires numeric input types");
119        }
120
121        Ok(DataType::Float64)
122    }
123
124    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
125        let name = args.name;
126        Ok(vec![
127            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
128            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
129            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
130            Field::new(
131                format_state_name(name, "algo_const"),
132                DataType::Float64,
133                true,
134            ),
135        ]
136        .into_iter()
137        .map(Arc::new)
138        .collect())
139    }
140
141    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
142        Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
143    }
144
145    fn aliases(&self) -> &[String] {
146        &self.aliases
147    }
148
149    fn documentation(&self) -> Option<&Documentation> {
150        self.doc()
151    }
152}
153
154#[user_doc(
155    doc_section(label = "Statistical Functions"),
156    description = "Returns the sample covariance of a set of number pairs.",
157    syntax_example = "covar_samp(expression1, expression2)",
158    sql_example = r#"```sql
159> SELECT covar_samp(column1, column2) FROM table_name;
160+-----------------------------------+
161| covar_samp(column1, column2)      |
162+-----------------------------------+
163| 8.25                              |
164+-----------------------------------+
165```"#,
166    standard_argument(name = "expression1", prefix = "First"),
167    standard_argument(name = "expression2", prefix = "Second")
168)]
169#[derive(PartialEq, Eq, Hash)]
170pub struct CovariancePopulation {
171    signature: Signature,
172}
173
174impl Debug for CovariancePopulation {
175    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
176        f.debug_struct("CovariancePopulation")
177            .field("name", &self.name())
178            .field("signature", &self.signature)
179            .finish()
180    }
181}
182
183impl Default for CovariancePopulation {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189impl CovariancePopulation {
190    pub fn new() -> Self {
191        Self {
192            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
193        }
194    }
195}
196
197impl AggregateUDFImpl for CovariancePopulation {
198    fn as_any(&self) -> &dyn std::any::Any {
199        self
200    }
201
202    fn name(&self) -> &str {
203        "covar_pop"
204    }
205
206    fn signature(&self) -> &Signature {
207        &self.signature
208    }
209
210    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
211        if !arg_types[0].is_numeric() {
212            return plan_err!("Covariance requires numeric input types");
213        }
214
215        Ok(DataType::Float64)
216    }
217
218    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
219        let name = args.name;
220        Ok(vec![
221            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
222            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
223            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
224            Field::new(
225                format_state_name(name, "algo_const"),
226                DataType::Float64,
227                true,
228            ),
229        ]
230        .into_iter()
231        .map(Arc::new)
232        .collect())
233    }
234
235    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
236        Ok(Box::new(CovarianceAccumulator::try_new(
237            StatsType::Population,
238        )?))
239    }
240
241    fn documentation(&self) -> Option<&Documentation> {
242        self.doc()
243    }
244}
245
246/// An accumulator to compute covariance
247/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper
248/// for calculating variance:
249/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
250/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
251///
252/// The algorithm has been analyzed here:
253/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
254/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
255///
256/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online,
257/// parallelize and numerically stable.
258
259#[derive(Debug)]
260pub struct CovarianceAccumulator {
261    algo_const: f64,
262    mean1: f64,
263    mean2: f64,
264    count: u64,
265    stats_type: StatsType,
266}
267
268impl CovarianceAccumulator {
269    /// Creates a new `CovarianceAccumulator`
270    pub fn try_new(s_type: StatsType) -> Result<Self> {
271        Ok(Self {
272            algo_const: 0_f64,
273            mean1: 0_f64,
274            mean2: 0_f64,
275            count: 0_u64,
276            stats_type: s_type,
277        })
278    }
279
280    pub fn get_count(&self) -> u64 {
281        self.count
282    }
283
284    pub fn get_mean1(&self) -> f64 {
285        self.mean1
286    }
287
288    pub fn get_mean2(&self) -> f64 {
289        self.mean2
290    }
291
292    pub fn get_algo_const(&self) -> f64 {
293        self.algo_const
294    }
295}
296
297impl Accumulator for CovarianceAccumulator {
298    fn state(&mut self) -> Result<Vec<ScalarValue>> {
299        Ok(vec![
300            ScalarValue::from(self.count),
301            ScalarValue::from(self.mean1),
302            ScalarValue::from(self.mean2),
303            ScalarValue::from(self.algo_const),
304        ])
305    }
306
307    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
308        let values1 = &cast(&values[0], &DataType::Float64)?;
309        let values2 = &cast(&values[1], &DataType::Float64)?;
310
311        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
312        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
313
314        for i in 0..values1.len() {
315            let value1 = if values1.is_valid(i) {
316                arr1.next()
317            } else {
318                None
319            };
320            let value2 = if values2.is_valid(i) {
321                arr2.next()
322            } else {
323                None
324            };
325
326            if value1.is_none() || value2.is_none() {
327                continue;
328            }
329
330            let value1 = unwrap_or_internal_err!(value1);
331            let value2 = unwrap_or_internal_err!(value2);
332            let new_count = self.count + 1;
333            let delta1 = value1 - self.mean1;
334            let new_mean1 = delta1 / new_count as f64 + self.mean1;
335            let delta2 = value2 - self.mean2;
336            let new_mean2 = delta2 / new_count as f64 + self.mean2;
337            let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
338
339            self.count += 1;
340            self.mean1 = new_mean1;
341            self.mean2 = new_mean2;
342            self.algo_const = new_c;
343        }
344
345        Ok(())
346    }
347
348    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
349        let values1 = &cast(&values[0], &DataType::Float64)?;
350        let values2 = &cast(&values[1], &DataType::Float64)?;
351        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
352        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
353
354        for i in 0..values1.len() {
355            let value1 = if values1.is_valid(i) {
356                arr1.next()
357            } else {
358                None
359            };
360            let value2 = if values2.is_valid(i) {
361                arr2.next()
362            } else {
363                None
364            };
365
366            if value1.is_none() || value2.is_none() {
367                continue;
368            }
369
370            let value1 = unwrap_or_internal_err!(value1);
371            let value2 = unwrap_or_internal_err!(value2);
372
373            let new_count = self.count - 1;
374            let delta1 = self.mean1 - value1;
375            let new_mean1 = delta1 / new_count as f64 + self.mean1;
376            let delta2 = self.mean2 - value2;
377            let new_mean2 = delta2 / new_count as f64 + self.mean2;
378            let new_c = self.algo_const - delta1 * (new_mean2 - value2);
379
380            self.count -= 1;
381            self.mean1 = new_mean1;
382            self.mean2 = new_mean2;
383            self.algo_const = new_c;
384        }
385
386        Ok(())
387    }
388
389    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
390        let counts = downcast_value!(states[0], UInt64Array);
391        let means1 = downcast_value!(states[1], Float64Array);
392        let means2 = downcast_value!(states[2], Float64Array);
393        let cs = downcast_value!(states[3], Float64Array);
394
395        for i in 0..counts.len() {
396            let c = counts.value(i);
397            if c == 0_u64 {
398                continue;
399            }
400            let new_count = self.count + c;
401            let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
402                + means1.value(i) * c as f64 / new_count as f64;
403            let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
404                + means2.value(i) * c as f64 / new_count as f64;
405            let delta1 = self.mean1 - means1.value(i);
406            let delta2 = self.mean2 - means2.value(i);
407            let new_c = self.algo_const
408                + cs.value(i)
409                + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
410
411            self.count = new_count;
412            self.mean1 = new_mean1;
413            self.mean2 = new_mean2;
414            self.algo_const = new_c;
415        }
416        Ok(())
417    }
418
419    fn evaluate(&mut self) -> Result<ScalarValue> {
420        let count = match self.stats_type {
421            StatsType::Population => self.count,
422            StatsType::Sample => {
423                if self.count > 0 {
424                    self.count - 1
425                } else {
426                    self.count
427                }
428            }
429        };
430
431        if count == 0 {
432            Ok(ScalarValue::Float64(None))
433        } else {
434            Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
435        }
436    }
437
438    fn size(&self) -> usize {
439        size_of_val(self)
440    }
441}