datafusion_functions_aggregate/
regr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines physical expressions that can evaluated at runtime during query execution
19
20use arrow::array::Float64Array;
21use arrow::datatypes::FieldRef;
22use arrow::{
23    array::{ArrayRef, UInt64Array},
24    compute::cast,
25    datatypes::DataType,
26    datatypes::Field,
27};
28use datafusion_common::{
29    downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result,
30    ScalarValue,
31};
32use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
33use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
34use datafusion_expr::type_coercion::aggregates::NUMERICS;
35use datafusion_expr::utils::format_state_name;
36use datafusion_expr::{
37    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
38};
39use std::any::Any;
40use std::fmt::Debug;
41use std::mem::size_of_val;
42use std::sync::{Arc, LazyLock};
43
44macro_rules! make_regr_udaf_expr_and_func {
45    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
46        make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
47        create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
48    }
49}
50
51make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
52make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
53make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
54make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
55make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
56make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
57make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
58make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
59make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
60
61pub struct Regr {
62    signature: Signature,
63    regr_type: RegrType,
64    func_name: &'static str,
65}
66
67impl Debug for Regr {
68    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
69        f.debug_struct("regr")
70            .field("name", &self.name())
71            .field("signature", &self.signature)
72            .finish()
73    }
74}
75
76impl Regr {
77    pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
78        Self {
79            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
80            regr_type,
81            func_name,
82        }
83    }
84}
85
86#[derive(Debug, Clone, PartialEq, Hash, Eq)]
87#[allow(clippy::upper_case_acronyms)]
88pub enum RegrType {
89    /// Variant for `regr_slope` aggregate expression
90    /// Returns the slope of the linear regression line for non-null pairs in aggregate columns.
91    /// Given input column Y and X: `regr_slope(Y, X)` returns the slope (k in Y = k*X + b) using minimal
92    /// RSS (Residual Sum of Squares) fitting.
93    Slope,
94    /// Variant for `regr_intercept` aggregate expression
95    /// Returns the intercept of the linear regression line for non-null pairs in aggregate columns.
96    /// Given input column Y and X: `regr_intercept(Y, X)` returns the intercept (b in Y = k*X + b) using minimal
97    /// RSS fitting.
98    Intercept,
99    /// Variant for `regr_count` aggregate expression
100    /// Returns the number of input rows for which both expressions are not null.
101    /// Given input column Y and X: `regr_count(Y, X)` returns the count of non-null pairs.
102    Count,
103    /// Variant for `regr_r2` aggregate expression
104    /// Returns the coefficient of determination (R-squared value) of the linear regression line for non-null pairs in aggregate columns.
105    /// The R-squared value represents the proportion of variance in Y that is predictable from X.
106    R2,
107    /// Variant for `regr_avgx` aggregate expression
108    /// Returns the average of the independent variable for non-null pairs in aggregate columns.
109    /// Given input column X: `regr_avgx(Y, X)` returns the average of X values.
110    AvgX,
111    /// Variant for `regr_avgy` aggregate expression
112    /// Returns the average of the dependent variable for non-null pairs in aggregate columns.
113    /// Given input column Y: `regr_avgy(Y, X)` returns the average of Y values.
114    AvgY,
115    /// Variant for `regr_sxx` aggregate expression
116    /// Returns the sum of squares of the independent variable for non-null pairs in aggregate columns.
117    /// Given input column X: `regr_sxx(Y, X)` returns the sum of squares of deviations of X from its mean.
118    SXX,
119    /// Variant for `regr_syy` aggregate expression
120    /// Returns the sum of squares of the dependent variable for non-null pairs in aggregate columns.
121    /// Given input column Y: `regr_syy(Y, X)` returns the sum of squares of deviations of Y from its mean.
122    SYY,
123    /// Variant for `regr_sxy` aggregate expression
124    /// Returns the sum of products of pairs of numbers for non-null pairs in aggregate columns.
125    /// Given input column Y and X: `regr_sxy(Y, X)` returns the sum of products of the deviations of Y and X from their respective means.
126    SXY,
127}
128
129impl RegrType {
130    /// return the documentation for the `RegrType`
131    fn documentation(&self) -> Option<&Documentation> {
132        get_regr_docs().get(self)
133    }
134}
135
136static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
137    let mut hash_map = HashMap::new();
138    hash_map.insert(
139            RegrType::Slope,
140            Documentation::builder(
141                DOC_SECTION_STATISTICAL,
142                    "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
143                    Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
144
145                "regr_slope(expression_y, expression_x)")
146                .with_standard_argument("expression_y", Some("Dependent variable"))
147                .with_standard_argument("expression_x", Some("Independent variable"))
148                .build()
149        );
150
151    hash_map.insert(
152            RegrType::Intercept,
153            Documentation::builder(
154                DOC_SECTION_STATISTICAL,
155                    "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
156                    this function returns b.",
157
158                "regr_intercept(expression_y, expression_x)")
159                .with_standard_argument("expression_y", Some("Dependent variable"))
160                .with_standard_argument("expression_x", Some("Independent variable"))
161                .build()
162        );
163
164    hash_map.insert(
165        RegrType::Count,
166        Documentation::builder(
167            DOC_SECTION_STATISTICAL,
168            "Counts the number of non-null paired data points.",
169            "regr_count(expression_y, expression_x)",
170        )
171        .with_standard_argument("expression_y", Some("Dependent variable"))
172        .with_standard_argument("expression_x", Some("Independent variable"))
173        .build(),
174    );
175
176    hash_map.insert(
177            RegrType::R2,
178            Documentation::builder(
179                DOC_SECTION_STATISTICAL,
180                    "Computes the square of the correlation coefficient between the independent and dependent variables.",
181
182                "regr_r2(expression_y, expression_x)")
183                .with_standard_argument("expression_y", Some("Dependent variable"))
184                .with_standard_argument("expression_x", Some("Independent variable"))
185                .build()
186        );
187
188    hash_map.insert(
189            RegrType::AvgX,
190            Documentation::builder(
191                DOC_SECTION_STATISTICAL,
192                    "Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
193
194                "regr_avgx(expression_y, expression_x)")
195                .with_standard_argument("expression_y", Some("Dependent variable"))
196                .with_standard_argument("expression_x", Some("Independent variable"))
197                .build()
198        );
199
200    hash_map.insert(
201            RegrType::AvgY,
202            Documentation::builder(
203                DOC_SECTION_STATISTICAL,
204                    "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
205
206                "regr_avgy(expression_y, expression_x)")
207                .with_standard_argument("expression_y", Some("Dependent variable"))
208                .with_standard_argument("expression_x", Some("Independent variable"))
209                .build()
210        );
211
212    hash_map.insert(
213        RegrType::SXX,
214        Documentation::builder(
215            DOC_SECTION_STATISTICAL,
216            "Computes the sum of squares of the independent variable.",
217            "regr_sxx(expression_y, expression_x)",
218        )
219        .with_standard_argument("expression_y", Some("Dependent variable"))
220        .with_standard_argument("expression_x", Some("Independent variable"))
221        .build(),
222    );
223
224    hash_map.insert(
225        RegrType::SYY,
226        Documentation::builder(
227            DOC_SECTION_STATISTICAL,
228            "Computes the sum of squares of the dependent variable.",
229            "regr_syy(expression_y, expression_x)",
230        )
231        .with_standard_argument("expression_y", Some("Dependent variable"))
232        .with_standard_argument("expression_x", Some("Independent variable"))
233        .build(),
234    );
235
236    hash_map.insert(
237        RegrType::SXY,
238        Documentation::builder(
239            DOC_SECTION_STATISTICAL,
240            "Computes the sum of products of paired data points.",
241            "regr_sxy(expression_y, expression_x)",
242        )
243        .with_standard_argument("expression_y", Some("Dependent variable"))
244        .with_standard_argument("expression_x", Some("Independent variable"))
245        .build(),
246    );
247    hash_map
248});
249fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
250    &DOCUMENTATION
251}
252
253impl AggregateUDFImpl for Regr {
254    fn as_any(&self) -> &dyn Any {
255        self
256    }
257
258    fn name(&self) -> &str {
259        self.func_name
260    }
261
262    fn signature(&self) -> &Signature {
263        &self.signature
264    }
265
266    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
267        if !arg_types[0].is_numeric() {
268            return plan_err!("Covariance requires numeric input types");
269        }
270
271        if matches!(self.regr_type, RegrType::Count) {
272            Ok(DataType::UInt64)
273        } else {
274            Ok(DataType::Float64)
275        }
276    }
277
278    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
279        Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
280    }
281
282    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
283        Ok(vec![
284            Field::new(
285                format_state_name(args.name, "count"),
286                DataType::UInt64,
287                true,
288            ),
289            Field::new(
290                format_state_name(args.name, "mean_x"),
291                DataType::Float64,
292                true,
293            ),
294            Field::new(
295                format_state_name(args.name, "mean_y"),
296                DataType::Float64,
297                true,
298            ),
299            Field::new(
300                format_state_name(args.name, "m2_x"),
301                DataType::Float64,
302                true,
303            ),
304            Field::new(
305                format_state_name(args.name, "m2_y"),
306                DataType::Float64,
307                true,
308            ),
309            Field::new(
310                format_state_name(args.name, "algo_const"),
311                DataType::Float64,
312                true,
313            ),
314        ]
315        .into_iter()
316        .map(Arc::new)
317        .collect())
318    }
319
320    fn documentation(&self) -> Option<&Documentation> {
321        self.regr_type.documentation()
322    }
323}
324
325/// `RegrAccumulator` is used to compute linear regression aggregate functions
326/// by maintaining statistics needed to compute them in an online fashion.
327///
328/// This struct uses Welford's online algorithm for calculating variance and covariance:
329/// <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm>
330///
331/// Given the statistics, the following aggregate functions can be calculated:
332///
333/// - `regr_slope(y, x)`: Slope of the linear regression line, calculated as:
334///   cov_pop(x, y) / var_pop(x).
335///   It represents the expected change in Y for a one-unit change in X.
336///
337/// - `regr_intercept(y, x)`: Intercept of the linear regression line, calculated as:
338///   mean_y - (regr_slope(y, x) * mean_x).
339///   It represents the expected value of Y when X is 0.
340///
341/// - `regr_count(y, x)`: Count of the non-null(both x and y) input rows.
342///
343/// - `regr_r2(y, x)`: R-squared value (coefficient of determination), calculated as:
344///   (cov_pop(x, y) ^ 2) / (var_pop(x) * var_pop(y)).
345///   It provides a measure of how well the model's predictions match the observed data.
346///
347/// - `regr_avgx(y, x)`: Average of the independent variable X, calculated as: mean_x.
348///
349/// - `regr_avgy(y, x)`: Average of the dependent variable Y, calculated as: mean_y.
350///
351/// - `regr_sxx(y, x)`: Sum of squares of the independent variable X, calculated as:
352///   m2_x.
353///
354/// - `regr_syy(y, x)`: Sum of squares of the dependent variable Y, calculated as:
355///   m2_y.
356///
357/// - `regr_sxy(y, x)`: Sum of products of paired values, calculated as:
358///   algo_const.
359///
360/// Here's how the statistics maintained in this struct are calculated:
361/// - `cov_pop(x, y)`: algo_const / count.
362/// - `var_pop(x)`: m2_x / count.
363/// - `var_pop(y)`: m2_y / count.
364#[derive(Debug)]
365pub struct RegrAccumulator {
366    count: u64,
367    mean_x: f64,
368    mean_y: f64,
369    m2_x: f64,
370    m2_y: f64,
371    algo_const: f64,
372    regr_type: RegrType,
373}
374
375impl RegrAccumulator {
376    /// Creates a new `RegrAccumulator`
377    pub fn try_new(regr_type: &RegrType) -> Result<Self> {
378        Ok(Self {
379            count: 0_u64,
380            mean_x: 0_f64,
381            mean_y: 0_f64,
382            m2_x: 0_f64,
383            m2_y: 0_f64,
384            algo_const: 0_f64,
385            regr_type: regr_type.clone(),
386        })
387    }
388}
389
390impl Accumulator for RegrAccumulator {
391    fn state(&mut self) -> Result<Vec<ScalarValue>> {
392        Ok(vec![
393            ScalarValue::from(self.count),
394            ScalarValue::from(self.mean_x),
395            ScalarValue::from(self.mean_y),
396            ScalarValue::from(self.m2_x),
397            ScalarValue::from(self.m2_y),
398            ScalarValue::from(self.algo_const),
399        ])
400    }
401
402    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
403        // regr_slope(Y, X) calculates k in y = k*x + b
404        let values_y = &cast(&values[0], &DataType::Float64)?;
405        let values_x = &cast(&values[1], &DataType::Float64)?;
406
407        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
408        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
409
410        for i in 0..values_y.len() {
411            // skip either x or y is NULL
412            let value_y = if values_y.is_valid(i) {
413                arr_y.next()
414            } else {
415                None
416            };
417            let value_x = if values_x.is_valid(i) {
418                arr_x.next()
419            } else {
420                None
421            };
422            if value_y.is_none() || value_x.is_none() {
423                continue;
424            }
425
426            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
427            let value_y = unwrap_or_internal_err!(value_y);
428            let value_x = unwrap_or_internal_err!(value_x);
429
430            self.count += 1;
431            let delta_x = value_x - self.mean_x;
432            let delta_y = value_y - self.mean_y;
433            self.mean_x += delta_x / self.count as f64;
434            self.mean_y += delta_y / self.count as f64;
435            let delta_x_2 = value_x - self.mean_x;
436            let delta_y_2 = value_y - self.mean_y;
437            self.m2_x += delta_x * delta_x_2;
438            self.m2_y += delta_y * delta_y_2;
439            self.algo_const += delta_x * (value_y - self.mean_y);
440        }
441
442        Ok(())
443    }
444
445    fn supports_retract_batch(&self) -> bool {
446        true
447    }
448
449    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
450        let values_y = &cast(&values[0], &DataType::Float64)?;
451        let values_x = &cast(&values[1], &DataType::Float64)?;
452
453        let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten();
454        let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten();
455
456        for i in 0..values_y.len() {
457            // skip either x or y is NULL
458            let value_y = if values_y.is_valid(i) {
459                arr_y.next()
460            } else {
461                None
462            };
463            let value_x = if values_x.is_valid(i) {
464                arr_x.next()
465            } else {
466                None
467            };
468            if value_y.is_none() || value_x.is_none() {
469                continue;
470            }
471
472            // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)]
473            let value_y = unwrap_or_internal_err!(value_y);
474            let value_x = unwrap_or_internal_err!(value_x);
475
476            if self.count > 1 {
477                self.count -= 1;
478                let delta_x = value_x - self.mean_x;
479                let delta_y = value_y - self.mean_y;
480                self.mean_x -= delta_x / self.count as f64;
481                self.mean_y -= delta_y / self.count as f64;
482                let delta_x_2 = value_x - self.mean_x;
483                let delta_y_2 = value_y - self.mean_y;
484                self.m2_x -= delta_x * delta_x_2;
485                self.m2_y -= delta_y * delta_y_2;
486                self.algo_const -= delta_x * (value_y - self.mean_y);
487            } else {
488                self.count = 0;
489                self.mean_x = 0.0;
490                self.m2_x = 0.0;
491                self.m2_y = 0.0;
492                self.mean_y = 0.0;
493                self.algo_const = 0.0;
494            }
495        }
496
497        Ok(())
498    }
499
500    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
501        let count_arr = downcast_value!(states[0], UInt64Array);
502        let mean_x_arr = downcast_value!(states[1], Float64Array);
503        let mean_y_arr = downcast_value!(states[2], Float64Array);
504        let m2_x_arr = downcast_value!(states[3], Float64Array);
505        let m2_y_arr = downcast_value!(states[4], Float64Array);
506        let algo_const_arr = downcast_value!(states[5], Float64Array);
507
508        for i in 0..count_arr.len() {
509            let count_b = count_arr.value(i);
510            if count_b == 0_u64 {
511                continue;
512            }
513            let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
514                self.count,
515                self.mean_x,
516                self.mean_y,
517                self.m2_x,
518                self.m2_y,
519                self.algo_const,
520            );
521            let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
522                count_b,
523                mean_x_arr.value(i),
524                mean_y_arr.value(i),
525                m2_x_arr.value(i),
526                m2_y_arr.value(i),
527                algo_const_arr.value(i),
528            );
529
530            // Assuming two different batches of input have calculated the states:
531            // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a}
532            // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b}
533            // The merged states from A and B are {count_ab, mean_x_ab, mean_y_ab, m2_x_ab,
534            // algo_const_ab}
535            //
536            // Reference for the algorithm to merge states:
537            // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
538            let count_ab = count_a + count_b;
539            let (count_a, count_b) = (count_a as f64, count_b as f64);
540            let d_x = mean_x_b - mean_x_a;
541            let d_y = mean_y_b - mean_y_a;
542            let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
543            let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
544            let m2_x_ab =
545                m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
546            let m2_y_ab =
547                m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
548            let algo_const_ab = algo_const_a
549                + algo_const_b
550                + d_x * d_y * count_a * count_b / count_ab as f64;
551
552            self.count = count_ab;
553            self.mean_x = mean_x_ab;
554            self.mean_y = mean_y_ab;
555            self.m2_x = m2_x_ab;
556            self.m2_y = m2_y_ab;
557            self.algo_const = algo_const_ab;
558        }
559        Ok(())
560    }
561
562    fn evaluate(&mut self) -> Result<ScalarValue> {
563        let cov_pop_x_y = self.algo_const / self.count as f64;
564        let var_pop_x = self.m2_x / self.count as f64;
565        let var_pop_y = self.m2_y / self.count as f64;
566
567        let nullif_or_stat = |cond: bool, stat: f64| {
568            if cond {
569                Ok(ScalarValue::Float64(None))
570            } else {
571                Ok(ScalarValue::Float64(Some(stat)))
572            }
573        };
574
575        match self.regr_type {
576            RegrType::Slope => {
577                // Only 0/1 point or slope is infinite
578                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
579                nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
580            }
581            RegrType::Intercept => {
582                let slope = cov_pop_x_y / var_pop_x;
583                // Only 0/1 point or slope is infinite
584                let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
585                nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
586            }
587            RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
588            RegrType::R2 => {
589                // Only 0/1 point or all x(or y) is the same
590                let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
591                nullif_or_stat(
592                    nullif_cond,
593                    (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
594                )
595            }
596            RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
597            RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
598            RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
599            RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
600            RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
601        }
602    }
603
604    fn size(&self) -> usize {
605        size_of_val(self)
606    }
607}