1use arrow::datatypes::FieldRef;
21use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
22use datafusion_common::cast::{as_float64_array, as_uint64_array};
23use datafusion_common::{HashMap, Result, ScalarValue};
24use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
25use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
26use datafusion_expr::utils::format_state_name;
27use datafusion_expr::{
28 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
29};
30use std::any::Any;
31use std::fmt::Debug;
32use std::hash::Hash;
33use std::mem::size_of_val;
34use std::sync::{Arc, LazyLock};
35
36macro_rules! make_regr_udaf_expr_and_func {
37 ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
38 make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
39 create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
40 }
41}
42
43make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
44make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
45make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
46make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
47make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
48make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
49make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
50make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
51make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
52
53#[derive(PartialEq, Eq, Hash, Debug)]
54pub struct Regr {
55 signature: Signature,
56 regr_type: RegrType,
57 func_name: &'static str,
58}
59
60impl Regr {
61 pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
62 Self {
63 signature: Signature::exact(
64 vec![DataType::Float64, DataType::Float64],
65 Volatility::Immutable,
66 ),
67 regr_type,
68 func_name,
69 }
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Hash, Eq)]
74pub enum RegrType {
75 Slope,
80 Intercept,
85 Count,
89 R2,
93 AvgX,
97 AvgY,
101 SXX,
105 SYY,
109 SXY,
113}
114
115impl RegrType {
116 fn documentation(&self) -> Option<&Documentation> {
118 get_regr_docs().get(self)
119 }
120}
121
122static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
123 let mut hash_map = HashMap::new();
124 hash_map.insert(
125 RegrType::Slope,
126 Documentation::builder(
127 DOC_SECTION_STATISTICAL,
128 "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
129 Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
130
131 "regr_slope(expression_y, expression_x)")
132 .with_sql_example(
133 r#"```sql
134create table weekly_performance(day int, user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
135select * from weekly_performance;
136+-----+--------------+
137| day | user_signups |
138+-----+--------------+
139| 1 | 60 |
140| 2 | 65 |
141| 3 | 70 |
142| 4 | 75 |
143| 5 | 80 |
144+-----+--------------+
145
146SELECT regr_slope(user_signups, day) AS slope FROM weekly_performance;
147+--------+
148| slope |
149+--------+
150| 5.0 |
151+--------+
152```
153"#
154 )
155 .with_standard_argument("expression_y", Some("Dependent variable"))
156 .with_standard_argument("expression_x", Some("Independent variable"))
157 .build()
158 );
159
160 hash_map.insert(
161 RegrType::Intercept,
162 Documentation::builder(
163 DOC_SECTION_STATISTICAL,
164 "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
165 this function returns b.",
166
167 "regr_intercept(expression_y, expression_x)")
168 .with_sql_example(
169 r#"```sql
170create table weekly_performance(week int, productivity_score int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
171select * from weekly_performance;
172+------+---------------------+
173| week | productivity_score |
174| ---- | ------------------- |
175| 1 | 60 |
176| 2 | 65 |
177| 3 | 70 |
178| 4 | 75 |
179| 5 | 80 |
180+------+---------------------+
181
182SELECT regr_intercept(productivity_score, week) AS intercept FROM weekly_performance;
183+----------+
184|intercept|
185|intercept |
186+----------+
187| 55 |
188+----------+
189```
190"#
191 )
192 .with_standard_argument("expression_y", Some("Dependent variable"))
193 .with_standard_argument("expression_x", Some("Independent variable"))
194 .build()
195 );
196
197 hash_map.insert(
198 RegrType::Count,
199 Documentation::builder(
200 DOC_SECTION_STATISTICAL,
201 "Counts the number of non-null paired data points.",
202 "regr_count(expression_y, expression_x)",
203 )
204 .with_sql_example(
205 r#"```sql
206create table daily_metrics(day int, user_signups int) as values (1,100), (2,120), (3, NULL), (4,110), (5,NULL);
207select * from daily_metrics;
208+-----+---------------+
209| day | user_signups |
210| --- | ------------- |
211| 1 | 100 |
212| 2 | 120 |
213| 3 | NULL |
214| 4 | 110 |
215| 5 | NULL |
216+-----+---------------+
217
218SELECT regr_count(user_signups, day) AS valid_pairs FROM daily_metrics;
219+-------------+
220| valid_pairs |
221+-------------+
222| 3 |
223+-------------+
224```
225"#
226 )
227 .with_standard_argument("expression_y", Some("Dependent variable"))
228 .with_standard_argument("expression_x", Some("Independent variable"))
229 .build(),
230 );
231
232 hash_map.insert(
233 RegrType::R2,
234 Documentation::builder(
235 DOC_SECTION_STATISTICAL,
236 "Computes the square of the correlation coefficient between the independent and dependent variables.",
237
238 "regr_r2(expression_y, expression_x)")
239 .with_sql_example(
240 r#"```sql
241create table weekly_performance(day int ,user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
242select * from weekly_performance;
243+-----+--------------+
244| day | user_signups |
245+-----+--------------+
246| 1 | 60 |
247| 2 | 65 |
248| 3 | 70 |
249| 4 | 75 |
250| 5 | 80 |
251+-----+--------------+
252
253SELECT regr_r2(user_signups, day) AS r_squared FROM weekly_performance;
254+---------+
255|r_squared|
256+---------+
257| 1.0 |
258+---------+
259```
260"#
261 )
262 .with_standard_argument("expression_y", Some("Dependent variable"))
263 .with_standard_argument("expression_x", Some("Independent variable"))
264 .build()
265 );
266
267 hash_map.insert(
268 RegrType::AvgX,
269 Documentation::builder(
270 DOC_SECTION_STATISTICAL,
271 "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
272
273 "regr_avgx(expression_y, expression_x)")
274 .with_sql_example(
275 r#"```sql
276create table daily_sales(day int, total_sales int) as values (1,100), (2,150), (3,200), (4,NULL), (5,250);
277select * from daily_sales;
278+-----+-------------+
279| day | total_sales |
280| --- | ----------- |
281| 1 | 100 |
282| 2 | 150 |
283| 3 | 200 |
284| 4 | NULL |
285| 5 | 250 |
286+-----+-------------+
287
288SELECT regr_avgx(total_sales, day) AS avg_day FROM daily_sales;
289+----------+
290| avg_day |
291+----------+
292| 2.75 |
293+----------+
294```
295"#
296 )
297 .with_standard_argument("expression_y", Some("Dependent variable"))
298 .with_standard_argument("expression_x", Some("Independent variable"))
299 .build()
300 );
301
302 hash_map.insert(
303 RegrType::AvgY,
304 Documentation::builder(
305 DOC_SECTION_STATISTICAL,
306 "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
307
308 "regr_avgy(expression_y, expression_x)")
309 .with_sql_example(
310 r#"```sql
311create table daily_temperature(day int, temperature int) as values (1,30), (2,32), (3, NULL), (4,35), (5,36);
312select * from daily_temperature;
313+-----+-------------+
314| day | temperature |
315| --- | ----------- |
316| 1 | 30 |
317| 2 | 32 |
318| 3 | NULL |
319| 4 | 35 |
320| 5 | 36 |
321+-----+-------------+
322
323-- temperature as Dependent Variable(Y), day as Independent Variable(X)
324SELECT regr_avgy(temperature, day) AS avg_temperature FROM daily_temperature;
325+-----------------+
326| avg_temperature |
327+-----------------+
328| 33.25 |
329+-----------------+
330```
331"#
332 )
333 .with_standard_argument("expression_y", Some("Dependent variable"))
334 .with_standard_argument("expression_x", Some("Independent variable"))
335 .build()
336 );
337
338 hash_map.insert(
339 RegrType::SXX,
340 Documentation::builder(
341 DOC_SECTION_STATISTICAL,
342 "Computes the sum of squares of the independent variable.",
343 "regr_sxx(expression_y, expression_x)",
344 )
345 .with_sql_example(
346 r#"```sql
347create table study_hours(student_id int, hours int, test_score int) as values (1,2,55), (2,4,65), (3,6,75), (4,8,85), (5,10,95);
348select * from study_hours;
349+------------+-------+------------+
350| student_id | hours | test_score |
351+------------+-------+------------+
352| 1 | 2 | 55 |
353| 2 | 4 | 65 |
354| 3 | 6 | 75 |
355| 4 | 8 | 85 |
356| 5 | 10 | 95 |
357+------------+-------+------------+
358
359SELECT regr_sxx(test_score, hours) AS sxx FROM study_hours;
360+------+
361| sxx |
362+------+
363| 40.0 |
364+------+
365```
366"#
367 )
368 .with_standard_argument("expression_y", Some("Dependent variable"))
369 .with_standard_argument("expression_x", Some("Independent variable"))
370 .build(),
371 );
372
373 hash_map.insert(
374 RegrType::SYY,
375 Documentation::builder(
376 DOC_SECTION_STATISTICAL,
377 "Computes the sum of squares of the dependent variable.",
378 "regr_syy(expression_y, expression_x)",
379 )
380 .with_sql_example(
381 r#"```sql
382create table employee_productivity(week int, productivity_score int) as values (1,60), (2,65), (3,70);
383select * from employee_productivity;
384+------+--------------------+
385| week | productivity_score |
386+------+--------------------+
387| 1 | 60 |
388| 2 | 65 |
389| 3 | 70 |
390+------+--------------------+
391
392SELECT regr_syy(productivity_score, week) AS sum_squares_y FROM employee_productivity;
393+---------------+
394| sum_squares_y |
395+---------------+
396| 50.0 |
397+---------------+
398```
399"#
400 )
401 .with_standard_argument("expression_y", Some("Dependent variable"))
402 .with_standard_argument("expression_x", Some("Independent variable"))
403 .build(),
404 );
405
406 hash_map.insert(
407 RegrType::SXY,
408 Documentation::builder(
409 DOC_SECTION_STATISTICAL,
410 "Computes the sum of products of paired data points.",
411 "regr_sxy(expression_y, expression_x)",
412 )
413 .with_sql_example(
414 r#"```sql
415create table employee_productivity(week int, productivity_score int) as values(1,60), (2,65), (3,70);
416select * from employee_productivity;
417+------+--------------------+
418| week | productivity_score |
419+------+--------------------+
420| 1 | 60 |
421| 2 | 65 |
422| 3 | 70 |
423+------+--------------------+
424
425SELECT regr_sxy(productivity_score, week) AS sum_product_deviations FROM employee_productivity;
426+------------------------+
427| sum_product_deviations |
428+------------------------+
429| 10.0 |
430+------------------------+
431```
432"#
433 )
434 .with_standard_argument("expression_y", Some("Dependent variable"))
435 .with_standard_argument("expression_x", Some("Independent variable"))
436 .build(),
437 );
438 hash_map
439});
440fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
441 &DOCUMENTATION
442}
443
444impl AggregateUDFImpl for Regr {
445 fn as_any(&self) -> &dyn Any {
446 self
447 }
448
449 fn name(&self) -> &str {
450 self.func_name
451 }
452
453 fn signature(&self) -> &Signature {
454 &self.signature
455 }
456
457 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
458 if self.regr_type == RegrType::Count {
459 Ok(DataType::UInt64)
460 } else {
461 Ok(DataType::Float64)
462 }
463 }
464
465 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
466 Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
467 }
468
469 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
470 Ok(vec![
471 Field::new(
472 format_state_name(args.name, "count"),
473 DataType::UInt64,
474 true,
475 ),
476 Field::new(
477 format_state_name(args.name, "mean_x"),
478 DataType::Float64,
479 true,
480 ),
481 Field::new(
482 format_state_name(args.name, "mean_y"),
483 DataType::Float64,
484 true,
485 ),
486 Field::new(
487 format_state_name(args.name, "m2_x"),
488 DataType::Float64,
489 true,
490 ),
491 Field::new(
492 format_state_name(args.name, "m2_y"),
493 DataType::Float64,
494 true,
495 ),
496 Field::new(
497 format_state_name(args.name, "algo_const"),
498 DataType::Float64,
499 true,
500 ),
501 ]
502 .into_iter()
503 .map(Arc::new)
504 .collect())
505 }
506
507 fn documentation(&self) -> Option<&Documentation> {
508 self.regr_type.documentation()
509 }
510}
511
512#[derive(Debug)]
552pub struct RegrAccumulator {
553 count: u64,
554 mean_x: f64,
555 mean_y: f64,
556 m2_x: f64,
557 m2_y: f64,
558 algo_const: f64,
559 regr_type: RegrType,
560}
561
562impl RegrAccumulator {
563 pub fn try_new(regr_type: &RegrType) -> Result<Self> {
565 Ok(Self {
566 count: 0_u64,
567 mean_x: 0_f64,
568 mean_y: 0_f64,
569 m2_x: 0_f64,
570 m2_y: 0_f64,
571 algo_const: 0_f64,
572 regr_type: regr_type.clone(),
573 })
574 }
575}
576
577impl Accumulator for RegrAccumulator {
578 fn state(&mut self) -> Result<Vec<ScalarValue>> {
579 Ok(vec![
580 ScalarValue::from(self.count),
581 ScalarValue::from(self.mean_x),
582 ScalarValue::from(self.mean_y),
583 ScalarValue::from(self.m2_x),
584 ScalarValue::from(self.m2_y),
585 ScalarValue::from(self.algo_const),
586 ])
587 }
588
589 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
590 let values_y = as_float64_array(&values[0])?;
592 let values_x = as_float64_array(&values[1])?;
593
594 for (value_y, value_x) in values_y.iter().zip(values_x) {
595 let (value_y, value_x) = match (value_y, value_x) {
597 (Some(y), Some(x)) => (y, x),
598 _ => continue,
600 };
601
602 self.count += 1;
604 let delta_x = value_x - self.mean_x;
605 let delta_y = value_y - self.mean_y;
606 self.mean_x += delta_x / self.count as f64;
607 self.mean_y += delta_y / self.count as f64;
608 let delta_x_2 = value_x - self.mean_x;
609 let delta_y_2 = value_y - self.mean_y;
610 self.m2_x += delta_x * delta_x_2;
611 self.m2_y += delta_y * delta_y_2;
612 self.algo_const += delta_x * (value_y - self.mean_y);
613 }
614
615 Ok(())
616 }
617
618 fn supports_retract_batch(&self) -> bool {
619 true
620 }
621
622 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
623 let values_y = as_float64_array(&values[0])?;
624 let values_x = as_float64_array(&values[1])?;
625
626 for (value_y, value_x) in values_y.iter().zip(values_x) {
627 let (value_y, value_x) = match (value_y, value_x) {
629 (Some(y), Some(x)) => (y, x),
630 _ => continue,
632 };
633
634 if self.count > 1 {
636 self.count -= 1;
637 let delta_x = value_x - self.mean_x;
638 let delta_y = value_y - self.mean_y;
639 self.mean_x -= delta_x / self.count as f64;
640 self.mean_y -= delta_y / self.count as f64;
641 let delta_x_2 = value_x - self.mean_x;
642 let delta_y_2 = value_y - self.mean_y;
643 self.m2_x -= delta_x * delta_x_2;
644 self.m2_y -= delta_y * delta_y_2;
645 self.algo_const -= delta_x * (value_y - self.mean_y);
646 } else {
647 self.count = 0;
648 self.mean_x = 0.0;
649 self.m2_x = 0.0;
650 self.m2_y = 0.0;
651 self.mean_y = 0.0;
652 self.algo_const = 0.0;
653 }
654 }
655
656 Ok(())
657 }
658
659 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
660 let count_arr = as_uint64_array(&states[0])?;
661 let mean_x_arr = as_float64_array(&states[1])?;
662 let mean_y_arr = as_float64_array(&states[2])?;
663 let m2_x_arr = as_float64_array(&states[3])?;
664 let m2_y_arr = as_float64_array(&states[4])?;
665 let algo_const_arr = as_float64_array(&states[5])?;
666
667 for i in 0..count_arr.len() {
668 let count_b = count_arr.value(i);
669 if count_b == 0_u64 {
670 continue;
671 }
672 let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
673 self.count,
674 self.mean_x,
675 self.mean_y,
676 self.m2_x,
677 self.m2_y,
678 self.algo_const,
679 );
680 let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
681 count_b,
682 mean_x_arr.value(i),
683 mean_y_arr.value(i),
684 m2_x_arr.value(i),
685 m2_y_arr.value(i),
686 algo_const_arr.value(i),
687 );
688
689 let count_ab = count_a + count_b;
698 let (count_a, count_b) = (count_a as f64, count_b as f64);
699 let d_x = mean_x_b - mean_x_a;
700 let d_y = mean_y_b - mean_y_a;
701 let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
702 let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
703 let m2_x_ab =
704 m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
705 let m2_y_ab =
706 m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
707 let algo_const_ab = algo_const_a
708 + algo_const_b
709 + d_x * d_y * count_a * count_b / count_ab as f64;
710
711 self.count = count_ab;
712 self.mean_x = mean_x_ab;
713 self.mean_y = mean_y_ab;
714 self.m2_x = m2_x_ab;
715 self.m2_y = m2_y_ab;
716 self.algo_const = algo_const_ab;
717 }
718 Ok(())
719 }
720
721 fn evaluate(&mut self) -> Result<ScalarValue> {
722 let cov_pop_x_y = self.algo_const / self.count as f64;
723 let var_pop_x = self.m2_x / self.count as f64;
724 let var_pop_y = self.m2_y / self.count as f64;
725
726 let nullif_or_stat = |cond: bool, stat: f64| {
727 if cond {
728 Ok(ScalarValue::Float64(None))
729 } else {
730 Ok(ScalarValue::Float64(Some(stat)))
731 }
732 };
733
734 match self.regr_type {
735 RegrType::Slope => {
736 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
738 nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
739 }
740 RegrType::Intercept => {
741 let slope = cov_pop_x_y / var_pop_x;
742 let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
744 nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
745 }
746 RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
747 RegrType::R2 => {
748 let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
750 nullif_or_stat(
751 nullif_cond,
752 (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
753 )
754 }
755 RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
756 RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
757 RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
758 RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
759 RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
760 }
761 }
762
763 fn size(&self) -> usize {
764 size_of_val(self)
765 }
766}