Skip to main content

datafusion_functions_aggregate/
correlation.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`Correlation`]: correlation sample aggregations.
19
20use std::fmt::Debug;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::array::{
25    Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder, UInt64Array,
26    downcast_array,
27};
28use arrow::compute::{and, filter, is_not_null};
29use arrow::datatypes::{FieldRef, Float64Type, UInt64Type};
30use arrow::{
31    array::ArrayRef,
32    datatypes::{DataType, Field},
33};
34use datafusion_expr::{EmitTo, GroupsAccumulator};
35use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple;
36use log::debug;
37
38use crate::covariance::CovarianceAccumulator;
39use crate::stddev::StddevAccumulator;
40use datafusion_common::{Result, ScalarValue};
41use datafusion_expr::{
42    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
43    function::{AccumulatorArgs, StateFieldsArgs},
44    utils::format_state_name,
45};
46use datafusion_functions_aggregate_common::stats::StatsType;
47use datafusion_macros::user_doc;
48
49make_udaf_expr_and_func!(
50    Correlation,
51    corr,
52    y x,
53    "Correlation between two numeric values.",
54    corr_udaf
55);
56
57#[user_doc(
58    doc_section(label = "Statistical Functions"),
59    description = "Returns the coefficient of correlation between two numeric values.",
60    syntax_example = "corr(expression1, expression2)",
61    sql_example = r#"```sql
62> SELECT corr(column1, column2) FROM table_name;
63+--------------------------------+
64| corr(column1, column2)         |
65+--------------------------------+
66| 0.85                           |
67+--------------------------------+
68```"#,
69    standard_argument(name = "expression1", prefix = "First"),
70    standard_argument(name = "expression2", prefix = "Second")
71)]
72#[derive(Debug, PartialEq, Eq, Hash)]
73pub struct Correlation {
74    signature: Signature,
75}
76
77impl Default for Correlation {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83impl Correlation {
84    /// Create a new CORR aggregate function
85    pub fn new() -> Self {
86        Self {
87            signature: Signature::exact(
88                vec![DataType::Float64, DataType::Float64],
89                Volatility::Immutable,
90            )
91            .with_parameter_names(vec!["y".to_string(), "x".to_string()])
92            .expect("valid parameter names for corr"),
93        }
94    }
95}
96
97impl AggregateUDFImpl for Correlation {
98    fn name(&self) -> &str {
99        "corr"
100    }
101
102    fn signature(&self) -> &Signature {
103        &self.signature
104    }
105
106    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
107        Ok(DataType::Float64)
108    }
109
110    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
111        Ok(Box::new(CorrelationAccumulator::try_new()?))
112    }
113
114    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
115        let name = args.name;
116        Ok(vec![
117            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
118            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
119            Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
120            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
121            Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
122            Field::new(
123                format_state_name(name, "algo_const"),
124                DataType::Float64,
125                true,
126            ),
127        ]
128        .into_iter()
129        .map(Arc::new)
130        .collect())
131    }
132
133    fn documentation(&self) -> Option<&Documentation> {
134        self.doc()
135    }
136
137    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
138        true
139    }
140
141    fn create_groups_accumulator(
142        &self,
143        _args: AccumulatorArgs,
144    ) -> Result<Box<dyn GroupsAccumulator>> {
145        debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`");
146        Ok(Box::new(CorrelationGroupsAccumulator::new()))
147    }
148}
149
150/// An accumulator to compute correlation
151#[derive(Debug)]
152pub struct CorrelationAccumulator {
153    covar: CovarianceAccumulator,
154    stddev1: StddevAccumulator,
155    stddev2: StddevAccumulator,
156}
157
158impl CorrelationAccumulator {
159    /// Creates a new `CorrelationAccumulator`
160    pub fn try_new() -> Result<Self> {
161        Ok(Self {
162            covar: CovarianceAccumulator::try_new(StatsType::Population)?,
163            stddev1: StddevAccumulator::try_new(StatsType::Population)?,
164            stddev2: StddevAccumulator::try_new(StatsType::Population)?,
165        })
166    }
167}
168
169impl Accumulator for CorrelationAccumulator {
170    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
171        // TODO: null input skipping logic duplicated across Correlation
172        // and its children accumulators.
173        // This could be simplified by splitting up input filtering and
174        // calculation logic in children accumulators, and calling only
175        // calculation part from Correlation
176        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
177            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
178            let values1 = filter(&values[0], &mask)?;
179            let values2 = filter(&values[1], &mask)?;
180
181            vec![values1, values2]
182        } else {
183            values.to_vec()
184        };
185
186        self.covar.update_batch(&values)?;
187        self.stddev1.update_batch(&values[0..1])?;
188        self.stddev2.update_batch(&values[1..2])?;
189        Ok(())
190    }
191
192    fn evaluate(&mut self) -> Result<ScalarValue> {
193        let covar = self.covar.evaluate()?;
194        let stddev1 = self.stddev1.evaluate()?;
195        let stddev2 = self.stddev2.evaluate()?;
196
197        // First check if we have NaN values by examining the internal state
198        // This handles the case where both inputs are NaN even with count=1
199        let mean1 = self.covar.get_mean1();
200        let mean2 = self.covar.get_mean2();
201
202        // If both means are NaN, then both input columns contain only NaN values
203        if mean1.is_nan() && mean2.is_nan() {
204            return Ok(ScalarValue::Float64(Some(f64::NAN)));
205        }
206        let n = self.covar.get_count();
207        if mean1.is_nan() || mean2.is_nan() || n < 2 {
208            return Ok(ScalarValue::Float64(None));
209        }
210
211        if let ScalarValue::Float64(Some(c)) = covar
212            && let ScalarValue::Float64(Some(s1)) = stddev1
213            && let ScalarValue::Float64(Some(s2)) = stddev2
214        {
215            if s1 == 0_f64 || s2 == 0_f64 {
216                return Ok(ScalarValue::Float64(None));
217            } else {
218                return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
219            }
220        }
221
222        Ok(ScalarValue::Float64(None))
223    }
224
225    fn size(&self) -> usize {
226        size_of_val(self) - size_of_val(&self.covar) + self.covar.size()
227            - size_of_val(&self.stddev1)
228            + self.stddev1.size()
229            - size_of_val(&self.stddev2)
230            + self.stddev2.size()
231    }
232
233    fn state(&mut self) -> Result<Vec<ScalarValue>> {
234        Ok(vec![
235            ScalarValue::from(self.covar.get_count()),
236            ScalarValue::from(self.covar.get_mean1()),
237            ScalarValue::from(self.stddev1.get_m2()),
238            ScalarValue::from(self.covar.get_mean2()),
239            ScalarValue::from(self.stddev2.get_m2()),
240            ScalarValue::from(self.covar.get_algo_const()),
241        ])
242    }
243
244    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
245        let states_c = [
246            Arc::clone(&states[0]),
247            Arc::clone(&states[1]),
248            Arc::clone(&states[3]),
249            Arc::clone(&states[5]),
250        ];
251        let states_s1 = [
252            Arc::clone(&states[0]),
253            Arc::clone(&states[1]),
254            Arc::clone(&states[2]),
255        ];
256        let states_s2 = [
257            Arc::clone(&states[0]),
258            Arc::clone(&states[3]),
259            Arc::clone(&states[4]),
260        ];
261
262        self.covar.merge_batch(&states_c)?;
263        self.stddev1.merge_batch(&states_s1)?;
264        self.stddev2.merge_batch(&states_s2)?;
265        Ok(())
266    }
267
268    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
269        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
270            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
271            let values1 = filter(&values[0], &mask)?;
272            let values2 = filter(&values[1], &mask)?;
273
274            vec![values1, values2]
275        } else {
276            values.to_vec()
277        };
278
279        self.covar.retract_batch(&values)?;
280        self.stddev1.retract_batch(&values[0..1])?;
281        self.stddev2.retract_batch(&values[1..2])?;
282        Ok(())
283    }
284}
285
286#[derive(Default)]
287pub struct CorrelationGroupsAccumulator {
288    // Number of elements for each group
289    // This is also used to track nulls: if a group has 0 valid values accumulated,
290    // final aggregation result will be null.
291    count: Vec<u64>,
292    // Sum of x values for each group
293    sum_x: Vec<f64>,
294    // Sum of y
295    sum_y: Vec<f64>,
296    // Sum of x*y
297    sum_xy: Vec<f64>,
298    // Sum of x^2
299    sum_xx: Vec<f64>,
300    // Sum of y^2
301    sum_yy: Vec<f64>,
302}
303
304impl CorrelationGroupsAccumulator {
305    pub fn new() -> Self {
306        Default::default()
307    }
308}
309
310/// Specialized version of `accumulate_multiple` for correlation's merge_batch
311///
312/// Note: Arrays in `state_arrays` should not have null values, because they are all
313/// intermediate states created within the accumulator, instead of inputs from
314/// outside.
315fn accumulate_correlation_states(
316    group_indices: &[usize],
317    state_arrays: (
318        &UInt64Array,  // count
319        &Float64Array, // sum_x
320        &Float64Array, // sum_y
321        &Float64Array, // sum_xy
322        &Float64Array, // sum_xx
323        &Float64Array, // sum_yy
324    ),
325    mut value_fn: impl FnMut(usize, u64, &[f64]),
326) {
327    let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays;
328
329    assert_eq!(counts.null_count(), 0);
330    assert_eq!(sum_x.null_count(), 0);
331    assert_eq!(sum_y.null_count(), 0);
332    assert_eq!(sum_xy.null_count(), 0);
333    assert_eq!(sum_xx.null_count(), 0);
334    assert_eq!(sum_yy.null_count(), 0);
335
336    let counts_values = counts.values().as_ref();
337    let sum_x_values = sum_x.values().as_ref();
338    let sum_y_values = sum_y.values().as_ref();
339    let sum_xy_values = sum_xy.values().as_ref();
340    let sum_xx_values = sum_xx.values().as_ref();
341    let sum_yy_values = sum_yy.values().as_ref();
342
343    for (idx, &group_idx) in group_indices.iter().enumerate() {
344        let row = [
345            sum_x_values[idx],
346            sum_y_values[idx],
347            sum_xy_values[idx],
348            sum_xx_values[idx],
349            sum_yy_values[idx],
350        ];
351        value_fn(group_idx, counts_values[idx], &row);
352    }
353}
354
355/// GroupsAccumulator implementation for `corr(x, y)` that computes the Pearson correlation coefficient
356/// between two numeric columns.
357///
358/// Online algorithm for correlation:
359///
360/// r = (n * sum_xy - sum_x * sum_y) / sqrt((n * sum_xx - sum_x^2) * (n * sum_yy - sum_y^2))
361/// where:
362/// n = number of observations
363/// sum_x = sum of x values
364/// sum_y = sum of y values
365/// sum_xy = sum of (x * y)
366/// sum_xx = sum of x^2 values
367/// sum_yy = sum of y^2 values
368///
369/// Reference: <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#For_a_sample>
370impl GroupsAccumulator for CorrelationGroupsAccumulator {
371    fn update_batch(
372        &mut self,
373        values: &[ArrayRef],
374        group_indices: &[usize],
375        opt_filter: Option<&BooleanArray>,
376        total_num_groups: usize,
377    ) -> Result<()> {
378        self.count.resize(total_num_groups, 0);
379        self.sum_x.resize(total_num_groups, 0.0);
380        self.sum_y.resize(total_num_groups, 0.0);
381        self.sum_xy.resize(total_num_groups, 0.0);
382        self.sum_xx.resize(total_num_groups, 0.0);
383        self.sum_yy.resize(total_num_groups, 0.0);
384
385        let array_x = downcast_array::<Float64Array>(&values[0]);
386        let array_y = downcast_array::<Float64Array>(&values[1]);
387
388        accumulate_multiple(
389            group_indices,
390            &[&array_x, &array_y],
391            opt_filter,
392            |group_index, batch_index, columns| {
393                let x = columns[0].value(batch_index);
394                let y = columns[1].value(batch_index);
395                self.count[group_index] += 1;
396                self.sum_x[group_index] += x;
397                self.sum_y[group_index] += y;
398                self.sum_xy[group_index] += x * y;
399                self.sum_xx[group_index] += x * x;
400                self.sum_yy[group_index] += y * y;
401            },
402        );
403
404        Ok(())
405    }
406
407    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
408        // Drain the state vectors for the groups being emitted
409        let counts = emit_to.take_needed(&mut self.count);
410        let sum_xs = emit_to.take_needed(&mut self.sum_x);
411        let sum_ys = emit_to.take_needed(&mut self.sum_y);
412        let sum_xys = emit_to.take_needed(&mut self.sum_xy);
413        let sum_xxs = emit_to.take_needed(&mut self.sum_xx);
414        let sum_yys = emit_to.take_needed(&mut self.sum_yy);
415
416        let n = counts.len();
417        let mut values = Vec::with_capacity(n);
418        let mut nulls = NullBufferBuilder::new(n);
419
420        // Notes for `Null` handling:
421        // - If the `count` state of a group is 0, no valid records are accumulated
422        //   for this group, so the aggregation result is `Null`.
423        // - Correlation can't be calculated when a group only has 1 record, or when
424        //   the `denominator` state is 0. In these cases, the final aggregation
425        //   result should be `Null` (according to PostgreSQL's behavior).
426        // - However, if any of the accumulated values contain NaN, the result should
427        //   be NaN regardless of the count (even for single-row groups).
428        for i in 0..n {
429            let count = counts[i];
430            let sum_x = sum_xs[i];
431            let sum_y = sum_ys[i];
432            let sum_xy = sum_xys[i];
433            let sum_xx = sum_xxs[i];
434            let sum_yy = sum_yys[i];
435
436            // If BOTH sum_x AND sum_y are NaN, then both input values are NaN → return NaN
437            // If only ONE of them is NaN, then only one input value is NaN → return NULL
438            if sum_x.is_nan() && sum_y.is_nan() {
439                // Both inputs are NaN → return NaN
440                values.push(f64::NAN);
441                nulls.append_non_null();
442                continue;
443            } else if count < 2 || sum_x.is_nan() || sum_y.is_nan() {
444                // Only one input is NaN → return NULL
445                values.push(0.0);
446                nulls.append_null();
447                continue;
448            }
449
450            let mean_x = sum_x / count as f64;
451            let mean_y = sum_y / count as f64;
452
453            let numerator = sum_xy - sum_x * mean_y;
454            let denominator =
455                ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt();
456
457            if denominator == 0.0 {
458                values.push(0.0);
459                nulls.append_null();
460            } else {
461                values.push(numerator / denominator);
462                nulls.append_non_null();
463            }
464        }
465
466        Ok(Arc::new(Float64Array::new(values.into(), nulls.finish())))
467    }
468
469    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
470        // Drain the state vectors for the groups being emitted
471        let count = emit_to.take_needed(&mut self.count);
472        let sum_x = emit_to.take_needed(&mut self.sum_x);
473        let sum_y = emit_to.take_needed(&mut self.sum_y);
474        let sum_xy = emit_to.take_needed(&mut self.sum_xy);
475        let sum_xx = emit_to.take_needed(&mut self.sum_xx);
476        let sum_yy = emit_to.take_needed(&mut self.sum_yy);
477
478        Ok(vec![
479            Arc::new(UInt64Array::from(count)),
480            Arc::new(Float64Array::from(sum_x)),
481            Arc::new(Float64Array::from(sum_y)),
482            Arc::new(Float64Array::from(sum_xy)),
483            Arc::new(Float64Array::from(sum_xx)),
484            Arc::new(Float64Array::from(sum_yy)),
485        ])
486    }
487
488    fn merge_batch(
489        &mut self,
490        values: &[ArrayRef],
491        group_indices: &[usize],
492        opt_filter: Option<&BooleanArray>,
493        total_num_groups: usize,
494    ) -> Result<()> {
495        // Resize vectors to accommodate total number of groups
496        self.count.resize(total_num_groups, 0);
497        self.sum_x.resize(total_num_groups, 0.0);
498        self.sum_y.resize(total_num_groups, 0.0);
499        self.sum_xy.resize(total_num_groups, 0.0);
500        self.sum_xx.resize(total_num_groups, 0.0);
501        self.sum_yy.resize(total_num_groups, 0.0);
502
503        // Extract arrays from input values
504        let partial_counts = values[0].as_primitive::<UInt64Type>();
505        let partial_sum_x = values[1].as_primitive::<Float64Type>();
506        let partial_sum_y = values[2].as_primitive::<Float64Type>();
507        let partial_sum_xy = values[3].as_primitive::<Float64Type>();
508        let partial_sum_xx = values[4].as_primitive::<Float64Type>();
509        let partial_sum_yy = values[5].as_primitive::<Float64Type>();
510
511        assert!(
512            opt_filter.is_none(),
513            "aggregate filter should be applied in partial stage, there should be no filter in final stage"
514        );
515
516        accumulate_correlation_states(
517            group_indices,
518            (
519                partial_counts,
520                partial_sum_x,
521                partial_sum_y,
522                partial_sum_xy,
523                partial_sum_xx,
524                partial_sum_yy,
525            ),
526            |group_index, count, values| {
527                self.count[group_index] += count;
528                self.sum_x[group_index] += values[0];
529                self.sum_y[group_index] += values[1];
530                self.sum_xy[group_index] += values[2];
531                self.sum_xx[group_index] += values[3];
532                self.sum_yy[group_index] += values[4];
533            },
534        );
535
536        Ok(())
537    }
538
539    fn size(&self) -> usize {
540        self.count.capacity() * size_of::<u64>()
541            + self.sum_x.capacity() * size_of::<f64>()
542            + self.sum_y.capacity() * size_of::<f64>()
543            + self.sum_xy.capacity() * size_of::<f64>()
544            + self.sum_xx.capacity() * size_of::<f64>()
545            + self.sum_yy.capacity() * size_of::<f64>()
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_accumulate_correlation_states() {
555        // Test data
556        let group_indices = vec![0, 1, 0, 1];
557        let counts = UInt64Array::from(vec![1, 2, 3, 4]);
558        let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
559        let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
560        let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
561        let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
562        let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
563
564        let mut accumulated = vec![];
565        accumulate_correlation_states(
566            &group_indices,
567            (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
568            |group_idx, count, values| {
569                accumulated.push((group_idx, count, values.to_vec()));
570            },
571        );
572
573        let expected = vec![
574            (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]),
575            (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]),
576            (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]),
577            (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]),
578        ];
579        assert_eq!(accumulated, expected);
580
581        // Test that function panics with null values
582        let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]);
583        let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
584        let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
585        let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
586        let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
587        let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
588
589        let result = std::panic::catch_unwind(|| {
590            accumulate_correlation_states(
591                &group_indices,
592                (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
593                |_, _, _| {},
594            )
595        });
596        assert!(result.is_err());
597    }
598}