1use std::fmt::Debug;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::array::{
25 Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder, UInt64Array,
26 downcast_array,
27};
28use arrow::compute::{and, filter, is_not_null};
29use arrow::datatypes::{FieldRef, Float64Type, UInt64Type};
30use arrow::{
31 array::ArrayRef,
32 datatypes::{DataType, Field},
33};
34use datafusion_expr::{EmitTo, GroupsAccumulator};
35use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple;
36use log::debug;
37
38use crate::covariance::CovarianceAccumulator;
39use crate::stddev::StddevAccumulator;
40use datafusion_common::{Result, ScalarValue};
41use datafusion_expr::{
42 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
43 function::{AccumulatorArgs, StateFieldsArgs},
44 utils::format_state_name,
45};
46use datafusion_functions_aggregate_common::stats::StatsType;
47use datafusion_macros::user_doc;
48
49make_udaf_expr_and_func!(
50 Correlation,
51 corr,
52 y x,
53 "Correlation between two numeric values.",
54 corr_udaf
55);
56
57#[user_doc(
58 doc_section(label = "Statistical Functions"),
59 description = "Returns the coefficient of correlation between two numeric values.",
60 syntax_example = "corr(expression1, expression2)",
61 sql_example = r#"```sql
62> SELECT corr(column1, column2) FROM table_name;
63+--------------------------------+
64| corr(column1, column2) |
65+--------------------------------+
66| 0.85 |
67+--------------------------------+
68```"#,
69 standard_argument(name = "expression1", prefix = "First"),
70 standard_argument(name = "expression2", prefix = "Second")
71)]
72#[derive(Debug, PartialEq, Eq, Hash)]
73pub struct Correlation {
74 signature: Signature,
75}
76
77impl Default for Correlation {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83impl Correlation {
84 pub fn new() -> Self {
86 Self {
87 signature: Signature::exact(
88 vec![DataType::Float64, DataType::Float64],
89 Volatility::Immutable,
90 )
91 .with_parameter_names(vec!["y".to_string(), "x".to_string()])
92 .expect("valid parameter names for corr"),
93 }
94 }
95}
96
97impl AggregateUDFImpl for Correlation {
98 fn name(&self) -> &str {
99 "corr"
100 }
101
102 fn signature(&self) -> &Signature {
103 &self.signature
104 }
105
106 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
107 Ok(DataType::Float64)
108 }
109
110 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
111 Ok(Box::new(CorrelationAccumulator::try_new()?))
112 }
113
114 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
115 let name = args.name;
116 Ok(vec![
117 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
118 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
119 Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
120 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
121 Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
122 Field::new(
123 format_state_name(name, "algo_const"),
124 DataType::Float64,
125 true,
126 ),
127 ]
128 .into_iter()
129 .map(Arc::new)
130 .collect())
131 }
132
133 fn documentation(&self) -> Option<&Documentation> {
134 self.doc()
135 }
136
137 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
138 true
139 }
140
141 fn create_groups_accumulator(
142 &self,
143 _args: AccumulatorArgs,
144 ) -> Result<Box<dyn GroupsAccumulator>> {
145 debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`");
146 Ok(Box::new(CorrelationGroupsAccumulator::new()))
147 }
148}
149
150#[derive(Debug)]
152pub struct CorrelationAccumulator {
153 covar: CovarianceAccumulator,
154 stddev1: StddevAccumulator,
155 stddev2: StddevAccumulator,
156}
157
158impl CorrelationAccumulator {
159 pub fn try_new() -> Result<Self> {
161 Ok(Self {
162 covar: CovarianceAccumulator::try_new(StatsType::Population)?,
163 stddev1: StddevAccumulator::try_new(StatsType::Population)?,
164 stddev2: StddevAccumulator::try_new(StatsType::Population)?,
165 })
166 }
167}
168
169impl Accumulator for CorrelationAccumulator {
170 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
171 let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
177 let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
178 let values1 = filter(&values[0], &mask)?;
179 let values2 = filter(&values[1], &mask)?;
180
181 vec![values1, values2]
182 } else {
183 values.to_vec()
184 };
185
186 self.covar.update_batch(&values)?;
187 self.stddev1.update_batch(&values[0..1])?;
188 self.stddev2.update_batch(&values[1..2])?;
189 Ok(())
190 }
191
192 fn evaluate(&mut self) -> Result<ScalarValue> {
193 let covar = self.covar.evaluate()?;
194 let stddev1 = self.stddev1.evaluate()?;
195 let stddev2 = self.stddev2.evaluate()?;
196
197 let mean1 = self.covar.get_mean1();
200 let mean2 = self.covar.get_mean2();
201
202 if mean1.is_nan() && mean2.is_nan() {
204 return Ok(ScalarValue::Float64(Some(f64::NAN)));
205 }
206 let n = self.covar.get_count();
207 if mean1.is_nan() || mean2.is_nan() || n < 2 {
208 return Ok(ScalarValue::Float64(None));
209 }
210
211 if let ScalarValue::Float64(Some(c)) = covar
212 && let ScalarValue::Float64(Some(s1)) = stddev1
213 && let ScalarValue::Float64(Some(s2)) = stddev2
214 {
215 if s1 == 0_f64 || s2 == 0_f64 {
216 return Ok(ScalarValue::Float64(None));
217 } else {
218 return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
219 }
220 }
221
222 Ok(ScalarValue::Float64(None))
223 }
224
225 fn size(&self) -> usize {
226 size_of_val(self) - size_of_val(&self.covar) + self.covar.size()
227 - size_of_val(&self.stddev1)
228 + self.stddev1.size()
229 - size_of_val(&self.stddev2)
230 + self.stddev2.size()
231 }
232
233 fn state(&mut self) -> Result<Vec<ScalarValue>> {
234 Ok(vec![
235 ScalarValue::from(self.covar.get_count()),
236 ScalarValue::from(self.covar.get_mean1()),
237 ScalarValue::from(self.stddev1.get_m2()),
238 ScalarValue::from(self.covar.get_mean2()),
239 ScalarValue::from(self.stddev2.get_m2()),
240 ScalarValue::from(self.covar.get_algo_const()),
241 ])
242 }
243
244 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
245 let states_c = [
246 Arc::clone(&states[0]),
247 Arc::clone(&states[1]),
248 Arc::clone(&states[3]),
249 Arc::clone(&states[5]),
250 ];
251 let states_s1 = [
252 Arc::clone(&states[0]),
253 Arc::clone(&states[1]),
254 Arc::clone(&states[2]),
255 ];
256 let states_s2 = [
257 Arc::clone(&states[0]),
258 Arc::clone(&states[3]),
259 Arc::clone(&states[4]),
260 ];
261
262 self.covar.merge_batch(&states_c)?;
263 self.stddev1.merge_batch(&states_s1)?;
264 self.stddev2.merge_batch(&states_s2)?;
265 Ok(())
266 }
267
268 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
269 let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
270 let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
271 let values1 = filter(&values[0], &mask)?;
272 let values2 = filter(&values[1], &mask)?;
273
274 vec![values1, values2]
275 } else {
276 values.to_vec()
277 };
278
279 self.covar.retract_batch(&values)?;
280 self.stddev1.retract_batch(&values[0..1])?;
281 self.stddev2.retract_batch(&values[1..2])?;
282 Ok(())
283 }
284}
285
286#[derive(Default)]
287pub struct CorrelationGroupsAccumulator {
288 count: Vec<u64>,
292 sum_x: Vec<f64>,
294 sum_y: Vec<f64>,
296 sum_xy: Vec<f64>,
298 sum_xx: Vec<f64>,
300 sum_yy: Vec<f64>,
302}
303
304impl CorrelationGroupsAccumulator {
305 pub fn new() -> Self {
306 Default::default()
307 }
308}
309
310fn accumulate_correlation_states(
316 group_indices: &[usize],
317 state_arrays: (
318 &UInt64Array, &Float64Array, &Float64Array, &Float64Array, &Float64Array, &Float64Array, ),
325 mut value_fn: impl FnMut(usize, u64, &[f64]),
326) {
327 let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays;
328
329 assert_eq!(counts.null_count(), 0);
330 assert_eq!(sum_x.null_count(), 0);
331 assert_eq!(sum_y.null_count(), 0);
332 assert_eq!(sum_xy.null_count(), 0);
333 assert_eq!(sum_xx.null_count(), 0);
334 assert_eq!(sum_yy.null_count(), 0);
335
336 let counts_values = counts.values().as_ref();
337 let sum_x_values = sum_x.values().as_ref();
338 let sum_y_values = sum_y.values().as_ref();
339 let sum_xy_values = sum_xy.values().as_ref();
340 let sum_xx_values = sum_xx.values().as_ref();
341 let sum_yy_values = sum_yy.values().as_ref();
342
343 for (idx, &group_idx) in group_indices.iter().enumerate() {
344 let row = [
345 sum_x_values[idx],
346 sum_y_values[idx],
347 sum_xy_values[idx],
348 sum_xx_values[idx],
349 sum_yy_values[idx],
350 ];
351 value_fn(group_idx, counts_values[idx], &row);
352 }
353}
354
355impl GroupsAccumulator for CorrelationGroupsAccumulator {
371 fn update_batch(
372 &mut self,
373 values: &[ArrayRef],
374 group_indices: &[usize],
375 opt_filter: Option<&BooleanArray>,
376 total_num_groups: usize,
377 ) -> Result<()> {
378 self.count.resize(total_num_groups, 0);
379 self.sum_x.resize(total_num_groups, 0.0);
380 self.sum_y.resize(total_num_groups, 0.0);
381 self.sum_xy.resize(total_num_groups, 0.0);
382 self.sum_xx.resize(total_num_groups, 0.0);
383 self.sum_yy.resize(total_num_groups, 0.0);
384
385 let array_x = downcast_array::<Float64Array>(&values[0]);
386 let array_y = downcast_array::<Float64Array>(&values[1]);
387
388 accumulate_multiple(
389 group_indices,
390 &[&array_x, &array_y],
391 opt_filter,
392 |group_index, batch_index, columns| {
393 let x = columns[0].value(batch_index);
394 let y = columns[1].value(batch_index);
395 self.count[group_index] += 1;
396 self.sum_x[group_index] += x;
397 self.sum_y[group_index] += y;
398 self.sum_xy[group_index] += x * y;
399 self.sum_xx[group_index] += x * x;
400 self.sum_yy[group_index] += y * y;
401 },
402 );
403
404 Ok(())
405 }
406
407 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
408 let counts = emit_to.take_needed(&mut self.count);
410 let sum_xs = emit_to.take_needed(&mut self.sum_x);
411 let sum_ys = emit_to.take_needed(&mut self.sum_y);
412 let sum_xys = emit_to.take_needed(&mut self.sum_xy);
413 let sum_xxs = emit_to.take_needed(&mut self.sum_xx);
414 let sum_yys = emit_to.take_needed(&mut self.sum_yy);
415
416 let n = counts.len();
417 let mut values = Vec::with_capacity(n);
418 let mut nulls = NullBufferBuilder::new(n);
419
420 for i in 0..n {
429 let count = counts[i];
430 let sum_x = sum_xs[i];
431 let sum_y = sum_ys[i];
432 let sum_xy = sum_xys[i];
433 let sum_xx = sum_xxs[i];
434 let sum_yy = sum_yys[i];
435
436 if sum_x.is_nan() && sum_y.is_nan() {
439 values.push(f64::NAN);
441 nulls.append_non_null();
442 continue;
443 } else if count < 2 || sum_x.is_nan() || sum_y.is_nan() {
444 values.push(0.0);
446 nulls.append_null();
447 continue;
448 }
449
450 let mean_x = sum_x / count as f64;
451 let mean_y = sum_y / count as f64;
452
453 let numerator = sum_xy - sum_x * mean_y;
454 let denominator =
455 ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt();
456
457 if denominator == 0.0 {
458 values.push(0.0);
459 nulls.append_null();
460 } else {
461 values.push(numerator / denominator);
462 nulls.append_non_null();
463 }
464 }
465
466 Ok(Arc::new(Float64Array::new(values.into(), nulls.finish())))
467 }
468
469 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
470 let count = emit_to.take_needed(&mut self.count);
472 let sum_x = emit_to.take_needed(&mut self.sum_x);
473 let sum_y = emit_to.take_needed(&mut self.sum_y);
474 let sum_xy = emit_to.take_needed(&mut self.sum_xy);
475 let sum_xx = emit_to.take_needed(&mut self.sum_xx);
476 let sum_yy = emit_to.take_needed(&mut self.sum_yy);
477
478 Ok(vec![
479 Arc::new(UInt64Array::from(count)),
480 Arc::new(Float64Array::from(sum_x)),
481 Arc::new(Float64Array::from(sum_y)),
482 Arc::new(Float64Array::from(sum_xy)),
483 Arc::new(Float64Array::from(sum_xx)),
484 Arc::new(Float64Array::from(sum_yy)),
485 ])
486 }
487
488 fn merge_batch(
489 &mut self,
490 values: &[ArrayRef],
491 group_indices: &[usize],
492 opt_filter: Option<&BooleanArray>,
493 total_num_groups: usize,
494 ) -> Result<()> {
495 self.count.resize(total_num_groups, 0);
497 self.sum_x.resize(total_num_groups, 0.0);
498 self.sum_y.resize(total_num_groups, 0.0);
499 self.sum_xy.resize(total_num_groups, 0.0);
500 self.sum_xx.resize(total_num_groups, 0.0);
501 self.sum_yy.resize(total_num_groups, 0.0);
502
503 let partial_counts = values[0].as_primitive::<UInt64Type>();
505 let partial_sum_x = values[1].as_primitive::<Float64Type>();
506 let partial_sum_y = values[2].as_primitive::<Float64Type>();
507 let partial_sum_xy = values[3].as_primitive::<Float64Type>();
508 let partial_sum_xx = values[4].as_primitive::<Float64Type>();
509 let partial_sum_yy = values[5].as_primitive::<Float64Type>();
510
511 assert!(
512 opt_filter.is_none(),
513 "aggregate filter should be applied in partial stage, there should be no filter in final stage"
514 );
515
516 accumulate_correlation_states(
517 group_indices,
518 (
519 partial_counts,
520 partial_sum_x,
521 partial_sum_y,
522 partial_sum_xy,
523 partial_sum_xx,
524 partial_sum_yy,
525 ),
526 |group_index, count, values| {
527 self.count[group_index] += count;
528 self.sum_x[group_index] += values[0];
529 self.sum_y[group_index] += values[1];
530 self.sum_xy[group_index] += values[2];
531 self.sum_xx[group_index] += values[3];
532 self.sum_yy[group_index] += values[4];
533 },
534 );
535
536 Ok(())
537 }
538
539 fn size(&self) -> usize {
540 self.count.capacity() * size_of::<u64>()
541 + self.sum_x.capacity() * size_of::<f64>()
542 + self.sum_y.capacity() * size_of::<f64>()
543 + self.sum_xy.capacity() * size_of::<f64>()
544 + self.sum_xx.capacity() * size_of::<f64>()
545 + self.sum_yy.capacity() * size_of::<f64>()
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_accumulate_correlation_states() {
555 let group_indices = vec![0, 1, 0, 1];
557 let counts = UInt64Array::from(vec![1, 2, 3, 4]);
558 let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
559 let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
560 let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
561 let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
562 let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
563
564 let mut accumulated = vec![];
565 accumulate_correlation_states(
566 &group_indices,
567 (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
568 |group_idx, count, values| {
569 accumulated.push((group_idx, count, values.to_vec()));
570 },
571 );
572
573 let expected = vec![
574 (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]),
575 (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]),
576 (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]),
577 (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]),
578 ];
579 assert_eq!(accumulated, expected);
580
581 let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]);
583 let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
584 let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
585 let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
586 let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
587 let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
588
589 let result = std::panic::catch_unwind(|| {
590 accumulate_correlation_states(
591 &group_indices,
592 (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
593 |_, _, _| {},
594 )
595 });
596 assert!(result.is_err());
597 }
598}