datafusion_functions_aggregate/
correlation.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`Correlation`]: correlation sample aggregations.
19
20use std::any::Any;
21use std::fmt::Debug;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{
26    downcast_array, Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder,
27    UInt64Array,
28};
29use arrow::compute::{and, filter, is_not_null};
30use arrow::datatypes::{FieldRef, Float64Type, UInt64Type};
31use arrow::{
32    array::ArrayRef,
33    datatypes::{DataType, Field},
34};
35use datafusion_expr::{EmitTo, GroupsAccumulator};
36use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple;
37use log::debug;
38
39use crate::covariance::CovarianceAccumulator;
40use crate::stddev::StddevAccumulator;
41use datafusion_common::{Result, ScalarValue};
42use datafusion_expr::{
43    function::{AccumulatorArgs, StateFieldsArgs},
44    utils::format_state_name,
45    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
46};
47use datafusion_functions_aggregate_common::stats::StatsType;
48use datafusion_macros::user_doc;
49
50make_udaf_expr_and_func!(
51    Correlation,
52    corr,
53    y x,
54    "Correlation between two numeric values.",
55    corr_udaf
56);
57
58#[user_doc(
59    doc_section(label = "Statistical Functions"),
60    description = "Returns the coefficient of correlation between two numeric values.",
61    syntax_example = "corr(expression1, expression2)",
62    sql_example = r#"```sql
63> SELECT corr(column1, column2) FROM table_name;
64+--------------------------------+
65| corr(column1, column2)         |
66+--------------------------------+
67| 0.85                           |
68+--------------------------------+
69```"#,
70    standard_argument(name = "expression1", prefix = "First"),
71    standard_argument(name = "expression2", prefix = "Second")
72)]
73#[derive(Debug, PartialEq, Eq, Hash)]
74pub struct Correlation {
75    signature: Signature,
76}
77
78impl Default for Correlation {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl Correlation {
85    /// Create a new CORR aggregate function
86    pub fn new() -> Self {
87        Self {
88            signature: Signature::exact(
89                vec![DataType::Float64, DataType::Float64],
90                Volatility::Immutable,
91            )
92            .with_parameter_names(vec!["y".to_string(), "x".to_string()])
93            .expect("valid parameter names for corr"),
94        }
95    }
96}
97
98impl AggregateUDFImpl for Correlation {
99    /// Return a reference to Any that can be used for downcasting
100    fn as_any(&self) -> &dyn Any {
101        self
102    }
103
104    fn name(&self) -> &str {
105        "corr"
106    }
107
108    fn signature(&self) -> &Signature {
109        &self.signature
110    }
111
112    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
113        Ok(DataType::Float64)
114    }
115
116    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
117        Ok(Box::new(CorrelationAccumulator::try_new()?))
118    }
119
120    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
121        let name = args.name;
122        Ok(vec![
123            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
124            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
125            Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
126            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
127            Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
128            Field::new(
129                format_state_name(name, "algo_const"),
130                DataType::Float64,
131                true,
132            ),
133        ]
134        .into_iter()
135        .map(Arc::new)
136        .collect())
137    }
138
139    fn documentation(&self) -> Option<&Documentation> {
140        self.doc()
141    }
142
143    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
144        true
145    }
146
147    fn create_groups_accumulator(
148        &self,
149        _args: AccumulatorArgs,
150    ) -> Result<Box<dyn GroupsAccumulator>> {
151        debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`");
152        Ok(Box::new(CorrelationGroupsAccumulator::new()))
153    }
154}
155
156/// An accumulator to compute correlation
157#[derive(Debug)]
158pub struct CorrelationAccumulator {
159    covar: CovarianceAccumulator,
160    stddev1: StddevAccumulator,
161    stddev2: StddevAccumulator,
162}
163
164impl CorrelationAccumulator {
165    /// Creates a new `CorrelationAccumulator`
166    pub fn try_new() -> Result<Self> {
167        Ok(Self {
168            covar: CovarianceAccumulator::try_new(StatsType::Population)?,
169            stddev1: StddevAccumulator::try_new(StatsType::Population)?,
170            stddev2: StddevAccumulator::try_new(StatsType::Population)?,
171        })
172    }
173}
174
175impl Accumulator for CorrelationAccumulator {
176    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
177        // TODO: null input skipping logic duplicated across Correlation
178        // and its children accumulators.
179        // This could be simplified by splitting up input filtering and
180        // calculation logic in children accumulators, and calling only
181        // calculation part from Correlation
182        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
183            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
184            let values1 = filter(&values[0], &mask)?;
185            let values2 = filter(&values[1], &mask)?;
186
187            vec![values1, values2]
188        } else {
189            values.to_vec()
190        };
191
192        self.covar.update_batch(&values)?;
193        self.stddev1.update_batch(&values[0..1])?;
194        self.stddev2.update_batch(&values[1..2])?;
195        Ok(())
196    }
197
198    fn evaluate(&mut self) -> Result<ScalarValue> {
199        let n = self.covar.get_count();
200        if n < 2 {
201            return Ok(ScalarValue::Float64(None));
202        }
203
204        let covar = self.covar.evaluate()?;
205        let stddev1 = self.stddev1.evaluate()?;
206        let stddev2 = self.stddev2.evaluate()?;
207
208        if let ScalarValue::Float64(Some(c)) = covar {
209            if let ScalarValue::Float64(Some(s1)) = stddev1 {
210                if let ScalarValue::Float64(Some(s2)) = stddev2 {
211                    if s1 == 0_f64 || s2 == 0_f64 {
212                        return Ok(ScalarValue::Float64(None));
213                    } else {
214                        return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
215                    }
216                }
217            }
218        }
219
220        Ok(ScalarValue::Float64(None))
221    }
222
223    fn size(&self) -> usize {
224        size_of_val(self) - size_of_val(&self.covar) + self.covar.size()
225            - size_of_val(&self.stddev1)
226            + self.stddev1.size()
227            - size_of_val(&self.stddev2)
228            + self.stddev2.size()
229    }
230
231    fn state(&mut self) -> Result<Vec<ScalarValue>> {
232        Ok(vec![
233            ScalarValue::from(self.covar.get_count()),
234            ScalarValue::from(self.covar.get_mean1()),
235            ScalarValue::from(self.stddev1.get_m2()),
236            ScalarValue::from(self.covar.get_mean2()),
237            ScalarValue::from(self.stddev2.get_m2()),
238            ScalarValue::from(self.covar.get_algo_const()),
239        ])
240    }
241
242    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
243        let states_c = [
244            Arc::clone(&states[0]),
245            Arc::clone(&states[1]),
246            Arc::clone(&states[3]),
247            Arc::clone(&states[5]),
248        ];
249        let states_s1 = [
250            Arc::clone(&states[0]),
251            Arc::clone(&states[1]),
252            Arc::clone(&states[2]),
253        ];
254        let states_s2 = [
255            Arc::clone(&states[0]),
256            Arc::clone(&states[3]),
257            Arc::clone(&states[4]),
258        ];
259
260        self.covar.merge_batch(&states_c)?;
261        self.stddev1.merge_batch(&states_s1)?;
262        self.stddev2.merge_batch(&states_s2)?;
263        Ok(())
264    }
265
266    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
267        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
268            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
269            let values1 = filter(&values[0], &mask)?;
270            let values2 = filter(&values[1], &mask)?;
271
272            vec![values1, values2]
273        } else {
274            values.to_vec()
275        };
276
277        self.covar.retract_batch(&values)?;
278        self.stddev1.retract_batch(&values[0..1])?;
279        self.stddev2.retract_batch(&values[1..2])?;
280        Ok(())
281    }
282}
283
284#[derive(Default)]
285pub struct CorrelationGroupsAccumulator {
286    // Number of elements for each group
287    // This is also used to track nulls: if a group has 0 valid values accumulated,
288    // final aggregation result will be null.
289    count: Vec<u64>,
290    // Sum of x values for each group
291    sum_x: Vec<f64>,
292    // Sum of y
293    sum_y: Vec<f64>,
294    // Sum of x*y
295    sum_xy: Vec<f64>,
296    // Sum of x^2
297    sum_xx: Vec<f64>,
298    // Sum of y^2
299    sum_yy: Vec<f64>,
300}
301
302impl CorrelationGroupsAccumulator {
303    pub fn new() -> Self {
304        Default::default()
305    }
306}
307
308/// Specialized version of `accumulate_multiple` for correlation's merge_batch
309///
310/// Note: Arrays in `state_arrays` should not have null values, because they are all
311/// intermediate states created within the accumulator, instead of inputs from
312/// outside.
313fn accumulate_correlation_states(
314    group_indices: &[usize],
315    state_arrays: (
316        &UInt64Array,  // count
317        &Float64Array, // sum_x
318        &Float64Array, // sum_y
319        &Float64Array, // sum_xy
320        &Float64Array, // sum_xx
321        &Float64Array, // sum_yy
322    ),
323    mut value_fn: impl FnMut(usize, u64, &[f64]),
324) {
325    let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays;
326
327    assert_eq!(counts.null_count(), 0);
328    assert_eq!(sum_x.null_count(), 0);
329    assert_eq!(sum_y.null_count(), 0);
330    assert_eq!(sum_xy.null_count(), 0);
331    assert_eq!(sum_xx.null_count(), 0);
332    assert_eq!(sum_yy.null_count(), 0);
333
334    let counts_values = counts.values().as_ref();
335    let sum_x_values = sum_x.values().as_ref();
336    let sum_y_values = sum_y.values().as_ref();
337    let sum_xy_values = sum_xy.values().as_ref();
338    let sum_xx_values = sum_xx.values().as_ref();
339    let sum_yy_values = sum_yy.values().as_ref();
340
341    for (idx, &group_idx) in group_indices.iter().enumerate() {
342        let row = [
343            sum_x_values[idx],
344            sum_y_values[idx],
345            sum_xy_values[idx],
346            sum_xx_values[idx],
347            sum_yy_values[idx],
348        ];
349        value_fn(group_idx, counts_values[idx], &row);
350    }
351}
352
353/// GroupsAccumulator implementation for `corr(x, y)` that computes the Pearson correlation coefficient
354/// between two numeric columns.
355///
356/// Online algorithm for correlation:
357///
358/// r = (n * sum_xy - sum_x * sum_y) / sqrt((n * sum_xx - sum_x^2) * (n * sum_yy - sum_y^2))
359/// where:
360/// n = number of observations
361/// sum_x = sum of x values
362/// sum_y = sum of y values  
363/// sum_xy = sum of (x * y)
364/// sum_xx = sum of x^2 values
365/// sum_yy = sum of y^2 values
366///
367/// Reference: <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#For_a_sample>
368impl GroupsAccumulator for CorrelationGroupsAccumulator {
369    fn update_batch(
370        &mut self,
371        values: &[ArrayRef],
372        group_indices: &[usize],
373        opt_filter: Option<&BooleanArray>,
374        total_num_groups: usize,
375    ) -> Result<()> {
376        self.count.resize(total_num_groups, 0);
377        self.sum_x.resize(total_num_groups, 0.0);
378        self.sum_y.resize(total_num_groups, 0.0);
379        self.sum_xy.resize(total_num_groups, 0.0);
380        self.sum_xx.resize(total_num_groups, 0.0);
381        self.sum_yy.resize(total_num_groups, 0.0);
382
383        let array_x = downcast_array::<Float64Array>(&values[0]);
384        let array_y = downcast_array::<Float64Array>(&values[1]);
385
386        accumulate_multiple(
387            group_indices,
388            &[&array_x, &array_y],
389            opt_filter,
390            |group_index, batch_index, columns| {
391                let x = columns[0].value(batch_index);
392                let y = columns[1].value(batch_index);
393                self.count[group_index] += 1;
394                self.sum_x[group_index] += x;
395                self.sum_y[group_index] += y;
396                self.sum_xy[group_index] += x * y;
397                self.sum_xx[group_index] += x * x;
398                self.sum_yy[group_index] += y * y;
399            },
400        );
401
402        Ok(())
403    }
404
405    fn merge_batch(
406        &mut self,
407        values: &[ArrayRef],
408        group_indices: &[usize],
409        opt_filter: Option<&BooleanArray>,
410        total_num_groups: usize,
411    ) -> Result<()> {
412        // Resize vectors to accommodate total number of groups
413        self.count.resize(total_num_groups, 0);
414        self.sum_x.resize(total_num_groups, 0.0);
415        self.sum_y.resize(total_num_groups, 0.0);
416        self.sum_xy.resize(total_num_groups, 0.0);
417        self.sum_xx.resize(total_num_groups, 0.0);
418        self.sum_yy.resize(total_num_groups, 0.0);
419
420        // Extract arrays from input values
421        let partial_counts = values[0].as_primitive::<UInt64Type>();
422        let partial_sum_x = values[1].as_primitive::<Float64Type>();
423        let partial_sum_y = values[2].as_primitive::<Float64Type>();
424        let partial_sum_xy = values[3].as_primitive::<Float64Type>();
425        let partial_sum_xx = values[4].as_primitive::<Float64Type>();
426        let partial_sum_yy = values[5].as_primitive::<Float64Type>();
427
428        assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage");
429
430        accumulate_correlation_states(
431            group_indices,
432            (
433                partial_counts,
434                partial_sum_x,
435                partial_sum_y,
436                partial_sum_xy,
437                partial_sum_xx,
438                partial_sum_yy,
439            ),
440            |group_index, count, values| {
441                self.count[group_index] += count;
442                self.sum_x[group_index] += values[0];
443                self.sum_y[group_index] += values[1];
444                self.sum_xy[group_index] += values[2];
445                self.sum_xx[group_index] += values[3];
446                self.sum_yy[group_index] += values[4];
447            },
448        );
449
450        Ok(())
451    }
452
453    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
454        let n = match emit_to {
455            EmitTo::All => self.count.len(),
456            EmitTo::First(n) => n,
457        };
458
459        let mut values = Vec::with_capacity(n);
460        let mut nulls = NullBufferBuilder::new(n);
461
462        // Notes for `Null` handling:
463        // - If the `count` state of a group is 0, no valid records are accumulated
464        //   for this group, so the aggregation result is `Null`.
465        // - Correlation can't be calculated when a group only has 1 record, or when
466        //   the `denominator` state is 0. In these cases, the final aggregation
467        //   result should be `Null` (according to PostgreSQL's behavior).
468        //
469        for i in 0..n {
470            if self.count[i] < 2 {
471                values.push(0.0);
472                nulls.append_null();
473                continue;
474            }
475
476            let count = self.count[i];
477            let sum_x = self.sum_x[i];
478            let sum_y = self.sum_y[i];
479            let sum_xy = self.sum_xy[i];
480            let sum_xx = self.sum_xx[i];
481            let sum_yy = self.sum_yy[i];
482
483            let mean_x = sum_x / count as f64;
484            let mean_y = sum_y / count as f64;
485
486            let numerator = sum_xy - sum_x * mean_y;
487            let denominator =
488                ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt();
489
490            if denominator == 0.0 {
491                values.push(0.0);
492                nulls.append_null();
493            } else {
494                values.push(numerator / denominator);
495                nulls.append_non_null();
496            }
497        }
498
499        Ok(Arc::new(Float64Array::new(values.into(), nulls.finish())))
500    }
501
502    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
503        let n = match emit_to {
504            EmitTo::All => self.count.len(),
505            EmitTo::First(n) => n,
506        };
507
508        Ok(vec![
509            Arc::new(UInt64Array::from(self.count[0..n].to_vec())),
510            Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())),
511            Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())),
512            Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())),
513            Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())),
514            Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())),
515        ])
516    }
517
518    fn size(&self) -> usize {
519        size_of_val(&self.count)
520            + size_of_val(&self.sum_x)
521            + size_of_val(&self.sum_y)
522            + size_of_val(&self.sum_xy)
523            + size_of_val(&self.sum_xx)
524            + size_of_val(&self.sum_yy)
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use arrow::array::{Float64Array, UInt64Array};
532
533    #[test]
534    fn test_accumulate_correlation_states() {
535        // Test data
536        let group_indices = vec![0, 1, 0, 1];
537        let counts = UInt64Array::from(vec![1, 2, 3, 4]);
538        let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
539        let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
540        let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
541        let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
542        let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
543
544        let mut accumulated = vec![];
545        accumulate_correlation_states(
546            &group_indices,
547            (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
548            |group_idx, count, values| {
549                accumulated.push((group_idx, count, values.to_vec()));
550            },
551        );
552
553        let expected = vec![
554            (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]),
555            (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]),
556            (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]),
557            (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]),
558        ];
559        assert_eq!(accumulated, expected);
560
561        // Test that function panics with null values
562        let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]);
563        let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
564        let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
565        let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
566        let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
567        let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
568
569        let result = std::panic::catch_unwind(|| {
570            accumulate_correlation_states(
571                &group_indices,
572                (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
573                |_, _, _| {},
574            )
575        });
576        assert!(result.is_err());
577    }
578}