1use std::any::Any;
21use std::fmt::Debug;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{
26 downcast_array, Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder,
27 UInt64Array,
28};
29use arrow::compute::{and, filter, is_not_null};
30use arrow::datatypes::{FieldRef, Float64Type, UInt64Type};
31use arrow::{
32 array::ArrayRef,
33 datatypes::{DataType, Field},
34};
35use datafusion_expr::{EmitTo, GroupsAccumulator};
36use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple;
37use log::debug;
38
39use crate::covariance::CovarianceAccumulator;
40use crate::stddev::StddevAccumulator;
41use datafusion_common::{Result, ScalarValue};
42use datafusion_expr::{
43 function::{AccumulatorArgs, StateFieldsArgs},
44 utils::format_state_name,
45 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
46};
47use datafusion_functions_aggregate_common::stats::StatsType;
48use datafusion_macros::user_doc;
49
50make_udaf_expr_and_func!(
51 Correlation,
52 corr,
53 y x,
54 "Correlation between two numeric values.",
55 corr_udaf
56);
57
58#[user_doc(
59 doc_section(label = "Statistical Functions"),
60 description = "Returns the coefficient of correlation between two numeric values.",
61 syntax_example = "corr(expression1, expression2)",
62 sql_example = r#"```sql
63> SELECT corr(column1, column2) FROM table_name;
64+--------------------------------+
65| corr(column1, column2) |
66+--------------------------------+
67| 0.85 |
68+--------------------------------+
69```"#,
70 standard_argument(name = "expression1", prefix = "First"),
71 standard_argument(name = "expression2", prefix = "Second")
72)]
73#[derive(Debug)]
74pub struct Correlation {
75 signature: Signature,
76}
77
78impl Default for Correlation {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl Correlation {
85 pub fn new() -> Self {
87 Self {
88 signature: Signature::exact(
89 vec![DataType::Float64, DataType::Float64],
90 Volatility::Immutable,
91 ),
92 }
93 }
94}
95
96impl AggregateUDFImpl for Correlation {
97 fn as_any(&self) -> &dyn Any {
99 self
100 }
101
102 fn name(&self) -> &str {
103 "corr"
104 }
105
106 fn signature(&self) -> &Signature {
107 &self.signature
108 }
109
110 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
111 Ok(DataType::Float64)
112 }
113
114 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
115 Ok(Box::new(CorrelationAccumulator::try_new()?))
116 }
117
118 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
119 let name = args.name;
120 Ok(vec![
121 Field::new(format_state_name(name, "count"), DataType::UInt64, true),
122 Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
123 Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
124 Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
125 Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
126 Field::new(
127 format_state_name(name, "algo_const"),
128 DataType::Float64,
129 true,
130 ),
131 ]
132 .into_iter()
133 .map(Arc::new)
134 .collect())
135 }
136
137 fn documentation(&self) -> Option<&Documentation> {
138 self.doc()
139 }
140
141 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
142 true
143 }
144
145 fn create_groups_accumulator(
146 &self,
147 _args: AccumulatorArgs,
148 ) -> Result<Box<dyn GroupsAccumulator>> {
149 debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`");
150 Ok(Box::new(CorrelationGroupsAccumulator::new()))
151 }
152}
153
154#[derive(Debug)]
156pub struct CorrelationAccumulator {
157 covar: CovarianceAccumulator,
158 stddev1: StddevAccumulator,
159 stddev2: StddevAccumulator,
160}
161
162impl CorrelationAccumulator {
163 pub fn try_new() -> Result<Self> {
165 Ok(Self {
166 covar: CovarianceAccumulator::try_new(StatsType::Population)?,
167 stddev1: StddevAccumulator::try_new(StatsType::Population)?,
168 stddev2: StddevAccumulator::try_new(StatsType::Population)?,
169 })
170 }
171}
172
173impl Accumulator for CorrelationAccumulator {
174 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
175 let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
181 let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
182 let values1 = filter(&values[0], &mask)?;
183 let values2 = filter(&values[1], &mask)?;
184
185 vec![values1, values2]
186 } else {
187 values.to_vec()
188 };
189
190 self.covar.update_batch(&values)?;
191 self.stddev1.update_batch(&values[0..1])?;
192 self.stddev2.update_batch(&values[1..2])?;
193 Ok(())
194 }
195
196 fn evaluate(&mut self) -> Result<ScalarValue> {
197 let covar = self.covar.evaluate()?;
198 let stddev1 = self.stddev1.evaluate()?;
199 let stddev2 = self.stddev2.evaluate()?;
200
201 if let ScalarValue::Float64(Some(c)) = covar {
202 if let ScalarValue::Float64(Some(s1)) = stddev1 {
203 if let ScalarValue::Float64(Some(s2)) = stddev2 {
204 if s1 == 0_f64 || s2 == 0_f64 {
205 return Ok(ScalarValue::Float64(Some(0_f64)));
206 } else {
207 return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
208 }
209 }
210 }
211 }
212
213 Ok(ScalarValue::Float64(None))
214 }
215
216 fn size(&self) -> usize {
217 size_of_val(self) - size_of_val(&self.covar) + self.covar.size()
218 - size_of_val(&self.stddev1)
219 + self.stddev1.size()
220 - size_of_val(&self.stddev2)
221 + self.stddev2.size()
222 }
223
224 fn state(&mut self) -> Result<Vec<ScalarValue>> {
225 Ok(vec![
226 ScalarValue::from(self.covar.get_count()),
227 ScalarValue::from(self.covar.get_mean1()),
228 ScalarValue::from(self.stddev1.get_m2()),
229 ScalarValue::from(self.covar.get_mean2()),
230 ScalarValue::from(self.stddev2.get_m2()),
231 ScalarValue::from(self.covar.get_algo_const()),
232 ])
233 }
234
235 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
236 let states_c = [
237 Arc::clone(&states[0]),
238 Arc::clone(&states[1]),
239 Arc::clone(&states[3]),
240 Arc::clone(&states[5]),
241 ];
242 let states_s1 = [
243 Arc::clone(&states[0]),
244 Arc::clone(&states[1]),
245 Arc::clone(&states[2]),
246 ];
247 let states_s2 = [
248 Arc::clone(&states[0]),
249 Arc::clone(&states[3]),
250 Arc::clone(&states[4]),
251 ];
252
253 self.covar.merge_batch(&states_c)?;
254 self.stddev1.merge_batch(&states_s1)?;
255 self.stddev2.merge_batch(&states_s2)?;
256 Ok(())
257 }
258
259 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
260 let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
261 let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
262 let values1 = filter(&values[0], &mask)?;
263 let values2 = filter(&values[1], &mask)?;
264
265 vec![values1, values2]
266 } else {
267 values.to_vec()
268 };
269
270 self.covar.retract_batch(&values)?;
271 self.stddev1.retract_batch(&values[0..1])?;
272 self.stddev2.retract_batch(&values[1..2])?;
273 Ok(())
274 }
275}
276
277#[derive(Default)]
278pub struct CorrelationGroupsAccumulator {
279 count: Vec<u64>,
283 sum_x: Vec<f64>,
285 sum_y: Vec<f64>,
287 sum_xy: Vec<f64>,
289 sum_xx: Vec<f64>,
291 sum_yy: Vec<f64>,
293}
294
295impl CorrelationGroupsAccumulator {
296 pub fn new() -> Self {
297 Default::default()
298 }
299}
300
301fn accumulate_correlation_states(
307 group_indices: &[usize],
308 state_arrays: (
309 &UInt64Array, &Float64Array, &Float64Array, &Float64Array, &Float64Array, &Float64Array, ),
316 mut value_fn: impl FnMut(usize, u64, &[f64]),
317) {
318 let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays;
319
320 assert_eq!(counts.null_count(), 0);
321 assert_eq!(sum_x.null_count(), 0);
322 assert_eq!(sum_y.null_count(), 0);
323 assert_eq!(sum_xy.null_count(), 0);
324 assert_eq!(sum_xx.null_count(), 0);
325 assert_eq!(sum_yy.null_count(), 0);
326
327 let counts_values = counts.values().as_ref();
328 let sum_x_values = sum_x.values().as_ref();
329 let sum_y_values = sum_y.values().as_ref();
330 let sum_xy_values = sum_xy.values().as_ref();
331 let sum_xx_values = sum_xx.values().as_ref();
332 let sum_yy_values = sum_yy.values().as_ref();
333
334 for (idx, &group_idx) in group_indices.iter().enumerate() {
335 let row = [
336 sum_x_values[idx],
337 sum_y_values[idx],
338 sum_xy_values[idx],
339 sum_xx_values[idx],
340 sum_yy_values[idx],
341 ];
342 value_fn(group_idx, counts_values[idx], &row);
343 }
344}
345
346impl GroupsAccumulator for CorrelationGroupsAccumulator {
362 fn update_batch(
363 &mut self,
364 values: &[ArrayRef],
365 group_indices: &[usize],
366 opt_filter: Option<&BooleanArray>,
367 total_num_groups: usize,
368 ) -> Result<()> {
369 self.count.resize(total_num_groups, 0);
370 self.sum_x.resize(total_num_groups, 0.0);
371 self.sum_y.resize(total_num_groups, 0.0);
372 self.sum_xy.resize(total_num_groups, 0.0);
373 self.sum_xx.resize(total_num_groups, 0.0);
374 self.sum_yy.resize(total_num_groups, 0.0);
375
376 let array_x = downcast_array::<Float64Array>(&values[0]);
377 let array_y = downcast_array::<Float64Array>(&values[1]);
378
379 accumulate_multiple(
380 group_indices,
381 &[&array_x, &array_y],
382 opt_filter,
383 |group_index, batch_index, columns| {
384 let x = columns[0].value(batch_index);
385 let y = columns[1].value(batch_index);
386 self.count[group_index] += 1;
387 self.sum_x[group_index] += x;
388 self.sum_y[group_index] += y;
389 self.sum_xy[group_index] += x * y;
390 self.sum_xx[group_index] += x * x;
391 self.sum_yy[group_index] += y * y;
392 },
393 );
394
395 Ok(())
396 }
397
398 fn merge_batch(
399 &mut self,
400 values: &[ArrayRef],
401 group_indices: &[usize],
402 opt_filter: Option<&BooleanArray>,
403 total_num_groups: usize,
404 ) -> Result<()> {
405 self.count.resize(total_num_groups, 0);
407 self.sum_x.resize(total_num_groups, 0.0);
408 self.sum_y.resize(total_num_groups, 0.0);
409 self.sum_xy.resize(total_num_groups, 0.0);
410 self.sum_xx.resize(total_num_groups, 0.0);
411 self.sum_yy.resize(total_num_groups, 0.0);
412
413 let partial_counts = values[0].as_primitive::<UInt64Type>();
415 let partial_sum_x = values[1].as_primitive::<Float64Type>();
416 let partial_sum_y = values[2].as_primitive::<Float64Type>();
417 let partial_sum_xy = values[3].as_primitive::<Float64Type>();
418 let partial_sum_xx = values[4].as_primitive::<Float64Type>();
419 let partial_sum_yy = values[5].as_primitive::<Float64Type>();
420
421 assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage");
422
423 accumulate_correlation_states(
424 group_indices,
425 (
426 partial_counts,
427 partial_sum_x,
428 partial_sum_y,
429 partial_sum_xy,
430 partial_sum_xx,
431 partial_sum_yy,
432 ),
433 |group_index, count, values| {
434 self.count[group_index] += count;
435 self.sum_x[group_index] += values[0];
436 self.sum_y[group_index] += values[1];
437 self.sum_xy[group_index] += values[2];
438 self.sum_xx[group_index] += values[3];
439 self.sum_yy[group_index] += values[4];
440 },
441 );
442
443 Ok(())
444 }
445
446 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
447 let n = match emit_to {
448 EmitTo::All => self.count.len(),
449 EmitTo::First(n) => n,
450 };
451
452 let mut values = Vec::with_capacity(n);
453 let mut nulls = NullBufferBuilder::new(n);
454
455 for i in 0..n {
465 if self.count[i] < 2 {
466 values.push(0.0);
468 nulls.append_null();
469 continue;
470 }
471
472 let count = self.count[i];
473 let sum_x = self.sum_x[i];
474 let sum_y = self.sum_y[i];
475 let sum_xy = self.sum_xy[i];
476 let sum_xx = self.sum_xx[i];
477 let sum_yy = self.sum_yy[i];
478
479 let mean_x = sum_x / count as f64;
480 let mean_y = sum_y / count as f64;
481
482 let numerator = sum_xy - sum_x * mean_y;
483 let denominator =
484 ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt();
485
486 if denominator == 0.0 {
487 values.push(0.0);
489 nulls.append_null();
490 } else {
491 values.push(numerator / denominator);
492 nulls.append_non_null();
493 }
494 }
495
496 Ok(Arc::new(Float64Array::new(values.into(), nulls.finish())))
497 }
498
499 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
500 let n = match emit_to {
501 EmitTo::All => self.count.len(),
502 EmitTo::First(n) => n,
503 };
504
505 Ok(vec![
506 Arc::new(UInt64Array::from(self.count[0..n].to_vec())),
507 Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())),
508 Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())),
509 Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())),
510 Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())),
511 Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())),
512 ])
513 }
514
515 fn size(&self) -> usize {
516 size_of_val(&self.count)
517 + size_of_val(&self.sum_x)
518 + size_of_val(&self.sum_y)
519 + size_of_val(&self.sum_xy)
520 + size_of_val(&self.sum_xx)
521 + size_of_val(&self.sum_yy)
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use arrow::array::{Float64Array, UInt64Array};
529
530 #[test]
531 fn test_accumulate_correlation_states() {
532 let group_indices = vec![0, 1, 0, 1];
534 let counts = UInt64Array::from(vec![1, 2, 3, 4]);
535 let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
536 let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
537 let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
538 let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
539 let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
540
541 let mut accumulated = vec![];
542 accumulate_correlation_states(
543 &group_indices,
544 (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
545 |group_idx, count, values| {
546 accumulated.push((group_idx, count, values.to_vec()));
547 },
548 );
549
550 let expected = vec![
551 (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]),
552 (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]),
553 (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]),
554 (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]),
555 ];
556 assert_eq!(accumulated, expected);
557
558 let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]);
560 let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
561 let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
562 let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
563 let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
564 let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
565
566 let result = std::panic::catch_unwind(|| {
567 accumulate_correlation_states(
568 &group_indices,
569 (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
570 |_, _, _| {},
571 )
572 });
573 assert!(result.is_err());
574 }
575}