1use arrow::array::Float64Array;
21use arrow::datatypes::FieldRef;
22use arrow::{
23 array::{ArrayRef, UInt64Array},
24 compute::cast,
25 datatypes::DataType,
26 datatypes::Field,
27};
28use datafusion_common::{
29 downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result,
30 ScalarValue,
31};
32use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
33use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
34use datafusion_expr::type_coercion::aggregates::NUMERICS;
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::{
37 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
38};
39use std::any::Any;
40use std::fmt::Debug;
41use std::hash::Hash;
42use std::mem::size_of_val;
43use std::sync::{Arc, LazyLock};
44
45macro_rules! make_regr_udaf_expr_and_func {
46 ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
47 make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
48 create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
49 }
50}
51
52make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
53make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
54make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
55make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
56make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
57make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
58make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
59make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
60make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
61
62#[derive(PartialEq, Eq, Hash)]
63pub struct Regr {
64 signature: Signature,
65 regr_type: RegrType,
66 func_name: &'static str,
67}
68
69impl Debug for Regr {
70 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
71 f.debug_struct("regr")
72 .field("name", &self.name())
73 .field("signature", &self.signature)
74 .finish()
75 }
76}
77
78impl Regr {
79 pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
80 Self {
81 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
82 regr_type,
83 func_name,
84 }
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Hash, Eq)]
89#[allow(clippy::upper_case_acronyms)]
90pub enum RegrType {
91 Slope,
96 Intercept,
101 Count,
105 R2,
109 AvgX,
113 AvgY,
117 SXX,
121 SYY,
125 SXY,
129}
130
131impl RegrType {
132 fn documentation(&self) -> Option<&Documentation> {
134 get_regr_docs().get(self)
135 }
136}
137
138static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
139 let mut hash_map = HashMap::new();
140 hash_map.insert(
141 RegrType::Slope,
142 Documentation::builder(
143 DOC_SECTION_STATISTICAL,
144 "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
145 Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
146
147 "regr_slope(expression_y, expression_x)")
148 .with_sql_example(
149 r#"```sql
150create table weekly_performance(day int, user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
151select * from weekly_performance;
152+-----+--------------+
153| day | user_signups |
154+-----+--------------+
155| 1 | 60 |
156| 2 | 65 |
157| 3 | 70 |
158| 4 | 75 |
159| 5 | 80 |
160+-----+--------------+
161
162SELECT regr_slope(user_signups, day) AS slope FROM weekly_performance;
163+--------+
164| slope |
165+--------+
166| 5.0 |
167+--------+
168```
169"#
170 )
171 .with_standard_argument("expression_y", Some("Dependent variable"))
172 .with_standard_argument("expression_x", Some("Independent variable"))
173 .build()
174 );
175
176 hash_map.insert(
177 RegrType::Intercept,
178 Documentation::builder(
179 DOC_SECTION_STATISTICAL,
180 "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
181 this function returns b.",
182
183 "regr_intercept(expression_y, expression_x)")
184 .with_sql_example(
185 r#"```sql
186create table weekly_performance(week int, productivity_score int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
187select * from weekly_performance;
188+------+---------------------+
189| week | productivity_score |
190| ---- | ------------------- |
191| 1 | 60 |
192| 2 | 65 |
193| 3 | 70 |
194| 4 | 75 |
195| 5 | 80 |
196+------+---------------------+
197
198SELECT regr_intercept(productivity_score, week) AS intercept FROM weekly_performance;
199+----------+
200|intercept|
201|intercept |
202+----------+
203| 55 |
204+----------+
205```
206"#
207 )
208 .with_standard_argument("expression_y", Some("Dependent variable"))
209 .with_standard_argument("expression_x", Some("Independent variable"))
210 .build()
211 );
212
213 hash_map.insert(
214 RegrType::Count,
215 Documentation::builder(
216 DOC_SECTION_STATISTICAL,
217 "Counts the number of non-null paired data points.",
218 "regr_count(expression_y, expression_x)",
219 )
220 .with_sql_example(
221 r#"```sql
222create table daily_metrics(day int, user_signups int) as values (1,100), (2,120), (3, NULL), (4,110), (5,NULL);
223select * from daily_metrics;
224+-----+---------------+
225| day | user_signups |
226| --- | ------------- |
227| 1 | 100 |
228| 2 | 120 |
229| 3 | NULL |
230| 4 | 110 |
231| 5 | NULL |
232+-----+---------------+
233
234SELECT regr_count(user_signups, day) AS valid_pairs FROM daily_metrics;
235+-------------+
236| valid_pairs |
237+-------------+
238| 3 |
239+-------------+
240```
241"#
242 )
243 .with_standard_argument("expression_y", Some("Dependent variable"))
244 .with_standard_argument("expression_x", Some("Independent variable"))
245 .build(),
246 );
247
248 hash_map.insert(
249 RegrType::R2,
250 Documentation::builder(
251 DOC_SECTION_STATISTICAL,
252 "Computes the square of the correlation coefficient between the independent and dependent variables.",
253
254 "regr_r2(expression_y, expression_x)")
255 .with_sql_example(
256 r#"```sql
257create table weekly_performance(day int ,user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
258select * from weekly_performance;
259+-----+--------------+
260| day | user_signups |
261+-----+--------------+
262| 1 | 60 |
263| 2 | 65 |
264| 3 | 70 |
265| 4 | 75 |
266| 5 | 80 |
267+-----+--------------+
268
269SELECT regr_r2(user_signups, day) AS r_squared FROM weekly_performance;
270+---------+
271|r_squared|
272+---------+
273| 1.0 |
274+---------+
275```
276"#
277 )
278 .with_standard_argument("expression_y", Some("Dependent variable"))
279 .with_standard_argument("expression_x", Some("Independent variable"))
280 .build()
281 );
282
283 hash_map.insert(
284 RegrType::AvgX,
285 Documentation::builder(
286 DOC_SECTION_STATISTICAL,
287 "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
288
289 "regr_avgx(expression_y, expression_x)")
290 .with_sql_example(
291 r#"```sql
292create table daily_sales(day int, total_sales int) as values (1,100), (2,150), (3,200), (4,NULL), (5,250);
293select * from daily_sales;
294+-----+-------------+
295| day | total_sales |
296| --- | ----------- |
297| 1 | 100 |
298| 2 | 150 |
299| 3 | 200 |
300| 4 | NULL |
301| 5 | 250 |
302+-----+-------------+
303
304SELECT regr_avgx(total_sales, day) AS avg_day FROM daily_sales;
305+----------+
306| avg_day |
307+----------+
308| 2.75 |
309+----------+
310```
311"#
312 )
313 .with_standard_argument("expression_y", Some("Dependent variable"))
314 .with_standard_argument("expression_x", Some("Independent variable"))
315 .build()
316 );
317
318 hash_map.insert(
319 RegrType::AvgY,
320 Documentation::builder(
321 DOC_SECTION_STATISTICAL,
322 "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
323
324 "regr_avgy(expression_y, expression_x)")
325 .with_sql_example(
326 r#"```sql
327create table daily_temperature(day int, temperature int) as values (1,30), (2,32), (3, NULL), (4,35), (5,36);
328select * from daily_temperature;
329+-----+-------------+
330| day | temperature |
331| --- | ----------- |
332| 1 | 30 |
333| 2 | 32 |
334| 3 | NULL |
335| 4 | 35 |
336| 5 | 36 |
337+-----+-------------+
338
339-- temperature as Dependent Variable(Y), day as Independent Variable(X)
340SELECT regr_avgy(temperature, day) AS avg_temperature FROM daily_temperature;
341+-----------------+
342| avg_temperature |
343+-----------------+
344| 33.25 |
345+-----------------+
346```
347"#
348 )
349 .with_standard_argument("expression_y", Some("Dependent variable"))
350 .with_standard_argument("expression_x", Some("Independent variable"))
351 .build()
352 );
353
354 hash_map.insert(
355 RegrType::SXX,
356 Documentation::builder(
357 DOC_SECTION_STATISTICAL,
358 "Computes the sum of squares of the independent variable.",
359 "regr_sxx(expression_y, expression_x)",
360 )
361 .with_sql_example(
362 r#"```sql
363create table study_hours(student_id int, hours int, test_score int) as values (1,2,55), (2,4,65), (3,6,75), (4,8,85), (5,10,95);
364select * from study_hours;
365+------------+-------+------------+
366| student_id | hours | test_score |
367+------------+-------+------------+
368| 1 | 2 | 55 |
369| 2 | 4 | 65 |
370| 3 | 6 | 75 |
371| 4 | 8 | 85 |
372| 5 | 10 | 95 |
373+------------+-------+------------+
374
375SELECT regr_sxx(test_score, hours) AS sxx FROM study_hours;
376+------+
377| sxx |
378+------+
379| 40.0 |
380+------+
381```
382"#
383 )
384 .with_standard_argument("expression_y", Some("Dependent variable"))
385 .with_standard_argument("expression_x", Some("Independent variable"))
386 .build(),
387 );
388
389 hash_map.insert(
390 RegrType::SYY,
391 Documentation::builder(
392 DOC_SECTION_STATISTICAL,
393 "Computes the sum of squares of the dependent variable.",
394 "regr_syy(expression_y, expression_x)",
395 )
396 .with_sql_example(
397 r#"```sql
398create table employee_productivity(week int, productivity_score int) as values (1,60), (2,65), (3,70);
399select * from employee_productivity;
400+------+--------------------+
401| week | productivity_score |
402+------+--------------------+
403| 1 | 60 |
404| 2 | 65 |
405| 3 | 70 |
406+------+--------------------+
407
408SELECT regr_syy(productivity_score, week) AS sum_squares_y FROM employee_productivity;
409+---------------+
410| sum_squares_y |
411+---------------+
412| 50.0 |
413+---------------+
414```
415"#
416 )
417 .with_standard_argument("expression_y", Some("Dependent variable"))
418 .with_standard_argument("expression_x", Some("Independent variable"))
419 .build(),
420 );
421
422 hash_map.insert(
423 RegrType::SXY,
424 Documentation::builder(
425 DOC_SECTION_STATISTICAL,
426 "Computes the sum of products of paired data points.",
427 "regr_sxy(expression_y, expression_x)",
428 )
429 .with_sql_example(
430 r#"```sql
431create table employee_productivity(week int, productivity_score int) as values(1,60), (2,65), (3,70);
432select * from employee_productivity;
433+------+--------------------+
434| week | productivity_score |
435+------+--------------------+
436| 1 | 60 |
437| 2 | 65 |
438| 3 | 70 |
439+------+--------------------+
440
441SELECT regr_sxy(productivity_score, week) AS sum_product_deviations FROM employee_productivity;
442+------------------------+
443| sum_product_deviations |
444+------------------------+
445| 10.0 |
446+------------------------+
447```
448"#
449 )
450 .with_standard_argument("expression_y", Some("Dependent variable"))
451 .with_standard_argument("expression_x", Some("Independent variable"))
452 .build(),
453 );
454 hash_map
455});
456fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
457 &DOCUMENTATION
458}
459
460impl AggregateUDFImpl for Regr {
461 fn as_any(&self) -> &dyn Any {
462 self
463 }
464
465 fn name(&self) -> &str {
466 self.func_name
467 }
468
469 fn signature(&self) -> &Signature {
470 &self.signature
471 }
472
473 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
474 if !arg_types[0].is_numeric() {
475 return plan_err!("Covariance requires numeric input types");
476 }
477
478 if matches!(self.regr_type, RegrType::Count) {
479 Ok(DataType::UInt64)
480 } else {
481 Ok(DataType::Float64)
482 }
483 }
484
485 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
486 Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
487 }
488
489 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
490 Ok(vec![
491 Field::new(
492 format_state_name(args.name, "count"),
493 DataType::UInt64,
494 true,
495 ),
496 Field::new(
497 format_state_name(args.name, "mean_x"),
498 DataType::Float64,
499 true,
500 ),
501 Field::new(
502 format_state_name(args.name, "mean_y"),
503 DataType::Float64,
504 true,
505 ),
506 Field::new(
507 format_state_name(args.name, "m2_x"),
508 DataType::Float64,
509 true,
510 ),
511 Field::new(
512 format_state_name(args.name, "m2_y"),
513 DataType::Float64,
514 true,
515 ),
516 Field::new(
517 format_state_name(args.name, "algo_const"),
518 DataType::Float64,
519 true,
520 ),
521 ]
522 .into_iter()
523 .map(Arc::new)
524 .collect())
525 }
526
527 fn documentation(&self) -> Option<&Documentation> {
528 self.regr_type.documentation()
529 }
530}
531
532#[derive(Debug)]
572pub struct RegrAccumulator {
573 count: u64,
574 mean_x: f64,
575 mean_y: f64,
576 m2_x: f64,
577 m2_y: f64,
578 algo_const: f64,
579 regr_type: RegrType,
580}
581
582impl RegrAccumulator {
583 pub fn try_new(regr_type: &RegrType) -> Result<Self> {
585 Ok(Self {
586 count: 0_u64,
587 mean_x: 0_f64,
588 mean_y: 0_f64,
589 m2_x: 0_f64,
590 m2_y: 0_f64,
591 algo_const: 0_f64,
592 regr_type: regr_type.clone(),
593 })
594 }
595}
596
597impl Accumulator for RegrAccumulator {
598 fn state(&mut self) -> Result<Vec<ScalarValue>> {
599 Ok(vec![
600 ScalarValue::from(self.count),
601 ScalarValue::from(self.mean_x),
602 ScalarValue::from(self.mean_y),
603 ScalarValue::from(self.m2_x),
604 ScalarValue::from(self.m2_y),
605 ScalarValue::from(self.algo_const),
606 ])
607 }
608
609 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
610 let values_y = &cast(&values[0], &DataType::Float64)?;
612 let values_x = &cast(&values[1], &DataType::Float64)?;
613
614 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
615 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
616
617 for i in 0..values_y.len() {
618 let value_y = if values_y.is_valid(i) {
620 arr_y.next()
621 } else {
622 None
623 };
624 let value_x = if values_x.is_valid(i) {
625 arr_x.next()
626 } else {
627 None
628 };
629 if value_y.is_none() || value_x.is_none() {
630 continue;
631 }
632
633 let value_y = unwrap_or_internal_err!(value_y);
635 let value_x = unwrap_or_internal_err!(value_x);
636
637 self.count += 1;
638 let delta_x = value_x - self.mean_x;
639 let delta_y = value_y - self.mean_y;
640 self.mean_x += delta_x / self.count as f64;
641 self.mean_y += delta_y / self.count as f64;
642 let delta_x_2 = value_x - self.mean_x;
643 let delta_y_2 = value_y - self.mean_y;
644 self.m2_x += delta_x * delta_x_2;
645 self.m2_y += delta_y * delta_y_2;
646 self.algo_const += delta_x * (value_y - self.mean_y);
647 }
648
649 Ok(())
650 }
651
652 fn supports_retract_batch(&self) -> bool {
653 true
654 }
655
656 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
657 let values_y = &cast(&values[0], &DataType::Float64)?;
658 let values_x = &cast(&values[1], &DataType::Float64)?;
659
660 let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
661 let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
662
663 for i in 0..values_y.len() {
664 let value_y = if values_y.is_valid(i) {
666 arr_y.next()
667 } else {
668 None
669 };
670 let value_x = if values_x.is_valid(i) {
671 arr_x.next()
672 } else {
673 None
674 };
675 if value_y.is_none() || value_x.is_none() {
676 continue;
677 }
678
679 let value_y = unwrap_or_internal_err!(value_y);
681 let value_x = unwrap_or_internal_err!(value_x);
682
683 if self.count > 1 {
684 self.count -= 1;
685 let delta_x = value_x - self.mean_x;
686 let delta_y = value_y - self.mean_y;
687 self.mean_x -= delta_x / self.count as f64;
688 self.mean_y -= delta_y / self.count as f64;
689 let delta_x_2 = value_x - self.mean_x;
690 let delta_y_2 = value_y - self.mean_y;
691 self.m2_x -= delta_x * delta_x_2;
692 self.m2_y -= delta_y * delta_y_2;
693 self.algo_const -= delta_x * (value_y - self.mean_y);
694 } else {
695 self.count = 0;
696 self.mean_x = 0.0;
697 self.m2_x = 0.0;
698 self.m2_y = 0.0;
699 self.mean_y = 0.0;
700 self.algo_const = 0.0;
701 }
702 }
703
704 Ok(())
705 }
706
707 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
708 let count_arr = downcast_value!(states[0], UInt64Array);
709 let mean_x_arr = downcast_value!(states[1], Float64Array);
710 let mean_y_arr = downcast_value!(states[2], Float64Array);
711 let m2_x_arr = downcast_value!(states[3], Float64Array);
712 let m2_y_arr = downcast_value!(states[4], Float64Array);
713 let algo_const_arr = downcast_value!(states[5], Float64Array);
714
715 for i in 0..count_arr.len() {
716 let count_b = count_arr.value(i);
717 if count_b == 0_u64 {
718 continue;
719 }
720 let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
721 self.count,
722 self.mean_x,
723 self.mean_y,
724 self.m2_x,
725 self.m2_y,
726 self.algo_const,
727 );
728 let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
729 count_b,
730 mean_x_arr.value(i),
731 mean_y_arr.value(i),
732 m2_x_arr.value(i),
733 m2_y_arr.value(i),
734 algo_const_arr.value(i),
735 );
736
737 let count_ab = count_a + count_b;
746 let (count_a, count_b) = (count_a as f64, count_b as f64);
747 let d_x = mean_x_b - mean_x_a;
748 let d_y = mean_y_b - mean_y_a;
749 let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
750 let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
751 let m2_x_ab =
752 m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
753 let m2_y_ab =
754 m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
755 let algo_const_ab = algo_const_a
756 + algo_const_b
757 + d_x * d_y * count_a * count_b / count_ab as f64;
758
759 self.count = count_ab;
760 self.mean_x = mean_x_ab;
761 self.mean_y = mean_y_ab;
762 self.m2_x = m2_x_ab;
763 self.m2_y = m2_y_ab;
764 self.algo_const = algo_const_ab;
765 }
766 Ok(())
767 }
768
769 fn evaluate(&mut self) -> Result<ScalarValue> {
770 let cov_pop_x_y = self.algo_const / self.count as f64;
771 let var_pop_x = self.m2_x / self.count as f64;
772 let var_pop_y = self.m2_y / self.count as f64;
773
774 let nullif_or_stat = |cond: bool, stat: f64| {
775 if cond {
776 Ok(ScalarValue::Float64(None))
777 } else {
778 Ok(ScalarValue::Float64(Some(stat)))
779 }
780 };
781
782 match self.regr_type {
783 RegrType::Slope => {
784 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
786 nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
787 }
788 RegrType::Intercept => {
789 let slope = cov_pop_x_y / var_pop_x;
790 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
792 nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
793 }
794 RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
795 RegrType::R2 => {
796 let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
798 nullif_or_stat(
799 nullif_cond,
800 (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
801 )
802 }
803 RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
804 RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
805 RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
806 RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
807 RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
808 }
809 }
810
811 fn size(&self) -> usize {
812 size_of_val(self)
813 }
814}