Skip to main content

datafusion_functions_aggregate/
regr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use arrow::datatypes::FieldRef;
21use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
22use datafusion_common::cast::{as_float64_array, as_uint64_array};
23use datafusion_common::{HashMap, Result, ScalarValue};
24use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
25use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
26use datafusion_expr::utils::format_state_name;
27use datafusion_expr::{
28    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
29};
30use std::any::Any;
31use std::fmt::Debug;
32use std::hash::Hash;
33use std::mem::size_of_val;
34use std::sync::{Arc, LazyLock};
35
36macro_rules! make_regr_udaf_expr_and_func {
37    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
38        make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
39        create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
40    }
41}
42
43make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
44make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
45make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
46make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
47make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
48make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
49make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
50make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
51make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
52
53#[derive(PartialEq, Eq, Hash, Debug)]
54pub struct Regr {
55    signature: Signature,
56    regr_type: RegrType,
57    func_name: &'static str,
58}
59
60impl Regr {
61    pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
62        Self {
63            signature: Signature::exact(
64                vec![DataType::Float64, DataType::Float64],
65                Volatility::Immutable,
66            ),
67            regr_type,
68            func_name,
69        }
70    }
71}
72
73#[derive(Debug, Clone, PartialEq, Hash, Eq)]
74pub enum RegrType {
75    /// Variant for `regr_slope` aggregate expression
76    /// Returns the slope of the linear regression line for non-null pairs in aggregate columns.
77    /// Given input column Y and X: `regr_slope(Y, X)` returns the slope (k in Y = k*X + b) using minimal
78    /// RSS (Residual Sum of Squares) fitting.
79    Slope,
80    /// Variant for `regr_intercept` aggregate expression
81    /// Returns the intercept of the linear regression line for non-null pairs in aggregate columns.
82    /// Given input column Y and X: `regr_intercept(Y, X)` returns the intercept (b in Y = k*X + b) using minimal
83    /// RSS fitting.
84    Intercept,
85    /// Variant for `regr_count` aggregate expression
86    /// Returns the number of input rows for which both expressions are not null.
87    /// Given input column Y and X: `regr_count(Y, X)` returns the count of non-null pairs.
88    Count,
89    /// Variant for `regr_r2` aggregate expression
90    /// Returns the coefficient of determination (R-squared value) of the linear regression line for non-null pairs in aggregate columns.
91    /// The R-squared value represents the proportion of variance in Y that is predictable from X.
92    R2,
93    /// Variant for `regr_avgx` aggregate expression
94    /// Returns the average of the independent variable for non-null pairs in aggregate columns.
95    /// Given input column X: `regr_avgx(Y, X)` returns the average of X values.
96    AvgX,
97    /// Variant for `regr_avgy` aggregate expression
98    /// Returns the average of the dependent variable for non-null pairs in aggregate columns.
99    /// Given input column Y: `regr_avgy(Y, X)` returns the average of Y values.
100    AvgY,
101    /// Variant for `regr_sxx` aggregate expression
102    /// Returns the sum of squares of the independent variable for non-null pairs in aggregate columns.
103    /// Given input column X: `regr_sxx(Y, X)` returns the sum of squares of deviations of X from its mean.
104    SXX,
105    /// Variant for `regr_syy` aggregate expression
106    /// Returns the sum of squares of the dependent variable for non-null pairs in aggregate columns.
107    /// Given input column Y: `regr_syy(Y, X)` returns the sum of squares of deviations of Y from its mean.
108    SYY,
109    /// Variant for `regr_sxy` aggregate expression
110    /// Returns the sum of products of pairs of numbers for non-null pairs in aggregate columns.
111    /// Given input column Y and X: `regr_sxy(Y, X)` returns the sum of products of the deviations of Y and X from their respective means.
112    SXY,
113}
114
115impl RegrType {
116    /// return the documentation for the `RegrType`
117    fn documentation(&self) -> Option<&Documentation> {
118        get_regr_docs().get(self)
119    }
120}
121
122static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
123    let mut hash_map = HashMap::new();
124    hash_map.insert(
125            RegrType::Slope,
126            Documentation::builder(
127                DOC_SECTION_STATISTICAL,
128                    "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
129                    Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
130
131                "regr_slope(expression_y, expression_x)")
132                .with_sql_example(
133                    r#"```sql
134create table weekly_performance(day int, user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
135select * from weekly_performance;
136+-----+--------------+
137| day | user_signups |
138+-----+--------------+
139| 1   | 60           |
140| 2   | 65           |
141| 3   | 70           |
142| 4   | 75           |
143| 5   | 80           |
144+-----+--------------+
145
146SELECT regr_slope(user_signups, day) AS slope FROM weekly_performance;
147+--------+
148| slope  |
149+--------+
150| 5.0    |
151+--------+
152```
153"#
154                )
155                .with_standard_argument("expression_y", Some("Dependent variable"))
156                .with_standard_argument("expression_x", Some("Independent variable"))
157                .build()
158        );
159
160    hash_map.insert(
161            RegrType::Intercept,
162            Documentation::builder(
163                DOC_SECTION_STATISTICAL,
164                    "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
165                    this function returns b.",
166
167                "regr_intercept(expression_y, expression_x)")
168                .with_sql_example(
169                    r#"```sql
170create table weekly_performance(week int, productivity_score int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
171select * from weekly_performance;
172+------+---------------------+
173| week | productivity_score  |
174| ---- | ------------------- |
175| 1    | 60                  |
176| 2    | 65                  |
177| 3    | 70                  |
178| 4    | 75                  |
179| 5    | 80                  |
180+------+---------------------+
181
182SELECT regr_intercept(productivity_score, week) AS intercept FROM weekly_performance;
183+----------+
184|intercept|
185|intercept |
186+----------+
187|  55      |
188+----------+
189```
190"#
191                )
192                .with_standard_argument("expression_y", Some("Dependent variable"))
193                .with_standard_argument("expression_x", Some("Independent variable"))
194                .build()
195        );
196
197    hash_map.insert(
198        RegrType::Count,
199        Documentation::builder(
200            DOC_SECTION_STATISTICAL,
201            "Counts the number of non-null paired data points.",
202            "regr_count(expression_y, expression_x)",
203        )
204        .with_sql_example(
205            r#"```sql
206create table daily_metrics(day int, user_signups int) as values (1,100), (2,120), (3, NULL), (4,110), (5,NULL);
207select * from daily_metrics;
208+-----+---------------+
209| day | user_signups  |
210| --- | ------------- |
211| 1   | 100           |
212| 2   | 120           |
213| 3   | NULL          |
214| 4   | 110           |
215| 5   | NULL          |
216+-----+---------------+
217
218SELECT regr_count(user_signups, day) AS valid_pairs FROM daily_metrics;
219+-------------+
220| valid_pairs |
221+-------------+
222| 3           |
223+-------------+
224```
225"#
226        )
227        .with_standard_argument("expression_y", Some("Dependent variable"))
228        .with_standard_argument("expression_x", Some("Independent variable"))
229        .build(),
230    );
231
232    hash_map.insert(
233            RegrType::R2,
234            Documentation::builder(
235                DOC_SECTION_STATISTICAL,
236                    "Computes the square of the correlation coefficient between the independent and dependent variables.",
237
238                "regr_r2(expression_y, expression_x)")
239                .with_sql_example(
240                    r#"```sql
241create table weekly_performance(day int ,user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
242select * from weekly_performance;
243+-----+--------------+
244| day | user_signups |
245+-----+--------------+
246| 1   | 60           |
247| 2   | 65           |
248| 3   | 70           |
249| 4   | 75           |
250| 5   | 80           |
251+-----+--------------+
252
253SELECT regr_r2(user_signups, day) AS r_squared FROM weekly_performance;
254+---------+
255|r_squared|
256+---------+
257| 1.0     |
258+---------+
259```
260"#
261                )
262                .with_standard_argument("expression_y", Some("Dependent variable"))
263                .with_standard_argument("expression_x", Some("Independent variable"))
264                .build()
265        );
266
267    hash_map.insert(
268            RegrType::AvgX,
269            Documentation::builder(
270                DOC_SECTION_STATISTICAL,
271                    "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
272
273                "regr_avgx(expression_y, expression_x)")
274                .with_sql_example(
275                    r#"```sql
276create table daily_sales(day int, total_sales int) as values (1,100), (2,150), (3,200), (4,NULL), (5,250);
277select * from daily_sales;
278+-----+-------------+
279| day | total_sales |
280| --- | ----------- |
281| 1   | 100         |
282| 2   | 150         |
283| 3   | 200         |
284| 4   | NULL        |
285| 5   | 250         |
286+-----+-------------+
287
288SELECT regr_avgx(total_sales, day) AS avg_day FROM daily_sales;
289+----------+
290| avg_day  |
291+----------+
292|   2.75   |
293+----------+
294```
295"#
296                )
297                .with_standard_argument("expression_y", Some("Dependent variable"))
298                .with_standard_argument("expression_x", Some("Independent variable"))
299                .build()
300        );
301
302    hash_map.insert(
303            RegrType::AvgY,
304            Documentation::builder(
305                DOC_SECTION_STATISTICAL,
306                    "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
307
308                "regr_avgy(expression_y, expression_x)")
309                .with_sql_example(
310                    r#"```sql
311create table daily_temperature(day int, temperature int) as values (1,30), (2,32), (3, NULL), (4,35), (5,36);
312select * from daily_temperature;
313+-----+-------------+
314| day | temperature |
315| --- | ----------- |
316| 1   | 30          |
317| 2   | 32          |
318| 3   | NULL        |
319| 4   | 35          |
320| 5   | 36          |
321+-----+-------------+
322
323-- temperature as Dependent Variable(Y), day as Independent Variable(X)
324SELECT regr_avgy(temperature, day) AS avg_temperature FROM daily_temperature;
325+-----------------+
326| avg_temperature |
327+-----------------+
328| 33.25           |
329+-----------------+
330```
331"#
332                )
333                .with_standard_argument("expression_y", Some("Dependent variable"))
334                .with_standard_argument("expression_x", Some("Independent variable"))
335                .build()
336        );
337
338    hash_map.insert(
339        RegrType::SXX,
340        Documentation::builder(
341            DOC_SECTION_STATISTICAL,
342            "Computes the sum of squares of the independent variable.",
343            "regr_sxx(expression_y, expression_x)",
344        )
345        .with_sql_example(
346            r#"```sql
347create table study_hours(student_id int, hours int, test_score int) as values (1,2,55), (2,4,65), (3,6,75), (4,8,85), (5,10,95);
348select * from study_hours;
349+------------+-------+------------+
350| student_id | hours | test_score |
351+------------+-------+------------+
352| 1          | 2     | 55         |
353| 2          | 4     | 65         |
354| 3          | 6     | 75         |
355| 4          | 8     | 85         |
356| 5          | 10    | 95         |
357+------------+-------+------------+
358
359SELECT regr_sxx(test_score, hours) AS sxx FROM study_hours;
360+------+
361| sxx  |
362+------+
363| 40.0 |
364+------+
365```
366"#
367        )
368        .with_standard_argument("expression_y", Some("Dependent variable"))
369        .with_standard_argument("expression_x", Some("Independent variable"))
370        .build(),
371    );
372
373    hash_map.insert(
374        RegrType::SYY,
375        Documentation::builder(
376            DOC_SECTION_STATISTICAL,
377            "Computes the sum of squares of the dependent variable.",
378            "regr_syy(expression_y, expression_x)",
379        )
380        .with_sql_example(
381            r#"```sql
382create table employee_productivity(week int, productivity_score int) as values (1,60), (2,65), (3,70);
383select * from employee_productivity;
384+------+--------------------+
385| week | productivity_score |
386+------+--------------------+
387| 1    | 60                 |
388| 2    | 65                 |
389| 3    | 70                 |
390+------+--------------------+
391
392SELECT regr_syy(productivity_score, week) AS sum_squares_y FROM employee_productivity;
393+---------------+
394| sum_squares_y |
395+---------------+
396|    50.0       |
397+---------------+
398```
399"#
400        )
401        .with_standard_argument("expression_y", Some("Dependent variable"))
402        .with_standard_argument("expression_x", Some("Independent variable"))
403        .build(),
404    );
405
406    hash_map.insert(
407        RegrType::SXY,
408        Documentation::builder(
409            DOC_SECTION_STATISTICAL,
410            "Computes the sum of products of paired data points.",
411            "regr_sxy(expression_y, expression_x)",
412        )
413        .with_sql_example(
414            r#"```sql
415create table employee_productivity(week int, productivity_score int) as values(1,60), (2,65), (3,70);
416select * from employee_productivity;
417+------+--------------------+
418| week | productivity_score |
419+------+--------------------+
420| 1    | 60                 |
421| 2    | 65                 |
422| 3    | 70                 |
423+------+--------------------+
424
425SELECT regr_sxy(productivity_score, week) AS sum_product_deviations FROM employee_productivity;
426+------------------------+
427| sum_product_deviations |
428+------------------------+
429|       10.0             |
430+------------------------+
431```
432"#
433        )
434        .with_standard_argument("expression_y", Some("Dependent variable"))
435        .with_standard_argument("expression_x", Some("Independent variable"))
436        .build(),
437    );
438    hash_map
439});
440fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
441    &DOCUMENTATION
442}
443
444impl AggregateUDFImpl for Regr {
445    fn as_any(&self) -> &dyn Any {
446        self
447    }
448
449    fn name(&self) -> &str {
450        self.func_name
451    }
452
453    fn signature(&self) -> &Signature {
454        &self.signature
455    }
456
457    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
458        if self.regr_type == RegrType::Count {
459            Ok(DataType::UInt64)
460        } else {
461            Ok(DataType::Float64)
462        }
463    }
464
465    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
466        Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
467    }
468
469    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
470        Ok(vec![
471            Field::new(
472                format_state_name(args.name, "count"),
473                DataType::UInt64,
474                true,
475            ),
476            Field::new(
477                format_state_name(args.name, "mean_x"),
478                DataType::Float64,
479                true,
480            ),
481            Field::new(
482                format_state_name(args.name, "mean_y"),
483                DataType::Float64,
484                true,
485            ),
486            Field::new(
487                format_state_name(args.name, "m2_x"),
488                DataType::Float64,
489                true,
490            ),
491            Field::new(
492                format_state_name(args.name, "m2_y"),
493                DataType::Float64,
494                true,
495            ),
496            Field::new(
497                format_state_name(args.name, "algo_const"),
498                DataType::Float64,
499                true,
500            ),
501        ]
502        .into_iter()
503        .map(Arc::new)
504        .collect())
505    }
506
507    fn documentation(&self) -> Option<&Documentation> {
508        self.regr_type.documentation()
509    }
510}
511
512/// `RegrAccumulator` is used to compute linear regression aggregate functions
513/// by maintaining statistics needed to compute them in an online fashion.
514///
515/// This struct uses Welford's online algorithm for calculating variance and covariance:
516/// <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm>
517///
518/// Given the statistics, the following aggregate functions can be calculated:
519///
520/// - `regr_slope(y, x)`: Slope of the linear regression line, calculated as:
521///   cov_pop(x, y) / var_pop(x).
522///   It represents the expected change in Y for a one-unit change in X.
523///
524/// - `regr_intercept(y, x)`: Intercept of the linear regression line, calculated as:
525///   mean_y - (regr_slope(y, x) * mean_x).
526///   It represents the expected value of Y when X is 0.
527///
528/// - `regr_count(y, x)`: Count of the non-null(both x and y) input rows.
529///
530/// - `regr_r2(y, x)`: R-squared value (coefficient of determination), calculated as:
531///   (cov_pop(x, y) ^ 2) / (var_pop(x) * var_pop(y)).
532///   It provides a measure of how well the model's predictions match the observed data.
533///
534/// - `regr_avgx(y, x)`: Average of the independent variable X, calculated as: mean_x.
535///
536/// - `regr_avgy(y, x)`: Average of the dependent variable Y, calculated as: mean_y.
537///
538/// - `regr_sxx(y, x)`: Sum of squares of the independent variable X, calculated as:
539///   m2_x.
540///
541/// - `regr_syy(y, x)`: Sum of squares of the dependent variable Y, calculated as:
542///   m2_y.
543///
544/// - `regr_sxy(y, x)`: Sum of products of paired values, calculated as:
545///   algo_const.
546///
547/// Here's how the statistics maintained in this struct are calculated:
548/// - `cov_pop(x, y)`: algo_const / count.
549/// - `var_pop(x)`: m2_x / count.
550/// - `var_pop(y)`: m2_y / count.
551#[derive(Debug)]
552pub struct RegrAccumulator {
553    count: u64,
554    mean_x: f64,
555    mean_y: f64,
556    m2_x: f64,
557    m2_y: f64,
558    algo_const: f64,
559    regr_type: RegrType,
560}
561
562impl RegrAccumulator {
563    /// Creates a new `RegrAccumulator`
564    pub fn try_new(regr_type: &RegrType) -> Result<Self> {
565        Ok(Self {
566            count: 0_u64,
567            mean_x: 0_f64,
568            mean_y: 0_f64,
569            m2_x: 0_f64,
570            m2_y: 0_f64,
571            algo_const: 0_f64,
572            regr_type: regr_type.clone(),
573        })
574    }
575}
576
577impl Accumulator for RegrAccumulator {
578    fn state(&mut self) -> Result<Vec<ScalarValue>> {
579        Ok(vec![
580            ScalarValue::from(self.count),
581            ScalarValue::from(self.mean_x),
582            ScalarValue::from(self.mean_y),
583            ScalarValue::from(self.m2_x),
584            ScalarValue::from(self.m2_y),
585            ScalarValue::from(self.algo_const),
586        ])
587    }
588
589    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
590        // regr_slope(Y, X) calculates k in y = k*x + b
591        let values_y = as_float64_array(&values[0])?;
592        let values_x = as_float64_array(&values[1])?;
593
594        for (value_y, value_x) in values_y.iter().zip(values_x) {
595            // skip either x or y is NULL
596            let (value_y, value_x) = match (value_y, value_x) {
597                (Some(y), Some(x)) => (y, x),
598                // skip either x or y is NULL
599                _ => continue,
600            };
601
602            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
603            self.count += 1;
604            let delta_x = value_x - self.mean_x;
605            let delta_y = value_y - self.mean_y;
606            self.mean_x += delta_x / self.count as f64;
607            self.mean_y += delta_y / self.count as f64;
608            let delta_x_2 = value_x - self.mean_x;
609            let delta_y_2 = value_y - self.mean_y;
610            self.m2_x += delta_x * delta_x_2;
611            self.m2_y += delta_y * delta_y_2;
612            self.algo_const += delta_x * (value_y - self.mean_y);
613        }
614
615        Ok(())
616    }
617
618    fn supports_retract_batch(&self) -> bool {
619        true
620    }
621
622    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
623        let values_y = as_float64_array(&values[0])?;
624        let values_x = as_float64_array(&values[1])?;
625
626        for (value_y, value_x) in values_y.iter().zip(values_x) {
627            // skip either x or y is NULL
628            let (value_y, value_x) = match (value_y, value_x) {
629                (Some(y), Some(x)) => (y, x),
630                // skip either x or y is NULL
631                _ => continue,
632            };
633
634            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
635            if self.count > 1 {
636                self.count -= 1;
637                let delta_x = value_x - self.mean_x;
638                let delta_y = value_y - self.mean_y;
639                self.mean_x -= delta_x / self.count as f64;
640                self.mean_y -= delta_y / self.count as f64;
641                let delta_x_2 = value_x - self.mean_x;
642                let delta_y_2 = value_y - self.mean_y;
643                self.m2_x -= delta_x * delta_x_2;
644                self.m2_y -= delta_y * delta_y_2;
645                self.algo_const -= delta_x * (value_y - self.mean_y);
646            } else {
647                self.count = 0;
648                self.mean_x = 0.0;
649                self.m2_x = 0.0;
650                self.m2_y = 0.0;
651                self.mean_y = 0.0;
652                self.algo_const = 0.0;
653            }
654        }
655
656        Ok(())
657    }
658
659    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
660        let count_arr = as_uint64_array(&states[0])?;
661        let mean_x_arr = as_float64_array(&states[1])?;
662        let mean_y_arr = as_float64_array(&states[2])?;
663        let m2_x_arr = as_float64_array(&states[3])?;
664        let m2_y_arr = as_float64_array(&states[4])?;
665        let algo_const_arr = as_float64_array(&states[5])?;
666
667        for i in 0..count_arr.len() {
668            let count_b = count_arr.value(i);
669            if count_b == 0_u64 {
670                continue;
671            }
672            let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
673                self.count,
674                self.mean_x,
675                self.mean_y,
676                self.m2_x,
677                self.m2_y,
678                self.algo_const,
679            );
680            let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
681                count_b,
682                mean_x_arr.value(i),
683                mean_y_arr.value(i),
684                m2_x_arr.value(i),
685                m2_y_arr.value(i),
686                algo_const_arr.value(i),
687            );
688
689            // Assuming two different batches of input have calculated the states:
690            // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a}
691            // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b}
692            // The merged states from A and B are {count_ab, mean_x_ab, mean_y_ab, m2_x_ab,
693            // algo_const_ab}
694            //
695            // Reference for the algorithm to merge states:
696            // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
697            let count_ab = count_a + count_b;
698            let (count_a, count_b) = (count_a as f64, count_b as f64);
699            let d_x = mean_x_b - mean_x_a;
700            let d_y = mean_y_b - mean_y_a;
701            let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
702            let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
703            let m2_x_ab =
704                m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
705            let m2_y_ab =
706                m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
707            let algo_const_ab = algo_const_a
708                + algo_const_b
709                + d_x * d_y * count_a * count_b / count_ab as f64;
710
711            self.count = count_ab;
712            self.mean_x = mean_x_ab;
713            self.mean_y = mean_y_ab;
714            self.m2_x = m2_x_ab;
715            self.m2_y = m2_y_ab;
716            self.algo_const = algo_const_ab;
717        }
718        Ok(())
719    }
720
721    fn evaluate(&mut self) -> Result<ScalarValue> {
722        let cov_pop_x_y = self.algo_const / self.count as f64;
723        let var_pop_x = self.m2_x / self.count as f64;
724        let var_pop_y = self.m2_y / self.count as f64;
725
726        let nullif_or_stat = |cond: bool, stat: f64| {
727            if cond {
728                Ok(ScalarValue::Float64(None))
729            } else {
730                Ok(ScalarValue::Float64(Some(stat)))
731            }
732        };
733
734        match self.regr_type {
735            RegrType::Slope => {
736                // Only 0/1 point or slope is infinite
737                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
738                nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
739            }
740            RegrType::Intercept => {
741                let slope = cov_pop_x_y / var_pop_x;
742                // Only 0/1 point or slope is infinite
743                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
744                nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
745            }
746            RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
747            RegrType::R2 => {
748                // Only 0/1 point or all x(or y) is the same
749                let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
750                nullif_or_stat(
751                    nullif_cond,
752                    (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
753                )
754            }
755            RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
756            RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
757            RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
758            RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
759            RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
760        }
761    }
762
763    fn size(&self) -> usize {
764        size_of_val(self)
765    }
766}