1use arrow::array::Float64Array;
21use arrow::datatypes::FieldRef;
22use arrow::{
23 array::{ArrayRef, UInt64Array},
24 compute::cast,
25 datatypes::DataType,
26 datatypes::Field,
27};
28use datafusion_common::{
29 HashMap, Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err,
30};
31use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
32use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
33use datafusion_expr::type_coercion::aggregates::NUMERICS;
34use datafusion_expr::utils::format_state_name;
35use datafusion_expr::{
36 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
37};
38use std::any::Any;
39use std::fmt::Debug;
40use std::hash::Hash;
41use std::mem::size_of_val;
42use std::sync::{Arc, LazyLock};
43
44macro_rules! make_regr_udaf_expr_and_func {
45 ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
46 make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
47 create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
48 }
49}
50
51make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
52make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
53make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
54make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
55make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
56make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
57make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
58make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
59make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
60
61#[derive(PartialEq, Eq, Hash)]
62pub struct Regr {
63 signature: Signature,
64 regr_type: RegrType,
65 func_name: &'static str,
66}
67
68impl Debug for Regr {
69 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
70 f.debug_struct("regr")
71 .field("name", &self.name())
72 .field("signature", &self.signature)
73 .finish()
74 }
75}
76
77impl Regr {
78 pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
79 Self {
80 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
81 regr_type,
82 func_name,
83 }
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Hash, Eq)]
88pub enum RegrType {
89 Slope,
94 Intercept,
99 Count,
103 R2,
107 AvgX,
111 AvgY,
115 SXX,
119 SYY,
123 SXY,
127}
128
129impl RegrType {
130 fn documentation(&self) -> Option<&Documentation> {
132 get_regr_docs().get(self)
133 }
134}
135
136static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
137 let mut hash_map = HashMap::new();
138 hash_map.insert(
139 RegrType::Slope,
140 Documentation::builder(
141 DOC_SECTION_STATISTICAL,
142 "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
143 Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
144
145 "regr_slope(expression_y, expression_x)")
146 .with_sql_example(
147 r#"```sql
148create table weekly_performance(day int, user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
149select * from weekly_performance;
150+-----+--------------+
151| day | user_signups |
152+-----+--------------+
153| 1 | 60 |
154| 2 | 65 |
155| 3 | 70 |
156| 4 | 75 |
157| 5 | 80 |
158+-----+--------------+
159
160SELECT regr_slope(user_signups, day) AS slope FROM weekly_performance;
161+--------+
162| slope |
163+--------+
164| 5.0 |
165+--------+
166```
167"#
168 )
169 .with_standard_argument("expression_y", Some("Dependent variable"))
170 .with_standard_argument("expression_x", Some("Independent variable"))
171 .build()
172 );
173
174 hash_map.insert(
175 RegrType::Intercept,
176 Documentation::builder(
177 DOC_SECTION_STATISTICAL,
178 "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
179 this function returns b.",
180
181 "regr_intercept(expression_y, expression_x)")
182 .with_sql_example(
183 r#"```sql
184create table weekly_performance(week int, productivity_score int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
185select * from weekly_performance;
186+------+---------------------+
187| week | productivity_score |
188| ---- | ------------------- |
189| 1 | 60 |
190| 2 | 65 |
191| 3 | 70 |
192| 4 | 75 |
193| 5 | 80 |
194+------+---------------------+
195
196SELECT regr_intercept(productivity_score, week) AS intercept FROM weekly_performance;
197+----------+
198|intercept|
199|intercept |
200+----------+
201| 55 |
202+----------+
203```
204"#
205 )
206 .with_standard_argument("expression_y", Some("Dependent variable"))
207 .with_standard_argument("expression_x", Some("Independent variable"))
208 .build()
209 );
210
211 hash_map.insert(
212 RegrType::Count,
213 Documentation::builder(
214 DOC_SECTION_STATISTICAL,
215 "Counts the number of non-null paired data points.",
216 "regr_count(expression_y, expression_x)",
217 )
218 .with_sql_example(
219 r#"```sql
220create table daily_metrics(day int, user_signups int) as values (1,100), (2,120), (3, NULL), (4,110), (5,NULL);
221select * from daily_metrics;
222+-----+---------------+
223| day | user_signups |
224| --- | ------------- |
225| 1 | 100 |
226| 2 | 120 |
227| 3 | NULL |
228| 4 | 110 |
229| 5 | NULL |
230+-----+---------------+
231
232SELECT regr_count(user_signups, day) AS valid_pairs FROM daily_metrics;
233+-------------+
234| valid_pairs |
235+-------------+
236| 3 |
237+-------------+
238```
239"#
240 )
241 .with_standard_argument("expression_y", Some("Dependent variable"))
242 .with_standard_argument("expression_x", Some("Independent variable"))
243 .build(),
244 );
245
246 hash_map.insert(
247 RegrType::R2,
248 Documentation::builder(
249 DOC_SECTION_STATISTICAL,
250 "Computes the square of the correlation coefficient between the independent and dependent variables.",
251
252 "regr_r2(expression_y, expression_x)")
253 .with_sql_example(
254 r#"```sql
255create table weekly_performance(day int ,user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
256select * from weekly_performance;
257+-----+--------------+
258| day | user_signups |
259+-----+--------------+
260| 1 | 60 |
261| 2 | 65 |
262| 3 | 70 |
263| 4 | 75 |
264| 5 | 80 |
265+-----+--------------+
266
267SELECT regr_r2(user_signups, day) AS r_squared FROM weekly_performance;
268+---------+
269|r_squared|
270+---------+
271| 1.0 |
272+---------+
273```
274"#
275 )
276 .with_standard_argument("expression_y", Some("Dependent variable"))
277 .with_standard_argument("expression_x", Some("Independent variable"))
278 .build()
279 );
280
281 hash_map.insert(
282 RegrType::AvgX,
283 Documentation::builder(
284 DOC_SECTION_STATISTICAL,
285 "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
286
287 "regr_avgx(expression_y, expression_x)")
288 .with_sql_example(
289 r#"```sql
290create table daily_sales(day int, total_sales int) as values (1,100), (2,150), (3,200), (4,NULL), (5,250);
291select * from daily_sales;
292+-----+-------------+
293| day | total_sales |
294| --- | ----------- |
295| 1 | 100 |
296| 2 | 150 |
297| 3 | 200 |
298| 4 | NULL |
299| 5 | 250 |
300+-----+-------------+
301
302SELECT regr_avgx(total_sales, day) AS avg_day FROM daily_sales;
303+----------+
304| avg_day |
305+----------+
306| 2.75 |
307+----------+
308```
309"#
310 )
311 .with_standard_argument("expression_y", Some("Dependent variable"))
312 .with_standard_argument("expression_x", Some("Independent variable"))
313 .build()
314 );
315
316 hash_map.insert(
317 RegrType::AvgY,
318 Documentation::builder(
319 DOC_SECTION_STATISTICAL,
320 "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
321
322 "regr_avgy(expression_y, expression_x)")
323 .with_sql_example(
324 r#"```sql
325create table daily_temperature(day int, temperature int) as values (1,30), (2,32), (3, NULL), (4,35), (5,36);
326select * from daily_temperature;
327+-----+-------------+
328| day | temperature |
329| --- | ----------- |
330| 1 | 30 |
331| 2 | 32 |
332| 3 | NULL |
333| 4 | 35 |
334| 5 | 36 |
335+-----+-------------+
336
337-- temperature as Dependent Variable(Y), day as Independent Variable(X)
338SELECT regr_avgy(temperature, day) AS avg_temperature FROM daily_temperature;
339+-----------------+
340| avg_temperature |
341+-----------------+
342| 33.25 |
343+-----------------+
344```
345"#
346 )
347 .with_standard_argument("expression_y", Some("Dependent variable"))
348 .with_standard_argument("expression_x", Some("Independent variable"))
349 .build()
350 );
351
352 hash_map.insert(
353 RegrType::SXX,
354 Documentation::builder(
355 DOC_SECTION_STATISTICAL,
356 "Computes the sum of squares of the independent variable.",
357 "regr_sxx(expression_y, expression_x)",
358 )
359 .with_sql_example(
360 r#"```sql
361create table study_hours(student_id int, hours int, test_score int) as values (1,2,55), (2,4,65), (3,6,75), (4,8,85), (5,10,95);
362select * from study_hours;
363+------------+-------+------------+
364| student_id | hours | test_score |
365+------------+-------+------------+
366| 1 | 2 | 55 |
367| 2 | 4 | 65 |
368| 3 | 6 | 75 |
369| 4 | 8 | 85 |
370| 5 | 10 | 95 |
371+------------+-------+------------+
372
373SELECT regr_sxx(test_score, hours) AS sxx FROM study_hours;
374+------+
375| sxx |
376+------+
377| 40.0 |
378+------+
379```
380"#
381 )
382 .with_standard_argument("expression_y", Some("Dependent variable"))
383 .with_standard_argument("expression_x", Some("Independent variable"))
384 .build(),
385 );
386
387 hash_map.insert(
388 RegrType::SYY,
389 Documentation::builder(
390 DOC_SECTION_STATISTICAL,
391 "Computes the sum of squares of the dependent variable.",
392 "regr_syy(expression_y, expression_x)",
393 )
394 .with_sql_example(
395 r#"```sql
396create table employee_productivity(week int, productivity_score int) as values (1,60), (2,65), (3,70);
397select * from employee_productivity;
398+------+--------------------+
399| week | productivity_score |
400+------+--------------------+
401| 1 | 60 |
402| 2 | 65 |
403| 3 | 70 |
404+------+--------------------+
405
406SELECT regr_syy(productivity_score, week) AS sum_squares_y FROM employee_productivity;
407+---------------+
408| sum_squares_y |
409+---------------+
410| 50.0 |
411+---------------+
412```
413"#
414 )
415 .with_standard_argument("expression_y", Some("Dependent variable"))
416 .with_standard_argument("expression_x", Some("Independent variable"))
417 .build(),
418 );
419
420 hash_map.insert(
421 RegrType::SXY,
422 Documentation::builder(
423 DOC_SECTION_STATISTICAL,
424 "Computes the sum of products of paired data points.",
425 "regr_sxy(expression_y, expression_x)",
426 )
427 .with_sql_example(
428 r#"```sql
429create table employee_productivity(week int, productivity_score int) as values(1,60), (2,65), (3,70);
430select * from employee_productivity;
431+------+--------------------+
432| week | productivity_score |
433+------+--------------------+
434| 1 | 60 |
435| 2 | 65 |
436| 3 | 70 |
437+------+--------------------+
438
439SELECT regr_sxy(productivity_score, week) AS sum_product_deviations FROM employee_productivity;
440+------------------------+
441| sum_product_deviations |
442+------------------------+
443| 10.0 |
444+------------------------+
445```
446"#
447 )
448 .with_standard_argument("expression_y", Some("Dependent variable"))
449 .with_standard_argument("expression_x", Some("Independent variable"))
450 .build(),
451 );
452 hash_map
453});
454fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
455 &DOCUMENTATION
456}
457
458impl AggregateUDFImpl for Regr {
459 fn as_any(&self) -> &dyn Any {
460 self
461 }
462
463 fn name(&self) -> &str {
464 self.func_name
465 }
466
467 fn signature(&self) -> &Signature {
468 &self.signature
469 }
470
471 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
472 if !arg_types[0].is_numeric() {
473 return plan_err!("Covariance requires numeric input types");
474 }
475
476 if matches!(self.regr_type, RegrType::Count) {
477 Ok(DataType::UInt64)
478 } else {
479 Ok(DataType::Float64)
480 }
481 }
482
483 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
484 Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
485 }
486
487 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
488 Ok(vec![
489 Field::new(
490 format_state_name(args.name, "count"),
491 DataType::UInt64,
492 true,
493 ),
494 Field::new(
495 format_state_name(args.name, "mean_x"),
496 DataType::Float64,
497 true,
498 ),
499 Field::new(
500 format_state_name(args.name, "mean_y"),
501 DataType::Float64,
502 true,
503 ),
504 Field::new(
505 format_state_name(args.name, "m2_x"),
506 DataType::Float64,
507 true,
508 ),
509 Field::new(
510 format_state_name(args.name, "m2_y"),
511 DataType::Float64,
512 true,
513 ),
514 Field::new(
515 format_state_name(args.name, "algo_const"),
516 DataType::Float64,
517 true,
518 ),
519 ]
520 .into_iter()
521 .map(Arc::new)
522 .collect())
523 }
524
525 fn documentation(&self) -> Option<&Documentation> {
526 self.regr_type.documentation()
527 }
528}
529
530#[derive(Debug)]
570pub struct RegrAccumulator {
571 count: u64,
572 mean_x: f64,
573 mean_y: f64,
574 m2_x: f64,
575 m2_y: f64,
576 algo_const: f64,
577 regr_type: RegrType,
578}
579
580impl RegrAccumulator {
581 pub fn try_new(regr_type: &RegrType) -> Result<Self> {
583 Ok(Self {
584 count: 0_u64,
585 mean_x: 0_f64,
586 mean_y: 0_f64,
587 m2_x: 0_f64,
588 m2_y: 0_f64,
589 algo_const: 0_f64,
590 regr_type: regr_type.clone(),
591 })
592 }
593}
594
595impl Accumulator for RegrAccumulator {
596 fn state(&mut self) -> Result<Vec<ScalarValue>> {
597 Ok(vec![
598 ScalarValue::from(self.count),
599 ScalarValue::from(self.mean_x),
600 ScalarValue::from(self.mean_y),
601 ScalarValue::from(self.m2_x),
602 ScalarValue::from(self.m2_y),
603 ScalarValue::from(self.algo_const),
604 ])
605 }
606
607 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
608 let values_y = &cast(&values[0], &DataType::Float64)?;
610 let values_x = &cast(&values[1], &DataType::Float64)?;
611
612 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
613 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
614
615 for i in 0..values_y.len() {
616 let value_y = if values_y.is_valid(i) {
618 arr_y.next()
619 } else {
620 None
621 };
622 let value_x = if values_x.is_valid(i) {
623 arr_x.next()
624 } else {
625 None
626 };
627 if value_y.is_none() || value_x.is_none() {
628 continue;
629 }
630
631 let value_y = unwrap_or_internal_err!(value_y);
633 let value_x = unwrap_or_internal_err!(value_x);
634
635 self.count += 1;
636 let delta_x = value_x - self.mean_x;
637 let delta_y = value_y - self.mean_y;
638 self.mean_x += delta_x / self.count as f64;
639 self.mean_y += delta_y / self.count as f64;
640 let delta_x_2 = value_x - self.mean_x;
641 let delta_y_2 = value_y - self.mean_y;
642 self.m2_x += delta_x * delta_x_2;
643 self.m2_y += delta_y * delta_y_2;
644 self.algo_const += delta_x * (value_y - self.mean_y);
645 }
646
647 Ok(())
648 }
649
650 fn supports_retract_batch(&self) -> bool {
651 true
652 }
653
654 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
655 let values_y = &cast(&values[0], &DataType::Float64)?;
656 let values_x = &cast(&values[1], &DataType::Float64)?;
657
658 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
659 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
660
661 for i in 0..values_y.len() {
662 let value_y = if values_y.is_valid(i) {
664 arr_y.next()
665 } else {
666 None
667 };
668 let value_x = if values_x.is_valid(i) {
669 arr_x.next()
670 } else {
671 None
672 };
673 if value_y.is_none() || value_x.is_none() {
674 continue;
675 }
676
677 let value_y = unwrap_or_internal_err!(value_y);
679 let value_x = unwrap_or_internal_err!(value_x);
680
681 if self.count > 1 {
682 self.count -= 1;
683 let delta_x = value_x - self.mean_x;
684 let delta_y = value_y - self.mean_y;
685 self.mean_x -= delta_x / self.count as f64;
686 self.mean_y -= delta_y / self.count as f64;
687 let delta_x_2 = value_x - self.mean_x;
688 let delta_y_2 = value_y - self.mean_y;
689 self.m2_x -= delta_x * delta_x_2;
690 self.m2_y -= delta_y * delta_y_2;
691 self.algo_const -= delta_x * (value_y - self.mean_y);
692 } else {
693 self.count = 0;
694 self.mean_x = 0.0;
695 self.m2_x = 0.0;
696 self.m2_y = 0.0;
697 self.mean_y = 0.0;
698 self.algo_const = 0.0;
699 }
700 }
701
702 Ok(())
703 }
704
705 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
706 let count_arr = downcast_value!(states[0], UInt64Array);
707 let mean_x_arr = downcast_value!(states[1], Float64Array);
708 let mean_y_arr = downcast_value!(states[2], Float64Array);
709 let m2_x_arr = downcast_value!(states[3], Float64Array);
710 let m2_y_arr = downcast_value!(states[4], Float64Array);
711 let algo_const_arr = downcast_value!(states[5], Float64Array);
712
713 for i in 0..count_arr.len() {
714 let count_b = count_arr.value(i);
715 if count_b == 0_u64 {
716 continue;
717 }
718 let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
719 self.count,
720 self.mean_x,
721 self.mean_y,
722 self.m2_x,
723 self.m2_y,
724 self.algo_const,
725 );
726 let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
727 count_b,
728 mean_x_arr.value(i),
729 mean_y_arr.value(i),
730 m2_x_arr.value(i),
731 m2_y_arr.value(i),
732 algo_const_arr.value(i),
733 );
734
735 let count_ab = count_a + count_b;
744 let (count_a, count_b) = (count_a as f64, count_b as f64);
745 let d_x = mean_x_b - mean_x_a;
746 let d_y = mean_y_b - mean_y_a;
747 let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
748 let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
749 let m2_x_ab =
750 m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
751 let m2_y_ab =
752 m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
753 let algo_const_ab = algo_const_a
754 + algo_const_b
755 + d_x * d_y * count_a * count_b / count_ab as f64;
756
757 self.count = count_ab;
758 self.mean_x = mean_x_ab;
759 self.mean_y = mean_y_ab;
760 self.m2_x = m2_x_ab;
761 self.m2_y = m2_y_ab;
762 self.algo_const = algo_const_ab;
763 }
764 Ok(())
765 }
766
767 fn evaluate(&mut self) -> Result<ScalarValue> {
768 let cov_pop_x_y = self.algo_const / self.count as f64;
769 let var_pop_x = self.m2_x / self.count as f64;
770 let var_pop_y = self.m2_y / self.count as f64;
771
772 let nullif_or_stat = |cond: bool, stat: f64| {
773 if cond {
774 Ok(ScalarValue::Float64(None))
775 } else {
776 Ok(ScalarValue::Float64(Some(stat)))
777 }
778 };
779
780 match self.regr_type {
781 RegrType::Slope => {
782 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
784 nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
785 }
786 RegrType::Intercept => {
787 let slope = cov_pop_x_y / var_pop_x;
788 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
790 nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
791 }
792 RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
793 RegrType::R2 => {
794 let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
796 nullif_or_stat(
797 nullif_cond,
798 (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
799 )
800 }
801 RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
802 RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
803 RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
804 RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
805 RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
806 }
807 }
808
809 fn size(&self) -> usize {
810 size_of_val(self)
811 }
812}