datafusion_functions_aggregate/
regr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use arrow::array::Float64Array;
21use arrow::datatypes::FieldRef;
22use arrow::{
23    array::{ArrayRef, UInt64Array},
24    compute::cast,
25    datatypes::DataType,
26    datatypes::Field,
27};
28use datafusion_common::{
29    downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result,
30    ScalarValue,
31};
32use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
33use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
34use datafusion_expr::type_coercion::aggregates::NUMERICS;
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::{
37    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
38};
39use std::any::Any;
40use std::fmt::Debug;
41use std::hash::Hash;
42use std::mem::size_of_val;
43use std::sync::{Arc, LazyLock};
44
45macro_rules! make_regr_udaf_expr_and_func {
46    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
47        make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
48        create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
49    }
50}
51
52make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
53make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
54make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
55make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
56make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
57make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
58make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
59make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
60make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
61
62#[derive(PartialEq, Eq, Hash)]
63pub struct Regr {
64    signature: Signature,
65    regr_type: RegrType,
66    func_name: &'static str,
67}
68
69impl Debug for Regr {
70    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
71        f.debug_struct("regr")
72            .field("name", &self.name())
73            .field("signature", &self.signature)
74            .finish()
75    }
76}
77
78impl Regr {
79    pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
80        Self {
81            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
82            regr_type,
83            func_name,
84        }
85    }
86}
87
88#[derive(Debug, Clone, PartialEq, Hash, Eq)]
89#[allow(clippy::upper_case_acronyms)]
90pub enum RegrType {
91    /// Variant for `regr_slope` aggregate expression
92    /// Returns the slope of the linear regression line for non-null pairs in aggregate columns.
93    /// Given input column Y and X: `regr_slope(Y, X)` returns the slope (k in Y = k*X + b) using minimal
94    /// RSS (Residual Sum of Squares) fitting.
95    Slope,
96    /// Variant for `regr_intercept` aggregate expression
97    /// Returns the intercept of the linear regression line for non-null pairs in aggregate columns.
98    /// Given input column Y and X: `regr_intercept(Y, X)` returns the intercept (b in Y = k*X + b) using minimal
99    /// RSS fitting.
100    Intercept,
101    /// Variant for `regr_count` aggregate expression
102    /// Returns the number of input rows for which both expressions are not null.
103    /// Given input column Y and X: `regr_count(Y, X)` returns the count of non-null pairs.
104    Count,
105    /// Variant for `regr_r2` aggregate expression
106    /// Returns the coefficient of determination (R-squared value) of the linear regression line for non-null pairs in aggregate columns.
107    /// The R-squared value represents the proportion of variance in Y that is predictable from X.
108    R2,
109    /// Variant for `regr_avgx` aggregate expression
110    /// Returns the average of the independent variable for non-null pairs in aggregate columns.
111    /// Given input column X: `regr_avgx(Y, X)` returns the average of X values.
112    AvgX,
113    /// Variant for `regr_avgy` aggregate expression
114    /// Returns the average of the dependent variable for non-null pairs in aggregate columns.
115    /// Given input column Y: `regr_avgy(Y, X)` returns the average of Y values.
116    AvgY,
117    /// Variant for `regr_sxx` aggregate expression
118    /// Returns the sum of squares of the independent variable for non-null pairs in aggregate columns.
119    /// Given input column X: `regr_sxx(Y, X)` returns the sum of squares of deviations of X from its mean.
120    SXX,
121    /// Variant for `regr_syy` aggregate expression
122    /// Returns the sum of squares of the dependent variable for non-null pairs in aggregate columns.
123    /// Given input column Y: `regr_syy(Y, X)` returns the sum of squares of deviations of Y from its mean.
124    SYY,
125    /// Variant for `regr_sxy` aggregate expression
126    /// Returns the sum of products of pairs of numbers for non-null pairs in aggregate columns.
127    /// Given input column Y and X: `regr_sxy(Y, X)` returns the sum of products of the deviations of Y and X from their respective means.
128    SXY,
129}
130
131impl RegrType {
132    /// return the documentation for the `RegrType`
133    fn documentation(&self) -> Option<&Documentation> {
134        get_regr_docs().get(self)
135    }
136}
137
138static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
139    let mut hash_map = HashMap::new();
140    hash_map.insert(
141            RegrType::Slope,
142            Documentation::builder(
143                DOC_SECTION_STATISTICAL,
144                    "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
145                    Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
146
147                "regr_slope(expression_y, expression_x)")
148                .with_sql_example(
149                    r#"```sql
150create table weekly_performance(day int, user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
151select * from weekly_performance;
152+-----+--------------+
153| day | user_signups |
154+-----+--------------+
155| 1   | 60           |
156| 2   | 65           |
157| 3   | 70           |
158| 4   | 75           |
159| 5   | 80           |
160+-----+--------------+
161
162SELECT regr_slope(user_signups, day) AS slope FROM weekly_performance;
163+--------+
164| slope  |
165+--------+
166| 5.0    |
167+--------+
168```
169"#
170                )
171                .with_standard_argument("expression_y", Some("Dependent variable"))
172                .with_standard_argument("expression_x", Some("Independent variable"))
173                .build()
174        );
175
176    hash_map.insert(
177            RegrType::Intercept,
178            Documentation::builder(
179                DOC_SECTION_STATISTICAL,
180                    "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
181                    this function returns b.",
182
183                "regr_intercept(expression_y, expression_x)")
184                .with_sql_example(
185                    r#"```sql
186create table weekly_performance(week int, productivity_score int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
187select * from weekly_performance;
188+------+---------------------+
189| week | productivity_score  |
190| ---- | ------------------- |
191| 1    | 60                  |
192| 2    | 65                  |
193| 3    | 70                  |
194| 4    | 75                  |
195| 5    | 80                  |
196+------+---------------------+
197
198SELECT regr_intercept(productivity_score, week) AS intercept FROM weekly_performance;
199+----------+
200|intercept|
201|intercept |
202+----------+
203|  55      |
204+----------+
205```
206"#
207                )
208                .with_standard_argument("expression_y", Some("Dependent variable"))
209                .with_standard_argument("expression_x", Some("Independent variable"))
210                .build()
211        );
212
213    hash_map.insert(
214        RegrType::Count,
215        Documentation::builder(
216            DOC_SECTION_STATISTICAL,
217            "Counts the number of non-null paired data points.",
218            "regr_count(expression_y, expression_x)",
219        )
220        .with_sql_example(
221            r#"```sql
222create table daily_metrics(day int, user_signups int) as values (1,100), (2,120), (3, NULL), (4,110), (5,NULL);
223select * from daily_metrics;
224+-----+---------------+
225| day | user_signups  |
226| --- | ------------- |
227| 1   | 100           |
228| 2   | 120           |
229| 3   | NULL          |
230| 4   | 110           |
231| 5   | NULL          |
232+-----+---------------+
233
234SELECT regr_count(user_signups, day) AS valid_pairs FROM daily_metrics;
235+-------------+
236| valid_pairs |
237+-------------+
238| 3           |
239+-------------+
240```
241"#
242        )
243        .with_standard_argument("expression_y", Some("Dependent variable"))
244        .with_standard_argument("expression_x", Some("Independent variable"))
245        .build(),
246    );
247
248    hash_map.insert(
249            RegrType::R2,
250            Documentation::builder(
251                DOC_SECTION_STATISTICAL,
252                    "Computes the square of the correlation coefficient between the independent and dependent variables.",
253
254                "regr_r2(expression_y, expression_x)")
255                .with_sql_example(
256                    r#"```sql
257create table weekly_performance(day int ,user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
258select * from weekly_performance;
259+-----+--------------+
260| day | user_signups |
261+-----+--------------+
262| 1   | 60           |
263| 2   | 65           |
264| 3   | 70           |
265| 4   | 75           |
266| 5   | 80           |
267+-----+--------------+
268
269SELECT regr_r2(user_signups, day) AS r_squared FROM weekly_performance;
270+---------+
271|r_squared|
272+---------+
273| 1.0     |
274+---------+
275```
276"#
277                )
278                .with_standard_argument("expression_y", Some("Dependent variable"))
279                .with_standard_argument("expression_x", Some("Independent variable"))
280                .build()
281        );
282
283    hash_map.insert(
284            RegrType::AvgX,
285            Documentation::builder(
286                DOC_SECTION_STATISTICAL,
287                    "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
288
289                "regr_avgx(expression_y, expression_x)")
290                .with_sql_example(
291                    r#"```sql
292create table daily_sales(day int, total_sales int) as values (1,100), (2,150), (3,200), (4,NULL), (5,250);
293select * from daily_sales;
294+-----+-------------+
295| day | total_sales |
296| --- | ----------- |
297| 1   | 100         |
298| 2   | 150         |
299| 3   | 200         |
300| 4   | NULL        |
301| 5   | 250         |
302+-----+-------------+
303
304SELECT regr_avgx(total_sales, day) AS avg_day FROM daily_sales;
305+----------+
306| avg_day  |
307+----------+
308|   2.75   |
309+----------+
310```
311"#
312                )
313                .with_standard_argument("expression_y", Some("Dependent variable"))
314                .with_standard_argument("expression_x", Some("Independent variable"))
315                .build()
316        );
317
318    hash_map.insert(
319            RegrType::AvgY,
320            Documentation::builder(
321                DOC_SECTION_STATISTICAL,
322                    "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
323
324                "regr_avgy(expression_y, expression_x)")
325                .with_sql_example(
326                    r#"```sql
327create table daily_temperature(day int, temperature int) as values (1,30), (2,32), (3, NULL), (4,35), (5,36);
328select * from daily_temperature;
329+-----+-------------+
330| day | temperature |
331| --- | ----------- |
332| 1   | 30          |
333| 2   | 32          |
334| 3   | NULL        |
335| 4   | 35          |
336| 5   | 36          |
337+-----+-------------+
338
339-- temperature as Dependent Variable(Y), day as Independent Variable(X)
340SELECT regr_avgy(temperature, day) AS avg_temperature FROM daily_temperature;
341+-----------------+
342| avg_temperature |
343+-----------------+
344| 33.25           |
345+-----------------+
346```
347"#
348                )
349                .with_standard_argument("expression_y", Some("Dependent variable"))
350                .with_standard_argument("expression_x", Some("Independent variable"))
351                .build()
352        );
353
354    hash_map.insert(
355        RegrType::SXX,
356        Documentation::builder(
357            DOC_SECTION_STATISTICAL,
358            "Computes the sum of squares of the independent variable.",
359            "regr_sxx(expression_y, expression_x)",
360        )
361        .with_sql_example(
362            r#"```sql
363create table study_hours(student_id int, hours int, test_score int) as values (1,2,55), (2,4,65), (3,6,75), (4,8,85), (5,10,95);
364select * from study_hours;
365+------------+-------+------------+
366| student_id | hours | test_score |
367+------------+-------+------------+
368| 1          | 2     | 55         |
369| 2          | 4     | 65         |
370| 3          | 6     | 75         |
371| 4          | 8     | 85         |
372| 5          | 10    | 95         |
373+------------+-------+------------+
374
375SELECT regr_sxx(test_score, hours) AS sxx FROM study_hours;
376+------+
377| sxx  |
378+------+
379| 40.0 |
380+------+
381```
382"#
383        )
384        .with_standard_argument("expression_y", Some("Dependent variable"))
385        .with_standard_argument("expression_x", Some("Independent variable"))
386        .build(),
387    );
388
389    hash_map.insert(
390        RegrType::SYY,
391        Documentation::builder(
392            DOC_SECTION_STATISTICAL,
393            "Computes the sum of squares of the dependent variable.",
394            "regr_syy(expression_y, expression_x)",
395        )
396        .with_sql_example(
397            r#"```sql
398create table employee_productivity(week int, productivity_score int) as values (1,60), (2,65), (3,70);
399select * from employee_productivity;
400+------+--------------------+
401| week | productivity_score |
402+------+--------------------+
403| 1    | 60                 |
404| 2    | 65                 |
405| 3    | 70                 |
406+------+--------------------+
407
408SELECT regr_syy(productivity_score, week) AS sum_squares_y FROM employee_productivity;
409+---------------+
410| sum_squares_y |
411+---------------+
412|    50.0       |
413+---------------+
414```
415"#
416        )
417        .with_standard_argument("expression_y", Some("Dependent variable"))
418        .with_standard_argument("expression_x", Some("Independent variable"))
419        .build(),
420    );
421
422    hash_map.insert(
423        RegrType::SXY,
424        Documentation::builder(
425            DOC_SECTION_STATISTICAL,
426            "Computes the sum of products of paired data points.",
427            "regr_sxy(expression_y, expression_x)",
428        )
429        .with_sql_example(
430            r#"```sql
431create table employee_productivity(week int, productivity_score int) as values(1,60), (2,65), (3,70);
432select * from employee_productivity;
433+------+--------------------+
434| week | productivity_score |
435+------+--------------------+
436| 1    | 60                 |
437| 2    | 65                 |
438| 3    | 70                 |
439+------+--------------------+
440
441SELECT regr_sxy(productivity_score, week) AS sum_product_deviations FROM employee_productivity;
442+------------------------+
443| sum_product_deviations |
444+------------------------+
445|       10.0             |
446+------------------------+
447```
448"#
449        )
450        .with_standard_argument("expression_y", Some("Dependent variable"))
451        .with_standard_argument("expression_x", Some("Independent variable"))
452        .build(),
453    );
454    hash_map
455});
456fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
457    &DOCUMENTATION
458}
459
460impl AggregateUDFImpl for Regr {
461    fn as_any(&self) -> &dyn Any {
462        self
463    }
464
465    fn name(&self) -> &str {
466        self.func_name
467    }
468
469    fn signature(&self) -> &Signature {
470        &self.signature
471    }
472
473    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
474        if !arg_types[0].is_numeric() {
475            return plan_err!("Covariance requires numeric input types");
476        }
477
478        if matches!(self.regr_type, RegrType::Count) {
479            Ok(DataType::UInt64)
480        } else {
481            Ok(DataType::Float64)
482        }
483    }
484
485    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
486        Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
487    }
488
489    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
490        Ok(vec![
491            Field::new(
492                format_state_name(args.name, "count"),
493                DataType::UInt64,
494                true,
495            ),
496            Field::new(
497                format_state_name(args.name, "mean_x"),
498                DataType::Float64,
499                true,
500            ),
501            Field::new(
502                format_state_name(args.name, "mean_y"),
503                DataType::Float64,
504                true,
505            ),
506            Field::new(
507                format_state_name(args.name, "m2_x"),
508                DataType::Float64,
509                true,
510            ),
511            Field::new(
512                format_state_name(args.name, "m2_y"),
513                DataType::Float64,
514                true,
515            ),
516            Field::new(
517                format_state_name(args.name, "algo_const"),
518                DataType::Float64,
519                true,
520            ),
521        ]
522        .into_iter()
523        .map(Arc::new)
524        .collect())
525    }
526
527    fn documentation(&self) -> Option<&Documentation> {
528        self.regr_type.documentation()
529    }
530}
531
532/// `RegrAccumulator` is used to compute linear regression aggregate functions
533/// by maintaining statistics needed to compute them in an online fashion.
534///
535/// This struct uses Welford's online algorithm for calculating variance and covariance:
536/// <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm>
537///
538/// Given the statistics, the following aggregate functions can be calculated:
539///
540/// - `regr_slope(y, x)`: Slope of the linear regression line, calculated as:
541///   cov_pop(x, y) / var_pop(x).
542///   It represents the expected change in Y for a one-unit change in X.
543///
544/// - `regr_intercept(y, x)`: Intercept of the linear regression line, calculated as:
545///   mean_y - (regr_slope(y, x) * mean_x).
546///   It represents the expected value of Y when X is 0.
547///
548/// - `regr_count(y, x)`: Count of the non-null(both x and y) input rows.
549///
550/// - `regr_r2(y, x)`: R-squared value (coefficient of determination), calculated as:
551///   (cov_pop(x, y) ^ 2) / (var_pop(x) * var_pop(y)).
552///   It provides a measure of how well the model's predictions match the observed data.
553///
554/// - `regr_avgx(y, x)`: Average of the independent variable X, calculated as: mean_x.
555///
556/// - `regr_avgy(y, x)`: Average of the dependent variable Y, calculated as: mean_y.
557///
558/// - `regr_sxx(y, x)`: Sum of squares of the independent variable X, calculated as:
559///   m2_x.
560///
561/// - `regr_syy(y, x)`: Sum of squares of the dependent variable Y, calculated as:
562///   m2_y.
563///
564/// - `regr_sxy(y, x)`: Sum of products of paired values, calculated as:
565///   algo_const.
566///
567/// Here's how the statistics maintained in this struct are calculated:
568/// - `cov_pop(x, y)`: algo_const / count.
569/// - `var_pop(x)`: m2_x / count.
570/// - `var_pop(y)`: m2_y / count.
571#[derive(Debug)]
572pub struct RegrAccumulator {
573    count: u64,
574    mean_x: f64,
575    mean_y: f64,
576    m2_x: f64,
577    m2_y: f64,
578    algo_const: f64,
579    regr_type: RegrType,
580}
581
582impl RegrAccumulator {
583    /// Creates a new `RegrAccumulator`
584    pub fn try_new(regr_type: &RegrType) -> Result<Self> {
585        Ok(Self {
586            count: 0_u64,
587            mean_x: 0_f64,
588            mean_y: 0_f64,
589            m2_x: 0_f64,
590            m2_y: 0_f64,
591            algo_const: 0_f64,
592            regr_type: regr_type.clone(),
593        })
594    }
595}
596
597impl Accumulator for RegrAccumulator {
598    fn state(&mut self) -> Result<Vec<ScalarValue>> {
599        Ok(vec![
600            ScalarValue::from(self.count),
601            ScalarValue::from(self.mean_x),
602            ScalarValue::from(self.mean_y),
603            ScalarValue::from(self.m2_x),
604            ScalarValue::from(self.m2_y),
605            ScalarValue::from(self.algo_const),
606        ])
607    }
608
609    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
610        // regr_slope(Y, X) calculates k in y = k*x + b
611        let values_y = &cast(&values[0], &DataType::Float64)?;
612        let values_x = &cast(&values[1], &DataType::Float64)?;
613
614        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
615        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
616
617        for i in 0..values_y.len() {
618            // skip either x or y is NULL
619            let value_y = if values_y.is_valid(i) {
620                arr_y.next()
621            } else {
622                None
623            };
624            let value_x = if values_x.is_valid(i) {
625                arr_x.next()
626            } else {
627                None
628            };
629            if value_y.is_none() || value_x.is_none() {
630                continue;
631            }
632
633            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
634            let value_y = unwrap_or_internal_err!(value_y);
635            let value_x = unwrap_or_internal_err!(value_x);
636
637            self.count += 1;
638            let delta_x = value_x - self.mean_x;
639            let delta_y = value_y - self.mean_y;
640            self.mean_x += delta_x / self.count as f64;
641            self.mean_y += delta_y / self.count as f64;
642            let delta_x_2 = value_x - self.mean_x;
643            let delta_y_2 = value_y - self.mean_y;
644            self.m2_x += delta_x * delta_x_2;
645            self.m2_y += delta_y * delta_y_2;
646            self.algo_const += delta_x * (value_y - self.mean_y);
647        }
648
649        Ok(())
650    }
651
652    fn supports_retract_batch(&self) -> bool {
653        true
654    }
655
656    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
657        let values_y = &cast(&values[0], &DataType::Float64)?;
658        let values_x = &cast(&values[1], &DataType::Float64)?;
659
660        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
661        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
662
663        for i in 0..values_y.len() {
664            // skip either x or y is NULL
665            let value_y = if values_y.is_valid(i) {
666                arr_y.next()
667            } else {
668                None
669            };
670            let value_x = if values_x.is_valid(i) {
671                arr_x.next()
672            } else {
673                None
674            };
675            if value_y.is_none() || value_x.is_none() {
676                continue;
677            }
678
679            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
680            let value_y = unwrap_or_internal_err!(value_y);
681            let value_x = unwrap_or_internal_err!(value_x);
682
683            if self.count > 1 {
684                self.count -= 1;
685                let delta_x = value_x - self.mean_x;
686                let delta_y = value_y - self.mean_y;
687                self.mean_x -= delta_x / self.count as f64;
688                self.mean_y -= delta_y / self.count as f64;
689                let delta_x_2 = value_x - self.mean_x;
690                let delta_y_2 = value_y - self.mean_y;
691                self.m2_x -= delta_x * delta_x_2;
692                self.m2_y -= delta_y * delta_y_2;
693                self.algo_const -= delta_x * (value_y - self.mean_y);
694            } else {
695                self.count = 0;
696                self.mean_x = 0.0;
697                self.m2_x = 0.0;
698                self.m2_y = 0.0;
699                self.mean_y = 0.0;
700                self.algo_const = 0.0;
701            }
702        }
703
704        Ok(())
705    }
706
707    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
708        let count_arr = downcast_value!(states[0], UInt64Array);
709        let mean_x_arr = downcast_value!(states[1], Float64Array);
710        let mean_y_arr = downcast_value!(states[2], Float64Array);
711        let m2_x_arr = downcast_value!(states[3], Float64Array);
712        let m2_y_arr = downcast_value!(states[4], Float64Array);
713        let algo_const_arr = downcast_value!(states[5], Float64Array);
714
715        for i in 0..count_arr.len() {
716            let count_b = count_arr.value(i);
717            if count_b == 0_u64 {
718                continue;
719            }
720            let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
721                self.count,
722                self.mean_x,
723                self.mean_y,
724                self.m2_x,
725                self.m2_y,
726                self.algo_const,
727            );
728            let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
729                count_b,
730                mean_x_arr.value(i),
731                mean_y_arr.value(i),
732                m2_x_arr.value(i),
733                m2_y_arr.value(i),
734                algo_const_arr.value(i),
735            );
736
737            // Assuming two different batches of input have calculated the states:
738            // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a}
739            // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b}
740            // The merged states from A and B are {count_ab, mean_x_ab, mean_y_ab, m2_x_ab,
741            // algo_const_ab}
742            //
743            // Reference for the algorithm to merge states:
744            // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
745            let count_ab = count_a + count_b;
746            let (count_a, count_b) = (count_a as f64, count_b as f64);
747            let d_x = mean_x_b - mean_x_a;
748            let d_y = mean_y_b - mean_y_a;
749            let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
750            let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
751            let m2_x_ab =
752                m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
753            let m2_y_ab =
754                m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
755            let algo_const_ab = algo_const_a
756                + algo_const_b
757                + d_x * d_y * count_a * count_b / count_ab as f64;
758
759            self.count = count_ab;
760            self.mean_x = mean_x_ab;
761            self.mean_y = mean_y_ab;
762            self.m2_x = m2_x_ab;
763            self.m2_y = m2_y_ab;
764            self.algo_const = algo_const_ab;
765        }
766        Ok(())
767    }
768
769    fn evaluate(&mut self) -> Result<ScalarValue> {
770        let cov_pop_x_y = self.algo_const / self.count as f64;
771        let var_pop_x = self.m2_x / self.count as f64;
772        let var_pop_y = self.m2_y / self.count as f64;
773
774        let nullif_or_stat = |cond: bool, stat: f64| {
775            if cond {
776                Ok(ScalarValue::Float64(None))
777            } else {
778                Ok(ScalarValue::Float64(Some(stat)))
779            }
780        };
781
782        match self.regr_type {
783            RegrType::Slope => {
784                // Only 0/1 point or slope is infinite
785                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
786                nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
787            }
788            RegrType::Intercept => {
789                let slope = cov_pop_x_y / var_pop_x;
790                // Only 0/1 point or slope is infinite
791                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
792                nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
793            }
794            RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
795            RegrType::R2 => {
796                // Only 0/1 point or all x(or y) is the same
797                let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
798                nullif_or_stat(
799                    nullif_cond,
800                    (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
801                )
802            }
803            RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
804            RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
805            RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
806            RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
807            RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
808        }
809    }
810
811    fn size(&self) -> usize {
812        size_of_val(self)
813    }
814}