datafusion_functions_aggregate/
covariance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CovarianceSample`]: covariance sample aggregations.
19
20use arrow::datatypes::FieldRef;
21use arrow::{
22    array::{ArrayRef, Float64Array, UInt64Array},
23    compute::kernels::cast,
24    datatypes::{DataType, Field},
25};
26use datafusion_common::{
27    Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err,
28};
29use datafusion_expr::{
30    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
31    function::{AccumulatorArgs, StateFieldsArgs},
32    type_coercion::aggregates::NUMERICS,
33    utils::format_state_name,
34};
35use datafusion_functions_aggregate_common::stats::StatsType;
36use datafusion_macros::user_doc;
37use std::fmt::Debug;
38use std::mem::size_of_val;
39use std::sync::Arc;
40
41make_udaf_expr_and_func!(
42    CovarianceSample,
43    covar_samp,
44    y x,
45    "Computes the sample covariance.",
46    covar_samp_udaf
47);
48
49make_udaf_expr_and_func!(
50    CovariancePopulation,
51    covar_pop,
52    y x,
53    "Computes the population covariance.",
54    covar_pop_udaf
55);
56
57#[user_doc(
58    doc_section(label = "Statistical Functions"),
59    description = "Returns the sample covariance of a set of number pairs.",
60    syntax_example = "covar_samp(expression1, expression2)",
61    sql_example = r#"```sql
62> SELECT covar_samp(column1, column2) FROM table_name;
63+-----------------------------------+
64| covar_samp(column1, column2)      |
65+-----------------------------------+
66| 8.25                              |
67+-----------------------------------+
68```"#,
69    standard_argument(name = "expression1", prefix = "First"),
70    standard_argument(name = "expression2", prefix = "Second")
71)]
72#[derive(PartialEq, Eq, Hash)]
73pub struct CovarianceSample {
74    signature: Signature,
75    aliases: Vec<String>,
76}
77
78impl Debug for CovarianceSample {
79    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
80        f.debug_struct("CovarianceSample")
81            .field("name", &self.name())
82            .field("signature", &self.signature)
83            .finish()
84    }
85}
86
87impl Default for CovarianceSample {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl CovarianceSample {
94    pub fn new() -> Self {
95        Self {
96            aliases: vec![String::from("covar")],
97            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
98        }
99    }
100}
101
102impl AggregateUDFImpl for CovarianceSample {
103    fn as_any(&self) -> &dyn std::any::Any {
104        self
105    }
106
107    fn name(&self) -> &str {
108        "covar_samp"
109    }
110
111    fn signature(&self) -> &Signature {
112        &self.signature
113    }
114
115    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
116        if !arg_types[0].is_numeric() {
117            return plan_err!("Covariance requires numeric input types");
118        }
119
120        Ok(DataType::Float64)
121    }
122
123    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
124        let name = args.name;
125        Ok(vec![
126            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
127            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
128            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
129            Field::new(
130                format_state_name(name, "algo_const"),
131                DataType::Float64,
132                true,
133            ),
134        ]
135        .into_iter()
136        .map(Arc::new)
137        .collect())
138    }
139
140    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
141        Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
142    }
143
144    fn aliases(&self) -> &[String] {
145        &self.aliases
146    }
147
148    fn documentation(&self) -> Option<&Documentation> {
149        self.doc()
150    }
151}
152
153#[user_doc(
154    doc_section(label = "Statistical Functions"),
155    description = "Returns the sample covariance of a set of number pairs.",
156    syntax_example = "covar_samp(expression1, expression2)",
157    sql_example = r#"```sql
158> SELECT covar_samp(column1, column2) FROM table_name;
159+-----------------------------------+
160| covar_samp(column1, column2)      |
161+-----------------------------------+
162| 8.25                              |
163+-----------------------------------+
164```"#,
165    standard_argument(name = "expression1", prefix = "First"),
166    standard_argument(name = "expression2", prefix = "Second")
167)]
168#[derive(PartialEq, Eq, Hash)]
169pub struct CovariancePopulation {
170    signature: Signature,
171}
172
173impl Debug for CovariancePopulation {
174    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
175        f.debug_struct("CovariancePopulation")
176            .field("name", &self.name())
177            .field("signature", &self.signature)
178            .finish()
179    }
180}
181
182impl Default for CovariancePopulation {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188impl CovariancePopulation {
189    pub fn new() -> Self {
190        Self {
191            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
192        }
193    }
194}
195
196impl AggregateUDFImpl for CovariancePopulation {
197    fn as_any(&self) -> &dyn std::any::Any {
198        self
199    }
200
201    fn name(&self) -> &str {
202        "covar_pop"
203    }
204
205    fn signature(&self) -> &Signature {
206        &self.signature
207    }
208
209    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
210        if !arg_types[0].is_numeric() {
211            return plan_err!("Covariance requires numeric input types");
212        }
213
214        Ok(DataType::Float64)
215    }
216
217    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
218        let name = args.name;
219        Ok(vec![
220            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
221            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
222            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
223            Field::new(
224                format_state_name(name, "algo_const"),
225                DataType::Float64,
226                true,
227            ),
228        ]
229        .into_iter()
230        .map(Arc::new)
231        .collect())
232    }
233
234    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
235        Ok(Box::new(CovarianceAccumulator::try_new(
236            StatsType::Population,
237        )?))
238    }
239
240    fn documentation(&self) -> Option<&Documentation> {
241        self.doc()
242    }
243}
244
245/// An accumulator to compute covariance
246/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper
247/// for calculating variance:
248/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
249/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
250///
251/// The algorithm has been analyzed here:
252/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
253/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
254///
255/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online,
256/// parallelize and numerically stable.
257
258#[derive(Debug)]
259pub struct CovarianceAccumulator {
260    algo_const: f64,
261    mean1: f64,
262    mean2: f64,
263    count: u64,
264    stats_type: StatsType,
265}
266
267impl CovarianceAccumulator {
268    /// Creates a new `CovarianceAccumulator`
269    pub fn try_new(s_type: StatsType) -> Result<Self> {
270        Ok(Self {
271            algo_const: 0_f64,
272            mean1: 0_f64,
273            mean2: 0_f64,
274            count: 0_u64,
275            stats_type: s_type,
276        })
277    }
278
279    pub fn get_count(&self) -> u64 {
280        self.count
281    }
282
283    pub fn get_mean1(&self) -> f64 {
284        self.mean1
285    }
286
287    pub fn get_mean2(&self) -> f64 {
288        self.mean2
289    }
290
291    pub fn get_algo_const(&self) -> f64 {
292        self.algo_const
293    }
294}
295
296impl Accumulator for CovarianceAccumulator {
297    fn state(&mut self) -> Result<Vec<ScalarValue>> {
298        Ok(vec![
299            ScalarValue::from(self.count),
300            ScalarValue::from(self.mean1),
301            ScalarValue::from(self.mean2),
302            ScalarValue::from(self.algo_const),
303        ])
304    }
305
306    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
307        let values1 = &cast(&values[0], &DataType::Float64)?;
308        let values2 = &cast(&values[1], &DataType::Float64)?;
309
310        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
311        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
312
313        for i in 0..values1.len() {
314            let value1 = if values1.is_valid(i) {
315                arr1.next()
316            } else {
317                None
318            };
319            let value2 = if values2.is_valid(i) {
320                arr2.next()
321            } else {
322                None
323            };
324
325            if value1.is_none() || value2.is_none() {
326                continue;
327            }
328
329            let value1 = unwrap_or_internal_err!(value1);
330            let value2 = unwrap_or_internal_err!(value2);
331            let new_count = self.count + 1;
332            let delta1 = value1 - self.mean1;
333            let new_mean1 = delta1 / new_count as f64 + self.mean1;
334            let delta2 = value2 - self.mean2;
335            let new_mean2 = delta2 / new_count as f64 + self.mean2;
336            let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
337
338            self.count += 1;
339            self.mean1 = new_mean1;
340            self.mean2 = new_mean2;
341            self.algo_const = new_c;
342        }
343
344        Ok(())
345    }
346
347    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
348        let values1 = &cast(&values[0], &DataType::Float64)?;
349        let values2 = &cast(&values[1], &DataType::Float64)?;
350        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
351        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
352
353        for i in 0..values1.len() {
354            let value1 = if values1.is_valid(i) {
355                arr1.next()
356            } else {
357                None
358            };
359            let value2 = if values2.is_valid(i) {
360                arr2.next()
361            } else {
362                None
363            };
364
365            if value1.is_none() || value2.is_none() {
366                continue;
367            }
368
369            let value1 = unwrap_or_internal_err!(value1);
370            let value2 = unwrap_or_internal_err!(value2);
371
372            let new_count = self.count - 1;
373            let delta1 = self.mean1 - value1;
374            let new_mean1 = delta1 / new_count as f64 + self.mean1;
375            let delta2 = self.mean2 - value2;
376            let new_mean2 = delta2 / new_count as f64 + self.mean2;
377            let new_c = self.algo_const - delta1 * (new_mean2 - value2);
378
379            self.count -= 1;
380            self.mean1 = new_mean1;
381            self.mean2 = new_mean2;
382            self.algo_const = new_c;
383        }
384
385        Ok(())
386    }
387
388    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
389        let counts = downcast_value!(states[0], UInt64Array);
390        let means1 = downcast_value!(states[1], Float64Array);
391        let means2 = downcast_value!(states[2], Float64Array);
392        let cs = downcast_value!(states[3], Float64Array);
393
394        for i in 0..counts.len() {
395            let c = counts.value(i);
396            if c == 0_u64 {
397                continue;
398            }
399            let new_count = self.count + c;
400            let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
401                + means1.value(i) * c as f64 / new_count as f64;
402            let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
403                + means2.value(i) * c as f64 / new_count as f64;
404            let delta1 = self.mean1 - means1.value(i);
405            let delta2 = self.mean2 - means2.value(i);
406            let new_c = self.algo_const
407                + cs.value(i)
408                + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
409
410            self.count = new_count;
411            self.mean1 = new_mean1;
412            self.mean2 = new_mean2;
413            self.algo_const = new_c;
414        }
415        Ok(())
416    }
417
418    fn evaluate(&mut self) -> Result<ScalarValue> {
419        let count = match self.stats_type {
420            StatsType::Population => self.count,
421            StatsType::Sample => {
422                if self.count > 0 {
423                    self.count - 1
424                } else {
425                    self.count
426                }
427            }
428        };
429
430        if count == 0 {
431            Ok(ScalarValue::Float64(None))
432        } else {
433            Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
434        }
435    }
436
437    fn size(&self) -> usize {
438        size_of_val(self)
439    }
440}