datafusion_functions_aggregate/
regr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use arrow::array::Float64Array;
21use arrow::datatypes::FieldRef;
22use arrow::{
23    array::{ArrayRef, UInt64Array},
24    compute::cast,
25    datatypes::DataType,
26    datatypes::Field,
27};
28use datafusion_common::{
29    HashMap, Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err,
30};
31use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
32use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
33use datafusion_expr::type_coercion::aggregates::NUMERICS;
34use datafusion_expr::utils::format_state_name;
35use datafusion_expr::{
36    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
37};
38use std::any::Any;
39use std::fmt::Debug;
40use std::hash::Hash;
41use std::mem::size_of_val;
42use std::sync::{Arc, LazyLock};
43
44macro_rules! make_regr_udaf_expr_and_func {
45    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
46        make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
47        create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
48    }
49}
50
51make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
52make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
53make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
54make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
55make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
56make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
57make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
58make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
59make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
60
61#[derive(PartialEq, Eq, Hash)]
62pub struct Regr {
63    signature: Signature,
64    regr_type: RegrType,
65    func_name: &'static str,
66}
67
68impl Debug for Regr {
69    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
70        f.debug_struct("regr")
71            .field("name", &self.name())
72            .field("signature", &self.signature)
73            .finish()
74    }
75}
76
77impl Regr {
78    pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
79        Self {
80            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
81            regr_type,
82            func_name,
83        }
84    }
85}
86
87#[derive(Debug, Clone, PartialEq, Hash, Eq)]
88pub enum RegrType {
89    /// Variant for `regr_slope` aggregate expression
90    /// Returns the slope of the linear regression line for non-null pairs in aggregate columns.
91    /// Given input column Y and X: `regr_slope(Y, X)` returns the slope (k in Y = k*X + b) using minimal
92    /// RSS (Residual Sum of Squares) fitting.
93    Slope,
94    /// Variant for `regr_intercept` aggregate expression
95    /// Returns the intercept of the linear regression line for non-null pairs in aggregate columns.
96    /// Given input column Y and X: `regr_intercept(Y, X)` returns the intercept (b in Y = k*X + b) using minimal
97    /// RSS fitting.
98    Intercept,
99    /// Variant for `regr_count` aggregate expression
100    /// Returns the number of input rows for which both expressions are not null.
101    /// Given input column Y and X: `regr_count(Y, X)` returns the count of non-null pairs.
102    Count,
103    /// Variant for `regr_r2` aggregate expression
104    /// Returns the coefficient of determination (R-squared value) of the linear regression line for non-null pairs in aggregate columns.
105    /// The R-squared value represents the proportion of variance in Y that is predictable from X.
106    R2,
107    /// Variant for `regr_avgx` aggregate expression
108    /// Returns the average of the independent variable for non-null pairs in aggregate columns.
109    /// Given input column X: `regr_avgx(Y, X)` returns the average of X values.
110    AvgX,
111    /// Variant for `regr_avgy` aggregate expression
112    /// Returns the average of the dependent variable for non-null pairs in aggregate columns.
113    /// Given input column Y: `regr_avgy(Y, X)` returns the average of Y values.
114    AvgY,
115    /// Variant for `regr_sxx` aggregate expression
116    /// Returns the sum of squares of the independent variable for non-null pairs in aggregate columns.
117    /// Given input column X: `regr_sxx(Y, X)` returns the sum of squares of deviations of X from its mean.
118    SXX,
119    /// Variant for `regr_syy` aggregate expression
120    /// Returns the sum of squares of the dependent variable for non-null pairs in aggregate columns.
121    /// Given input column Y: `regr_syy(Y, X)` returns the sum of squares of deviations of Y from its mean.
122    SYY,
123    /// Variant for `regr_sxy` aggregate expression
124    /// Returns the sum of products of pairs of numbers for non-null pairs in aggregate columns.
125    /// Given input column Y and X: `regr_sxy(Y, X)` returns the sum of products of the deviations of Y and X from their respective means.
126    SXY,
127}
128
129impl RegrType {
130    /// return the documentation for the `RegrType`
131    fn documentation(&self) -> Option<&Documentation> {
132        get_regr_docs().get(self)
133    }
134}
135
136static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
137    let mut hash_map = HashMap::new();
138    hash_map.insert(
139            RegrType::Slope,
140            Documentation::builder(
141                DOC_SECTION_STATISTICAL,
142                    "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
143                    Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
144
145                "regr_slope(expression_y, expression_x)")
146                .with_sql_example(
147                    r#"```sql
148create table weekly_performance(day int, user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
149select * from weekly_performance;
150+-----+--------------+
151| day | user_signups |
152+-----+--------------+
153| 1   | 60           |
154| 2   | 65           |
155| 3   | 70           |
156| 4   | 75           |
157| 5   | 80           |
158+-----+--------------+
159
160SELECT regr_slope(user_signups, day) AS slope FROM weekly_performance;
161+--------+
162| slope  |
163+--------+
164| 5.0    |
165+--------+
166```
167"#
168                )
169                .with_standard_argument("expression_y", Some("Dependent variable"))
170                .with_standard_argument("expression_x", Some("Independent variable"))
171                .build()
172        );
173
174    hash_map.insert(
175            RegrType::Intercept,
176            Documentation::builder(
177                DOC_SECTION_STATISTICAL,
178                    "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
179                    this function returns b.",
180
181                "regr_intercept(expression_y, expression_x)")
182                .with_sql_example(
183                    r#"```sql
184create table weekly_performance(week int, productivity_score int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
185select * from weekly_performance;
186+------+---------------------+
187| week | productivity_score  |
188| ---- | ------------------- |
189| 1    | 60                  |
190| 2    | 65                  |
191| 3    | 70                  |
192| 4    | 75                  |
193| 5    | 80                  |
194+------+---------------------+
195
196SELECT regr_intercept(productivity_score, week) AS intercept FROM weekly_performance;
197+----------+
198|intercept|
199|intercept |
200+----------+
201|  55      |
202+----------+
203```
204"#
205                )
206                .with_standard_argument("expression_y", Some("Dependent variable"))
207                .with_standard_argument("expression_x", Some("Independent variable"))
208                .build()
209        );
210
211    hash_map.insert(
212        RegrType::Count,
213        Documentation::builder(
214            DOC_SECTION_STATISTICAL,
215            "Counts the number of non-null paired data points.",
216            "regr_count(expression_y, expression_x)",
217        )
218        .with_sql_example(
219            r#"```sql
220create table daily_metrics(day int, user_signups int) as values (1,100), (2,120), (3, NULL), (4,110), (5,NULL);
221select * from daily_metrics;
222+-----+---------------+
223| day | user_signups  |
224| --- | ------------- |
225| 1   | 100           |
226| 2   | 120           |
227| 3   | NULL          |
228| 4   | 110           |
229| 5   | NULL          |
230+-----+---------------+
231
232SELECT regr_count(user_signups, day) AS valid_pairs FROM daily_metrics;
233+-------------+
234| valid_pairs |
235+-------------+
236| 3           |
237+-------------+
238```
239"#
240        )
241        .with_standard_argument("expression_y", Some("Dependent variable"))
242        .with_standard_argument("expression_x", Some("Independent variable"))
243        .build(),
244    );
245
246    hash_map.insert(
247            RegrType::R2,
248            Documentation::builder(
249                DOC_SECTION_STATISTICAL,
250                    "Computes the square of the correlation coefficient between the independent and dependent variables.",
251
252                "regr_r2(expression_y, expression_x)")
253                .with_sql_example(
254                    r#"```sql
255create table weekly_performance(day int ,user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
256select * from weekly_performance;
257+-----+--------------+
258| day | user_signups |
259+-----+--------------+
260| 1   | 60           |
261| 2   | 65           |
262| 3   | 70           |
263| 4   | 75           |
264| 5   | 80           |
265+-----+--------------+
266
267SELECT regr_r2(user_signups, day) AS r_squared FROM weekly_performance;
268+---------+
269|r_squared|
270+---------+
271| 1.0     |
272+---------+
273```
274"#
275                )
276                .with_standard_argument("expression_y", Some("Dependent variable"))
277                .with_standard_argument("expression_x", Some("Independent variable"))
278                .build()
279        );
280
281    hash_map.insert(
282            RegrType::AvgX,
283            Documentation::builder(
284                DOC_SECTION_STATISTICAL,
285                    "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
286
287                "regr_avgx(expression_y, expression_x)")
288                .with_sql_example(
289                    r#"```sql
290create table daily_sales(day int, total_sales int) as values (1,100), (2,150), (3,200), (4,NULL), (5,250);
291select * from daily_sales;
292+-----+-------------+
293| day | total_sales |
294| --- | ----------- |
295| 1   | 100         |
296| 2   | 150         |
297| 3   | 200         |
298| 4   | NULL        |
299| 5   | 250         |
300+-----+-------------+
301
302SELECT regr_avgx(total_sales, day) AS avg_day FROM daily_sales;
303+----------+
304| avg_day  |
305+----------+
306|   2.75   |
307+----------+
308```
309"#
310                )
311                .with_standard_argument("expression_y", Some("Dependent variable"))
312                .with_standard_argument("expression_x", Some("Independent variable"))
313                .build()
314        );
315
316    hash_map.insert(
317            RegrType::AvgY,
318            Documentation::builder(
319                DOC_SECTION_STATISTICAL,
320                    "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
321
322                "regr_avgy(expression_y, expression_x)")
323                .with_sql_example(
324                    r#"```sql
325create table daily_temperature(day int, temperature int) as values (1,30), (2,32), (3, NULL), (4,35), (5,36);
326select * from daily_temperature;
327+-----+-------------+
328| day | temperature |
329| --- | ----------- |
330| 1   | 30          |
331| 2   | 32          |
332| 3   | NULL        |
333| 4   | 35          |
334| 5   | 36          |
335+-----+-------------+
336
337-- temperature as Dependent Variable(Y), day as Independent Variable(X)
338SELECT regr_avgy(temperature, day) AS avg_temperature FROM daily_temperature;
339+-----------------+
340| avg_temperature |
341+-----------------+
342| 33.25           |
343+-----------------+
344```
345"#
346                )
347                .with_standard_argument("expression_y", Some("Dependent variable"))
348                .with_standard_argument("expression_x", Some("Independent variable"))
349                .build()
350        );
351
352    hash_map.insert(
353        RegrType::SXX,
354        Documentation::builder(
355            DOC_SECTION_STATISTICAL,
356            "Computes the sum of squares of the independent variable.",
357            "regr_sxx(expression_y, expression_x)",
358        )
359        .with_sql_example(
360            r#"```sql
361create table study_hours(student_id int, hours int, test_score int) as values (1,2,55), (2,4,65), (3,6,75), (4,8,85), (5,10,95);
362select * from study_hours;
363+------------+-------+------------+
364| student_id | hours | test_score |
365+------------+-------+------------+
366| 1          | 2     | 55         |
367| 2          | 4     | 65         |
368| 3          | 6     | 75         |
369| 4          | 8     | 85         |
370| 5          | 10    | 95         |
371+------------+-------+------------+
372
373SELECT regr_sxx(test_score, hours) AS sxx FROM study_hours;
374+------+
375| sxx  |
376+------+
377| 40.0 |
378+------+
379```
380"#
381        )
382        .with_standard_argument("expression_y", Some("Dependent variable"))
383        .with_standard_argument("expression_x", Some("Independent variable"))
384        .build(),
385    );
386
387    hash_map.insert(
388        RegrType::SYY,
389        Documentation::builder(
390            DOC_SECTION_STATISTICAL,
391            "Computes the sum of squares of the dependent variable.",
392            "regr_syy(expression_y, expression_x)",
393        )
394        .with_sql_example(
395            r#"```sql
396create table employee_productivity(week int, productivity_score int) as values (1,60), (2,65), (3,70);
397select * from employee_productivity;
398+------+--------------------+
399| week | productivity_score |
400+------+--------------------+
401| 1    | 60                 |
402| 2    | 65                 |
403| 3    | 70                 |
404+------+--------------------+
405
406SELECT regr_syy(productivity_score, week) AS sum_squares_y FROM employee_productivity;
407+---------------+
408| sum_squares_y |
409+---------------+
410|    50.0       |
411+---------------+
412```
413"#
414        )
415        .with_standard_argument("expression_y", Some("Dependent variable"))
416        .with_standard_argument("expression_x", Some("Independent variable"))
417        .build(),
418    );
419
420    hash_map.insert(
421        RegrType::SXY,
422        Documentation::builder(
423            DOC_SECTION_STATISTICAL,
424            "Computes the sum of products of paired data points.",
425            "regr_sxy(expression_y, expression_x)",
426        )
427        .with_sql_example(
428            r#"```sql
429create table employee_productivity(week int, productivity_score int) as values(1,60), (2,65), (3,70);
430select * from employee_productivity;
431+------+--------------------+
432| week | productivity_score |
433+------+--------------------+
434| 1    | 60                 |
435| 2    | 65                 |
436| 3    | 70                 |
437+------+--------------------+
438
439SELECT regr_sxy(productivity_score, week) AS sum_product_deviations FROM employee_productivity;
440+------------------------+
441| sum_product_deviations |
442+------------------------+
443|       10.0             |
444+------------------------+
445```
446"#
447        )
448        .with_standard_argument("expression_y", Some("Dependent variable"))
449        .with_standard_argument("expression_x", Some("Independent variable"))
450        .build(),
451    );
452    hash_map
453});
454fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
455    &DOCUMENTATION
456}
457
458impl AggregateUDFImpl for Regr {
459    fn as_any(&self) -> &dyn Any {
460        self
461    }
462
463    fn name(&self) -> &str {
464        self.func_name
465    }
466
467    fn signature(&self) -> &Signature {
468        &self.signature
469    }
470
471    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
472        if !arg_types[0].is_numeric() {
473            return plan_err!("Covariance requires numeric input types");
474        }
475
476        if matches!(self.regr_type, RegrType::Count) {
477            Ok(DataType::UInt64)
478        } else {
479            Ok(DataType::Float64)
480        }
481    }
482
483    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
484        Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
485    }
486
487    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
488        Ok(vec![
489            Field::new(
490                format_state_name(args.name, "count"),
491                DataType::UInt64,
492                true,
493            ),
494            Field::new(
495                format_state_name(args.name, "mean_x"),
496                DataType::Float64,
497                true,
498            ),
499            Field::new(
500                format_state_name(args.name, "mean_y"),
501                DataType::Float64,
502                true,
503            ),
504            Field::new(
505                format_state_name(args.name, "m2_x"),
506                DataType::Float64,
507                true,
508            ),
509            Field::new(
510                format_state_name(args.name, "m2_y"),
511                DataType::Float64,
512                true,
513            ),
514            Field::new(
515                format_state_name(args.name, "algo_const"),
516                DataType::Float64,
517                true,
518            ),
519        ]
520        .into_iter()
521        .map(Arc::new)
522        .collect())
523    }
524
525    fn documentation(&self) -> Option<&Documentation> {
526        self.regr_type.documentation()
527    }
528}
529
530/// `RegrAccumulator` is used to compute linear regression aggregate functions
531/// by maintaining statistics needed to compute them in an online fashion.
532///
533/// This struct uses Welford's online algorithm for calculating variance and covariance:
534/// <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm>
535///
536/// Given the statistics, the following aggregate functions can be calculated:
537///
538/// - `regr_slope(y, x)`: Slope of the linear regression line, calculated as:
539///   cov_pop(x, y) / var_pop(x).
540///   It represents the expected change in Y for a one-unit change in X.
541///
542/// - `regr_intercept(y, x)`: Intercept of the linear regression line, calculated as:
543///   mean_y - (regr_slope(y, x) * mean_x).
544///   It represents the expected value of Y when X is 0.
545///
546/// - `regr_count(y, x)`: Count of the non-null(both x and y) input rows.
547///
548/// - `regr_r2(y, x)`: R-squared value (coefficient of determination), calculated as:
549///   (cov_pop(x, y) ^ 2) / (var_pop(x) * var_pop(y)).
550///   It provides a measure of how well the model's predictions match the observed data.
551///
552/// - `regr_avgx(y, x)`: Average of the independent variable X, calculated as: mean_x.
553///
554/// - `regr_avgy(y, x)`: Average of the dependent variable Y, calculated as: mean_y.
555///
556/// - `regr_sxx(y, x)`: Sum of squares of the independent variable X, calculated as:
557///   m2_x.
558///
559/// - `regr_syy(y, x)`: Sum of squares of the dependent variable Y, calculated as:
560///   m2_y.
561///
562/// - `regr_sxy(y, x)`: Sum of products of paired values, calculated as:
563///   algo_const.
564///
565/// Here's how the statistics maintained in this struct are calculated:
566/// - `cov_pop(x, y)`: algo_const / count.
567/// - `var_pop(x)`: m2_x / count.
568/// - `var_pop(y)`: m2_y / count.
569#[derive(Debug)]
570pub struct RegrAccumulator {
571    count: u64,
572    mean_x: f64,
573    mean_y: f64,
574    m2_x: f64,
575    m2_y: f64,
576    algo_const: f64,
577    regr_type: RegrType,
578}
579
580impl RegrAccumulator {
581    /// Creates a new `RegrAccumulator`
582    pub fn try_new(regr_type: &RegrType) -> Result<Self> {
583        Ok(Self {
584            count: 0_u64,
585            mean_x: 0_f64,
586            mean_y: 0_f64,
587            m2_x: 0_f64,
588            m2_y: 0_f64,
589            algo_const: 0_f64,
590            regr_type: regr_type.clone(),
591        })
592    }
593}
594
595impl Accumulator for RegrAccumulator {
596    fn state(&mut self) -> Result<Vec<ScalarValue>> {
597        Ok(vec![
598            ScalarValue::from(self.count),
599            ScalarValue::from(self.mean_x),
600            ScalarValue::from(self.mean_y),
601            ScalarValue::from(self.m2_x),
602            ScalarValue::from(self.m2_y),
603            ScalarValue::from(self.algo_const),
604        ])
605    }
606
607    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
608        // regr_slope(Y, X) calculates k in y = k*x + b
609        let values_y = &cast(&values[0], &DataType::Float64)?;
610        let values_x = &cast(&values[1], &DataType::Float64)?;
611
612        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
613        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
614
615        for i in 0..values_y.len() {
616            // skip either x or y is NULL
617            let value_y = if values_y.is_valid(i) {
618                arr_y.next()
619            } else {
620                None
621            };
622            let value_x = if values_x.is_valid(i) {
623                arr_x.next()
624            } else {
625                None
626            };
627            if value_y.is_none() || value_x.is_none() {
628                continue;
629            }
630
631            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
632            let value_y = unwrap_or_internal_err!(value_y);
633            let value_x = unwrap_or_internal_err!(value_x);
634
635            self.count += 1;
636            let delta_x = value_x - self.mean_x;
637            let delta_y = value_y - self.mean_y;
638            self.mean_x += delta_x / self.count as f64;
639            self.mean_y += delta_y / self.count as f64;
640            let delta_x_2 = value_x - self.mean_x;
641            let delta_y_2 = value_y - self.mean_y;
642            self.m2_x += delta_x * delta_x_2;
643            self.m2_y += delta_y * delta_y_2;
644            self.algo_const += delta_x * (value_y - self.mean_y);
645        }
646
647        Ok(())
648    }
649
650    fn supports_retract_batch(&self) -> bool {
651        true
652    }
653
654    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
655        let values_y = &cast(&values[0], &DataType::Float64)?;
656        let values_x = &cast(&values[1], &DataType::Float64)?;
657
658        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
659        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
660
661        for i in 0..values_y.len() {
662            // skip either x or y is NULL
663            let value_y = if values_y.is_valid(i) {
664                arr_y.next()
665            } else {
666                None
667            };
668            let value_x = if values_x.is_valid(i) {
669                arr_x.next()
670            } else {
671                None
672            };
673            if value_y.is_none() || value_x.is_none() {
674                continue;
675            }
676
677            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
678            let value_y = unwrap_or_internal_err!(value_y);
679            let value_x = unwrap_or_internal_err!(value_x);
680
681            if self.count > 1 {
682                self.count -= 1;
683                let delta_x = value_x - self.mean_x;
684                let delta_y = value_y - self.mean_y;
685                self.mean_x -= delta_x / self.count as f64;
686                self.mean_y -= delta_y / self.count as f64;
687                let delta_x_2 = value_x - self.mean_x;
688                let delta_y_2 = value_y - self.mean_y;
689                self.m2_x -= delta_x * delta_x_2;
690                self.m2_y -= delta_y * delta_y_2;
691                self.algo_const -= delta_x * (value_y - self.mean_y);
692            } else {
693                self.count = 0;
694                self.mean_x = 0.0;
695                self.m2_x = 0.0;
696                self.m2_y = 0.0;
697                self.mean_y = 0.0;
698                self.algo_const = 0.0;
699            }
700        }
701
702        Ok(())
703    }
704
705    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
706        let count_arr = downcast_value!(states[0], UInt64Array);
707        let mean_x_arr = downcast_value!(states[1], Float64Array);
708        let mean_y_arr = downcast_value!(states[2], Float64Array);
709        let m2_x_arr = downcast_value!(states[3], Float64Array);
710        let m2_y_arr = downcast_value!(states[4], Float64Array);
711        let algo_const_arr = downcast_value!(states[5], Float64Array);
712
713        for i in 0..count_arr.len() {
714            let count_b = count_arr.value(i);
715            if count_b == 0_u64 {
716                continue;
717            }
718            let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
719                self.count,
720                self.mean_x,
721                self.mean_y,
722                self.m2_x,
723                self.m2_y,
724                self.algo_const,
725            );
726            let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
727                count_b,
728                mean_x_arr.value(i),
729                mean_y_arr.value(i),
730                m2_x_arr.value(i),
731                m2_y_arr.value(i),
732                algo_const_arr.value(i),
733            );
734
735            // Assuming two different batches of input have calculated the states:
736            // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a}
737            // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b}
738            // The merged states from A and B are {count_ab, mean_x_ab, mean_y_ab, m2_x_ab,
739            // algo_const_ab}
740            //
741            // Reference for the algorithm to merge states:
742            // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
743            let count_ab = count_a + count_b;
744            let (count_a, count_b) = (count_a as f64, count_b as f64);
745            let d_x = mean_x_b - mean_x_a;
746            let d_y = mean_y_b - mean_y_a;
747            let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
748            let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
749            let m2_x_ab =
750                m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
751            let m2_y_ab =
752                m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
753            let algo_const_ab = algo_const_a
754                + algo_const_b
755                + d_x * d_y * count_a * count_b / count_ab as f64;
756
757            self.count = count_ab;
758            self.mean_x = mean_x_ab;
759            self.mean_y = mean_y_ab;
760            self.m2_x = m2_x_ab;
761            self.m2_y = m2_y_ab;
762            self.algo_const = algo_const_ab;
763        }
764        Ok(())
765    }
766
767    fn evaluate(&mut self) -> Result<ScalarValue> {
768        let cov_pop_x_y = self.algo_const / self.count as f64;
769        let var_pop_x = self.m2_x / self.count as f64;
770        let var_pop_y = self.m2_y / self.count as f64;
771
772        let nullif_or_stat = |cond: bool, stat: f64| {
773            if cond {
774                Ok(ScalarValue::Float64(None))
775            } else {
776                Ok(ScalarValue::Float64(Some(stat)))
777            }
778        };
779
780        match self.regr_type {
781            RegrType::Slope => {
782                // Only 0/1 point or slope is infinite
783                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
784                nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
785            }
786            RegrType::Intercept => {
787                let slope = cov_pop_x_y / var_pop_x;
788                // Only 0/1 point or slope is infinite
789                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
790                nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
791            }
792            RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
793            RegrType::R2 => {
794                // Only 0/1 point or all x(or y) is the same
795                let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
796                nullif_or_stat(
797                    nullif_cond,
798                    (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
799                )
800            }
801            RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
802            RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
803            RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
804            RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
805            RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
806        }
807    }
808
809    fn size(&self) -> usize {
810        size_of_val(self)
811    }
812}