1use std::any::Any;
21use std::fmt::Debug;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{
26 Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder, UInt64Array,
27 downcast_array,
28};
29use arrow::compute::{and, filter, is_not_null};
30use arrow::datatypes::{FieldRef, Float64Type, UInt64Type};
31use arrow::{
32 array::ArrayRef,
33 datatypes::{DataType, Field},
34};
35use datafusion_expr::{EmitTo, GroupsAccumulator};
36use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple;
37use log::debug;
38
39use crate::covariance::CovarianceAccumulator;
40use crate::stddev::StddevAccumulator;
41use datafusion_common::{Result, ScalarValue};
42use datafusion_expr::{
43 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
44 function::{AccumulatorArgs, StateFieldsArgs},
45 utils::format_state_name,
46};
47use datafusion_functions_aggregate_common::stats::StatsType;
48use datafusion_macros::user_doc;
49
50make_udaf_expr_and_func!(
51 Correlation,
52 corr,
53 y x,
54 "Correlation between two numeric values.",
55 corr_udaf
56);
57
58#[user_doc(
59 doc_section(label = "Statistical Functions"),
60 description = "Returns the coefficient of correlation between two numeric values.",
61 syntax_example = "corr(expression1, expression2)",
62 sql_example = r#"```sql
63> SELECT corr(column1, column2) FROM table_name;
64+--------------------------------+
65| corr(column1, column2) |
66+--------------------------------+
67| 0.85 |
68+--------------------------------+
69```"#,
70 standard_argument(name = "expression1", prefix = "First"),
71 standard_argument(name = "expression2", prefix = "Second")
72)]
73#[derive(Debug, PartialEq, Eq, Hash)]
74pub struct Correlation {
75 signature: Signature,
76}
77
78impl Default for Correlation {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl Correlation {
85 pub fn new() -> Self {
87 Self {
88 signature: Signature::exact(
89 vec![DataType::Float64, DataType::Float64],
90 Volatility::Immutable,
91 )
92 .with_parameter_names(vec!["y".to_string(), "x".to_string()])
93 .expect("valid parameter names for corr"),
94 }
95 }
96}
97
98impl AggregateUDFImpl for Correlation {
99 fn as_any(&self) -> &dyn Any {
101 self
102 }
103
104 fn name(&self) -> &str {
105 "corr"
106 }
107
108 fn signature(&self) -> &Signature {
109 &self.signature
110 }
111
112 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
113 Ok(DataType::Float64)
114 }
115
116 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
117 Ok(Box::new(CorrelationAccumulator::try_new()?))
118 }
119
120 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
121 let name = args.name;
122 Ok(vec![
123 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
124 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
125 Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
126 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
127 Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
128 Field::new(
129 format_state_name(name, "algo_const"),
130 DataType::Float64,
131 true,
132 ),
133 ]
134 .into_iter()
135 .map(Arc::new)
136 .collect())
137 }
138
139 fn documentation(&self) -> Option<&Documentation> {
140 self.doc()
141 }
142
143 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
144 true
145 }
146
147 fn create_groups_accumulator(
148 &self,
149 _args: AccumulatorArgs,
150 ) -> Result<Box<dyn GroupsAccumulator>> {
151 debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`");
152 Ok(Box::new(CorrelationGroupsAccumulator::new()))
153 }
154}
155
156#[derive(Debug)]
158pub struct CorrelationAccumulator {
159 covar: CovarianceAccumulator,
160 stddev1: StddevAccumulator,
161 stddev2: StddevAccumulator,
162}
163
164impl CorrelationAccumulator {
165 pub fn try_new() -> Result<Self> {
167 Ok(Self {
168 covar: CovarianceAccumulator::try_new(StatsType::Population)?,
169 stddev1: StddevAccumulator::try_new(StatsType::Population)?,
170 stddev2: StddevAccumulator::try_new(StatsType::Population)?,
171 })
172 }
173}
174
175impl Accumulator for CorrelationAccumulator {
176 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
177 let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
183 let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
184 let values1 = filter(&values[0], &mask)?;
185 let values2 = filter(&values[1], &mask)?;
186
187 vec![values1, values2]
188 } else {
189 values.to_vec()
190 };
191
192 self.covar.update_batch(&values)?;
193 self.stddev1.update_batch(&values[0..1])?;
194 self.stddev2.update_batch(&values[1..2])?;
195 Ok(())
196 }
197
198 fn evaluate(&mut self) -> Result<ScalarValue> {
199 let covar = self.covar.evaluate()?;
200 let stddev1 = self.stddev1.evaluate()?;
201 let stddev2 = self.stddev2.evaluate()?;
202
203 let mean1 = self.covar.get_mean1();
206 let mean2 = self.covar.get_mean2();
207
208 if mean1.is_nan() && mean2.is_nan() {
210 return Ok(ScalarValue::Float64(Some(f64::NAN)));
211 }
212 let n = self.covar.get_count();
213 if mean1.is_nan() || mean2.is_nan() || n < 2 {
214 return Ok(ScalarValue::Float64(None));
215 }
216
217 if let ScalarValue::Float64(Some(c)) = covar
218 && let ScalarValue::Float64(Some(s1)) = stddev1
219 && let ScalarValue::Float64(Some(s2)) = stddev2
220 {
221 if s1 == 0_f64 || s2 == 0_f64 {
222 return Ok(ScalarValue::Float64(None));
223 } else {
224 return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
225 }
226 }
227
228 Ok(ScalarValue::Float64(None))
229 }
230
231 fn size(&self) -> usize {
232 size_of_val(self) - size_of_val(&self.covar) + self.covar.size()
233 - size_of_val(&self.stddev1)
234 + self.stddev1.size()
235 - size_of_val(&self.stddev2)
236 + self.stddev2.size()
237 }
238
239 fn state(&mut self) -> Result<Vec<ScalarValue>> {
240 Ok(vec![
241 ScalarValue::from(self.covar.get_count()),
242 ScalarValue::from(self.covar.get_mean1()),
243 ScalarValue::from(self.stddev1.get_m2()),
244 ScalarValue::from(self.covar.get_mean2()),
245 ScalarValue::from(self.stddev2.get_m2()),
246 ScalarValue::from(self.covar.get_algo_const()),
247 ])
248 }
249
250 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
251 let states_c = [
252 Arc::clone(&states[0]),
253 Arc::clone(&states[1]),
254 Arc::clone(&states[3]),
255 Arc::clone(&states[5]),
256 ];
257 let states_s1 = [
258 Arc::clone(&states[0]),
259 Arc::clone(&states[1]),
260 Arc::clone(&states[2]),
261 ];
262 let states_s2 = [
263 Arc::clone(&states[0]),
264 Arc::clone(&states[3]),
265 Arc::clone(&states[4]),
266 ];
267
268 self.covar.merge_batch(&states_c)?;
269 self.stddev1.merge_batch(&states_s1)?;
270 self.stddev2.merge_batch(&states_s2)?;
271 Ok(())
272 }
273
274 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
275 let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
276 let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
277 let values1 = filter(&values[0], &mask)?;
278 let values2 = filter(&values[1], &mask)?;
279
280 vec![values1, values2]
281 } else {
282 values.to_vec()
283 };
284
285 self.covar.retract_batch(&values)?;
286 self.stddev1.retract_batch(&values[0..1])?;
287 self.stddev2.retract_batch(&values[1..2])?;
288 Ok(())
289 }
290}
291
292#[derive(Default)]
293pub struct CorrelationGroupsAccumulator {
294 count: Vec<u64>,
298 sum_x: Vec<f64>,
300 sum_y: Vec<f64>,
302 sum_xy: Vec<f64>,
304 sum_xx: Vec<f64>,
306 sum_yy: Vec<f64>,
308}
309
310impl CorrelationGroupsAccumulator {
311 pub fn new() -> Self {
312 Default::default()
313 }
314}
315
316fn accumulate_correlation_states(
322 group_indices: &[usize],
323 state_arrays: (
324 &UInt64Array, &Float64Array, &Float64Array, &Float64Array, &Float64Array, &Float64Array, ),
331 mut value_fn: impl FnMut(usize, u64, &[f64]),
332) {
333 let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays;
334
335 assert_eq!(counts.null_count(), 0);
336 assert_eq!(sum_x.null_count(), 0);
337 assert_eq!(sum_y.null_count(), 0);
338 assert_eq!(sum_xy.null_count(), 0);
339 assert_eq!(sum_xx.null_count(), 0);
340 assert_eq!(sum_yy.null_count(), 0);
341
342 let counts_values = counts.values().as_ref();
343 let sum_x_values = sum_x.values().as_ref();
344 let sum_y_values = sum_y.values().as_ref();
345 let sum_xy_values = sum_xy.values().as_ref();
346 let sum_xx_values = sum_xx.values().as_ref();
347 let sum_yy_values = sum_yy.values().as_ref();
348
349 for (idx, &group_idx) in group_indices.iter().enumerate() {
350 let row = [
351 sum_x_values[idx],
352 sum_y_values[idx],
353 sum_xy_values[idx],
354 sum_xx_values[idx],
355 sum_yy_values[idx],
356 ];
357 value_fn(group_idx, counts_values[idx], &row);
358 }
359}
360
361impl GroupsAccumulator for CorrelationGroupsAccumulator {
377 fn update_batch(
378 &mut self,
379 values: &[ArrayRef],
380 group_indices: &[usize],
381 opt_filter: Option<&BooleanArray>,
382 total_num_groups: usize,
383 ) -> Result<()> {
384 self.count.resize(total_num_groups, 0);
385 self.sum_x.resize(total_num_groups, 0.0);
386 self.sum_y.resize(total_num_groups, 0.0);
387 self.sum_xy.resize(total_num_groups, 0.0);
388 self.sum_xx.resize(total_num_groups, 0.0);
389 self.sum_yy.resize(total_num_groups, 0.0);
390
391 let array_x = downcast_array::<Float64Array>(&values[0]);
392 let array_y = downcast_array::<Float64Array>(&values[1]);
393
394 accumulate_multiple(
395 group_indices,
396 &[&array_x, &array_y],
397 opt_filter,
398 |group_index, batch_index, columns| {
399 let x = columns[0].value(batch_index);
400 let y = columns[1].value(batch_index);
401 self.count[group_index] += 1;
402 self.sum_x[group_index] += x;
403 self.sum_y[group_index] += y;
404 self.sum_xy[group_index] += x * y;
405 self.sum_xx[group_index] += x * x;
406 self.sum_yy[group_index] += y * y;
407 },
408 );
409
410 Ok(())
411 }
412
413 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
414 let n = match emit_to {
415 EmitTo::All => self.count.len(),
416 EmitTo::First(n) => n,
417 };
418
419 let mut values = Vec::with_capacity(n);
420 let mut nulls = NullBufferBuilder::new(n);
421
422 for i in 0..n {
432 let count = self.count[i];
433 let sum_x = self.sum_x[i];
434 let sum_y = self.sum_y[i];
435 let sum_xy = self.sum_xy[i];
436 let sum_xx = self.sum_xx[i];
437 let sum_yy = self.sum_yy[i];
438
439 if sum_x.is_nan() && sum_y.is_nan() {
442 values.push(f64::NAN);
444 nulls.append_non_null();
445 continue;
446 } else if count < 2 || sum_x.is_nan() || sum_y.is_nan() {
447 values.push(0.0);
449 nulls.append_null();
450 continue;
451 }
452
453 let mean_x = sum_x / count as f64;
454 let mean_y = sum_y / count as f64;
455
456 let numerator = sum_xy - sum_x * mean_y;
457 let denominator =
458 ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt();
459
460 if denominator == 0.0 {
461 values.push(0.0);
462 nulls.append_null();
463 } else {
464 values.push(numerator / denominator);
465 nulls.append_non_null();
466 }
467 }
468
469 Ok(Arc::new(Float64Array::new(values.into(), nulls.finish())))
470 }
471
472 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
473 let n = match emit_to {
474 EmitTo::All => self.count.len(),
475 EmitTo::First(n) => n,
476 };
477
478 Ok(vec![
479 Arc::new(UInt64Array::from(self.count[0..n].to_vec())),
480 Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())),
481 Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())),
482 Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())),
483 Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())),
484 Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())),
485 ])
486 }
487
488 fn merge_batch(
489 &mut self,
490 values: &[ArrayRef],
491 group_indices: &[usize],
492 opt_filter: Option<&BooleanArray>,
493 total_num_groups: usize,
494 ) -> Result<()> {
495 self.count.resize(total_num_groups, 0);
497 self.sum_x.resize(total_num_groups, 0.0);
498 self.sum_y.resize(total_num_groups, 0.0);
499 self.sum_xy.resize(total_num_groups, 0.0);
500 self.sum_xx.resize(total_num_groups, 0.0);
501 self.sum_yy.resize(total_num_groups, 0.0);
502
503 let partial_counts = values[0].as_primitive::<UInt64Type>();
505 let partial_sum_x = values[1].as_primitive::<Float64Type>();
506 let partial_sum_y = values[2].as_primitive::<Float64Type>();
507 let partial_sum_xy = values[3].as_primitive::<Float64Type>();
508 let partial_sum_xx = values[4].as_primitive::<Float64Type>();
509 let partial_sum_yy = values[5].as_primitive::<Float64Type>();
510
511 assert!(
512 opt_filter.is_none(),
513 "aggregate filter should be applied in partial stage, there should be no filter in final stage"
514 );
515
516 accumulate_correlation_states(
517 group_indices,
518 (
519 partial_counts,
520 partial_sum_x,
521 partial_sum_y,
522 partial_sum_xy,
523 partial_sum_xx,
524 partial_sum_yy,
525 ),
526 |group_index, count, values| {
527 self.count[group_index] += count;
528 self.sum_x[group_index] += values[0];
529 self.sum_y[group_index] += values[1];
530 self.sum_xy[group_index] += values[2];
531 self.sum_xx[group_index] += values[3];
532 self.sum_yy[group_index] += values[4];
533 },
534 );
535
536 Ok(())
537 }
538
539 fn size(&self) -> usize {
540 size_of_val(&self.count)
541 + size_of_val(&self.sum_x)
542 + size_of_val(&self.sum_y)
543 + size_of_val(&self.sum_xy)
544 + size_of_val(&self.sum_xx)
545 + size_of_val(&self.sum_yy)
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use arrow::array::{Float64Array, UInt64Array};
553
554 #[test]
555 fn test_accumulate_correlation_states() {
556 let group_indices = vec![0, 1, 0, 1];
558 let counts = UInt64Array::from(vec![1, 2, 3, 4]);
559 let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
560 let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
561 let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
562 let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
563 let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
564
565 let mut accumulated = vec![];
566 accumulate_correlation_states(
567 &group_indices,
568 (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
569 |group_idx, count, values| {
570 accumulated.push((group_idx, count, values.to_vec()));
571 },
572 );
573
574 let expected = vec![
575 (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]),
576 (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]),
577 (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]),
578 (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]),
579 ];
580 assert_eq!(accumulated, expected);
581
582 let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]);
584 let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
585 let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
586 let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
587 let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
588 let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
589
590 let result = std::panic::catch_unwind(|| {
591 accumulate_correlation_states(
592 &group_indices,
593 (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
594 |_, _, _| {},
595 )
596 });
597 assert!(result.is_err());
598 }
599}