1use std::any::Any;
21use std::fmt::Debug;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{
26 downcast_array, Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder,
27 UInt64Array,
28};
29use arrow::compute::{and, filter, is_not_null, kernels::cast};
30use arrow::datatypes::{FieldRef, Float64Type, UInt64Type};
31use arrow::{
32 array::ArrayRef,
33 datatypes::{DataType, Field},
34};
35use datafusion_expr::{EmitTo, GroupsAccumulator};
36use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple;
37use log::debug;
38
39use crate::covariance::CovarianceAccumulator;
40use crate::stddev::StddevAccumulator;
41use datafusion_common::{plan_err, Result, ScalarValue};
42use datafusion_expr::{
43 function::{AccumulatorArgs, StateFieldsArgs},
44 type_coercion::aggregates::NUMERICS,
45 utils::format_state_name,
46 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
47};
48use datafusion_functions_aggregate_common::stats::StatsType;
49use datafusion_macros::user_doc;
50
51make_udaf_expr_and_func!(
52 Correlation,
53 corr,
54 y x,
55 "Correlation between two numeric values.",
56 corr_udaf
57);
58
59#[user_doc(
60 doc_section(label = "Statistical Functions"),
61 description = "Returns the coefficient of correlation between two numeric values.",
62 syntax_example = "corr(expression1, expression2)",
63 sql_example = r#"```sql
64> SELECT corr(column1, column2) FROM table_name;
65+--------------------------------+
66| corr(column1, column2) |
67+--------------------------------+
68| 0.85 |
69+--------------------------------+
70```"#,
71 standard_argument(name = "expression1", prefix = "First"),
72 standard_argument(name = "expression2", prefix = "Second")
73)]
74#[derive(Debug)]
75pub struct Correlation {
76 signature: Signature,
77}
78
79impl Default for Correlation {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl Correlation {
86 pub fn new() -> Self {
88 Self {
89 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
90 }
91 }
92}
93
94impl AggregateUDFImpl for Correlation {
95 fn as_any(&self) -> &dyn Any {
97 self
98 }
99
100 fn name(&self) -> &str {
101 "corr"
102 }
103
104 fn signature(&self) -> &Signature {
105 &self.signature
106 }
107
108 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
109 if !arg_types[0].is_numeric() {
110 return plan_err!("Correlation requires numeric input types");
111 }
112
113 Ok(DataType::Float64)
114 }
115
116 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
117 Ok(Box::new(CorrelationAccumulator::try_new()?))
118 }
119
120 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
121 let name = args.name;
122 Ok(vec![
123 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
124 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
125 Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
126 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
127 Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
128 Field::new(
129 format_state_name(name, "algo_const"),
130 DataType::Float64,
131 true,
132 ),
133 ]
134 .into_iter()
135 .map(Arc::new)
136 .collect())
137 }
138
139 fn documentation(&self) -> Option<&Documentation> {
140 self.doc()
141 }
142
143 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
144 true
145 }
146
147 fn create_groups_accumulator(
148 &self,
149 _args: AccumulatorArgs,
150 ) -> Result<Box<dyn GroupsAccumulator>> {
151 debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`");
152 Ok(Box::new(CorrelationGroupsAccumulator::new()))
153 }
154}
155
156#[derive(Debug)]
158pub struct CorrelationAccumulator {
159 covar: CovarianceAccumulator,
160 stddev1: StddevAccumulator,
161 stddev2: StddevAccumulator,
162}
163
164impl CorrelationAccumulator {
165 pub fn try_new() -> Result<Self> {
167 Ok(Self {
168 covar: CovarianceAccumulator::try_new(StatsType::Population)?,
169 stddev1: StddevAccumulator::try_new(StatsType::Population)?,
170 stddev2: StddevAccumulator::try_new(StatsType::Population)?,
171 })
172 }
173}
174
175impl Accumulator for CorrelationAccumulator {
176 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
177 let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
183 let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
184 let values1 = filter(&values[0], &mask)?;
185 let values2 = filter(&values[1], &mask)?;
186
187 vec![values1, values2]
188 } else {
189 values.to_vec()
190 };
191
192 self.covar.update_batch(&values)?;
193 self.stddev1.update_batch(&values[0..1])?;
194 self.stddev2.update_batch(&values[1..2])?;
195 Ok(())
196 }
197
198 fn evaluate(&mut self) -> Result<ScalarValue> {
199 let covar = self.covar.evaluate()?;
200 let stddev1 = self.stddev1.evaluate()?;
201 let stddev2 = self.stddev2.evaluate()?;
202
203 if let ScalarValue::Float64(Some(c)) = covar {
204 if let ScalarValue::Float64(Some(s1)) = stddev1 {
205 if let ScalarValue::Float64(Some(s2)) = stddev2 {
206 if s1 == 0_f64 || s2 == 0_f64 {
207 return Ok(ScalarValue::Float64(Some(0_f64)));
208 } else {
209 return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
210 }
211 }
212 }
213 }
214
215 Ok(ScalarValue::Float64(None))
216 }
217
218 fn size(&self) -> usize {
219 size_of_val(self) - size_of_val(&self.covar) + self.covar.size()
220 - size_of_val(&self.stddev1)
221 + self.stddev1.size()
222 - size_of_val(&self.stddev2)
223 + self.stddev2.size()
224 }
225
226 fn state(&mut self) -> Result<Vec<ScalarValue>> {
227 Ok(vec![
228 ScalarValue::from(self.covar.get_count()),
229 ScalarValue::from(self.covar.get_mean1()),
230 ScalarValue::from(self.stddev1.get_m2()),
231 ScalarValue::from(self.covar.get_mean2()),
232 ScalarValue::from(self.stddev2.get_m2()),
233 ScalarValue::from(self.covar.get_algo_const()),
234 ])
235 }
236
237 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
238 let states_c = [
239 Arc::clone(&states[0]),
240 Arc::clone(&states[1]),
241 Arc::clone(&states[3]),
242 Arc::clone(&states[5]),
243 ];
244 let states_s1 = [
245 Arc::clone(&states[0]),
246 Arc::clone(&states[1]),
247 Arc::clone(&states[2]),
248 ];
249 let states_s2 = [
250 Arc::clone(&states[0]),
251 Arc::clone(&states[3]),
252 Arc::clone(&states[4]),
253 ];
254
255 self.covar.merge_batch(&states_c)?;
256 self.stddev1.merge_batch(&states_s1)?;
257 self.stddev2.merge_batch(&states_s2)?;
258 Ok(())
259 }
260
261 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
262 let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
263 let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
264 let values1 = filter(&values[0], &mask)?;
265 let values2 = filter(&values[1], &mask)?;
266
267 vec![values1, values2]
268 } else {
269 values.to_vec()
270 };
271
272 self.covar.retract_batch(&values)?;
273 self.stddev1.retract_batch(&values[0..1])?;
274 self.stddev2.retract_batch(&values[1..2])?;
275 Ok(())
276 }
277}
278
279#[derive(Default)]
280pub struct CorrelationGroupsAccumulator {
281 count: Vec<u64>,
285 sum_x: Vec<f64>,
287 sum_y: Vec<f64>,
289 sum_xy: Vec<f64>,
291 sum_xx: Vec<f64>,
293 sum_yy: Vec<f64>,
295}
296
297impl CorrelationGroupsAccumulator {
298 pub fn new() -> Self {
299 Default::default()
300 }
301}
302
303fn accumulate_correlation_states(
309 group_indices: &[usize],
310 state_arrays: (
311 &UInt64Array, &Float64Array, &Float64Array, &Float64Array, &Float64Array, &Float64Array, ),
318 mut value_fn: impl FnMut(usize, u64, &[f64]),
319) {
320 let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays;
321
322 assert_eq!(counts.null_count(), 0);
323 assert_eq!(sum_x.null_count(), 0);
324 assert_eq!(sum_y.null_count(), 0);
325 assert_eq!(sum_xy.null_count(), 0);
326 assert_eq!(sum_xx.null_count(), 0);
327 assert_eq!(sum_yy.null_count(), 0);
328
329 let counts_values = counts.values().as_ref();
330 let sum_x_values = sum_x.values().as_ref();
331 let sum_y_values = sum_y.values().as_ref();
332 let sum_xy_values = sum_xy.values().as_ref();
333 let sum_xx_values = sum_xx.values().as_ref();
334 let sum_yy_values = sum_yy.values().as_ref();
335
336 for (idx, &group_idx) in group_indices.iter().enumerate() {
337 let row = [
338 sum_x_values[idx],
339 sum_y_values[idx],
340 sum_xy_values[idx],
341 sum_xx_values[idx],
342 sum_yy_values[idx],
343 ];
344 value_fn(group_idx, counts_values[idx], &row);
345 }
346}
347
348impl GroupsAccumulator for CorrelationGroupsAccumulator {
364 fn update_batch(
365 &mut self,
366 values: &[ArrayRef],
367 group_indices: &[usize],
368 opt_filter: Option<&BooleanArray>,
369 total_num_groups: usize,
370 ) -> Result<()> {
371 self.count.resize(total_num_groups, 0);
372 self.sum_x.resize(total_num_groups, 0.0);
373 self.sum_y.resize(total_num_groups, 0.0);
374 self.sum_xy.resize(total_num_groups, 0.0);
375 self.sum_xx.resize(total_num_groups, 0.0);
376 self.sum_yy.resize(total_num_groups, 0.0);
377
378 let array_x = &cast(&values[0], &DataType::Float64)?;
379 let array_x = downcast_array::<Float64Array>(array_x);
380 let array_y = &cast(&values[1], &DataType::Float64)?;
381 let array_y = downcast_array::<Float64Array>(array_y);
382
383 accumulate_multiple(
384 group_indices,
385 &[&array_x, &array_y],
386 opt_filter,
387 |group_index, batch_index, columns| {
388 let x = columns[0].value(batch_index);
389 let y = columns[1].value(batch_index);
390 self.count[group_index] += 1;
391 self.sum_x[group_index] += x;
392 self.sum_y[group_index] += y;
393 self.sum_xy[group_index] += x * y;
394 self.sum_xx[group_index] += x * x;
395 self.sum_yy[group_index] += y * y;
396 },
397 );
398
399 Ok(())
400 }
401
402 fn merge_batch(
403 &mut self,
404 values: &[ArrayRef],
405 group_indices: &[usize],
406 opt_filter: Option<&BooleanArray>,
407 total_num_groups: usize,
408 ) -> Result<()> {
409 self.count.resize(total_num_groups, 0);
411 self.sum_x.resize(total_num_groups, 0.0);
412 self.sum_y.resize(total_num_groups, 0.0);
413 self.sum_xy.resize(total_num_groups, 0.0);
414 self.sum_xx.resize(total_num_groups, 0.0);
415 self.sum_yy.resize(total_num_groups, 0.0);
416
417 let partial_counts = values[0].as_primitive::<UInt64Type>();
419 let partial_sum_x = values[1].as_primitive::<Float64Type>();
420 let partial_sum_y = values[2].as_primitive::<Float64Type>();
421 let partial_sum_xy = values[3].as_primitive::<Float64Type>();
422 let partial_sum_xx = values[4].as_primitive::<Float64Type>();
423 let partial_sum_yy = values[5].as_primitive::<Float64Type>();
424
425 assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage");
426
427 accumulate_correlation_states(
428 group_indices,
429 (
430 partial_counts,
431 partial_sum_x,
432 partial_sum_y,
433 partial_sum_xy,
434 partial_sum_xx,
435 partial_sum_yy,
436 ),
437 |group_index, count, values| {
438 self.count[group_index] += count;
439 self.sum_x[group_index] += values[0];
440 self.sum_y[group_index] += values[1];
441 self.sum_xy[group_index] += values[2];
442 self.sum_xx[group_index] += values[3];
443 self.sum_yy[group_index] += values[4];
444 },
445 );
446
447 Ok(())
448 }
449
450 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
451 let n = match emit_to {
452 EmitTo::All => self.count.len(),
453 EmitTo::First(n) => n,
454 };
455
456 let mut values = Vec::with_capacity(n);
457 let mut nulls = NullBufferBuilder::new(n);
458
459 for i in 0..n {
469 if self.count[i] < 2 {
470 values.push(0.0);
472 nulls.append_null();
473 continue;
474 }
475
476 let count = self.count[i];
477 let sum_x = self.sum_x[i];
478 let sum_y = self.sum_y[i];
479 let sum_xy = self.sum_xy[i];
480 let sum_xx = self.sum_xx[i];
481 let sum_yy = self.sum_yy[i];
482
483 let mean_x = sum_x / count as f64;
484 let mean_y = sum_y / count as f64;
485
486 let numerator = sum_xy - sum_x * mean_y;
487 let denominator =
488 ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt();
489
490 if denominator == 0.0 {
491 values.push(0.0);
493 nulls.append_null();
494 } else {
495 values.push(numerator / denominator);
496 nulls.append_non_null();
497 }
498 }
499
500 Ok(Arc::new(Float64Array::new(values.into(), nulls.finish())))
501 }
502
503 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
504 let n = match emit_to {
505 EmitTo::All => self.count.len(),
506 EmitTo::First(n) => n,
507 };
508
509 Ok(vec![
510 Arc::new(UInt64Array::from(self.count[0..n].to_vec())),
511 Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())),
512 Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())),
513 Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())),
514 Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())),
515 Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())),
516 ])
517 }
518
519 fn size(&self) -> usize {
520 size_of_val(&self.count)
521 + size_of_val(&self.sum_x)
522 + size_of_val(&self.sum_y)
523 + size_of_val(&self.sum_xy)
524 + size_of_val(&self.sum_xx)
525 + size_of_val(&self.sum_yy)
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532 use arrow::array::{Float64Array, UInt64Array};
533
534 #[test]
535 fn test_accumulate_correlation_states() {
536 let group_indices = vec![0, 1, 0, 1];
538 let counts = UInt64Array::from(vec![1, 2, 3, 4]);
539 let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
540 let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
541 let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
542 let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
543 let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
544
545 let mut accumulated = vec![];
546 accumulate_correlation_states(
547 &group_indices,
548 (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
549 |group_idx, count, values| {
550 accumulated.push((group_idx, count, values.to_vec()));
551 },
552 );
553
554 let expected = vec![
555 (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]),
556 (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]),
557 (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]),
558 (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]),
559 ];
560 assert_eq!(accumulated, expected);
561
562 let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]);
564 let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
565 let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
566 let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
567 let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
568 let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
569
570 let result = std::panic::catch_unwind(|| {
571 accumulate_correlation_states(
572 &group_indices,
573 (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
574 |_, _, _| {},
575 )
576 });
577 assert!(result.is_err());
578 }
579}