use arrow::datatypes::FieldRef;
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
use datafusion_common::cast::{as_float64_array, as_uint64_array};
use datafusion_common::{HashMap, Result, ScalarValue};
use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
};
use std::any::Any;
use std::fmt::Debug;
use std::hash::Hash;
use std::mem::size_of_val;
use std::sync::{Arc, LazyLock};
macro_rules! make_regr_udaf_expr_and_func {
($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN);
create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN)));
}
}
make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope);
make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept);
make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count);
make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2);
make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX);
make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY);
make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX);
make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY);
make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY);
#[derive(PartialEq, Eq, Hash, Debug)]
pub struct Regr {
signature: Signature,
regr_type: RegrType,
func_name: &'static str,
}
impl Regr {
pub fn new(regr_type: RegrType, func_name: &'static str) -> Self {
Self {
signature: Signature::exact(
vec![DataType::Float64, DataType::Float64],
Volatility::Immutable,
),
regr_type,
func_name,
}
}
}
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
pub enum RegrType {
Slope,
Intercept,
Count,
R2,
AvgX,
AvgY,
SXX,
SYY,
SXY,
}
impl RegrType {
fn documentation(&self) -> Option<&Documentation> {
get_regr_docs().get(self)
}
}
static DOCUMENTATION: LazyLock<HashMap<RegrType, Documentation>> = LazyLock::new(|| {
let mut hash_map = HashMap::new();
hash_map.insert(
RegrType::Slope,
Documentation::builder(
DOC_SECTION_STATISTICAL,
"Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
"regr_slope(expression_y, expression_x)")
.with_sql_example(
r#"```sql
create table weekly_performance(day int, user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
select * from weekly_performance;
+-----+--------------+
| day | user_signups |
+-----+--------------+
| 1 | 60 |
| 2 | 65 |
| 3 | 70 |
| 4 | 75 |
| 5 | 80 |
+-----+--------------+
SELECT regr_slope(user_signups, day) AS slope FROM weekly_performance;
+--------+
| slope |
+--------+
| 5.0 |
+--------+
```
"#
)
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
);
hash_map.insert(
RegrType::Intercept,
Documentation::builder(
DOC_SECTION_STATISTICAL,
"Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
this function returns b.",
"regr_intercept(expression_y, expression_x)")
.with_sql_example(
r#"```sql
create table weekly_performance(week int, productivity_score int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
select * from weekly_performance;
+------+---------------------+
| week | productivity_score |
| ---- | ------------------- |
| 1 | 60 |
| 2 | 65 |
| 3 | 70 |
| 4 | 75 |
| 5 | 80 |
+------+---------------------+
SELECT regr_intercept(productivity_score, week) AS intercept FROM weekly_performance;
+----------+
|intercept|
|intercept |
+----------+
| 55 |
+----------+
```
"#
)
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
);
hash_map.insert(
RegrType::Count,
Documentation::builder(
DOC_SECTION_STATISTICAL,
"Counts the number of non-null paired data points.",
"regr_count(expression_y, expression_x)",
)
.with_sql_example(
r#"```sql
create table daily_metrics(day int, user_signups int) as values (1,100), (2,120), (3, NULL), (4,110), (5,NULL);
select * from daily_metrics;
+-----+---------------+
| day | user_signups |
| --- | ------------- |
| 1 | 100 |
| 2 | 120 |
| 3 | NULL |
| 4 | 110 |
| 5 | NULL |
+-----+---------------+
SELECT regr_count(user_signups, day) AS valid_pairs FROM daily_metrics;
+-------------+
| valid_pairs |
+-------------+
| 3 |
+-------------+
```
"#
)
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build(),
);
hash_map.insert(
RegrType::R2,
Documentation::builder(
DOC_SECTION_STATISTICAL,
"Computes the square of the correlation coefficient between the independent and dependent variables.",
"regr_r2(expression_y, expression_x)")
.with_sql_example(
r#"```sql
create table weekly_performance(day int ,user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80);
select * from weekly_performance;
+-----+--------------+
| day | user_signups |
+-----+--------------+
| 1 | 60 |
| 2 | 65 |
| 3 | 70 |
| 4 | 75 |
| 5 | 80 |
+-----+--------------+
SELECT regr_r2(user_signups, day) AS r_squared FROM weekly_performance;
+---------+
|r_squared|
+---------+
| 1.0 |
+---------+
```
"#
)
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
);
hash_map.insert(
RegrType::AvgX,
Documentation::builder(
DOC_SECTION_STATISTICAL,
"Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
"regr_avgx(expression_y, expression_x)")
.with_sql_example(
r#"```sql
create table daily_sales(day int, total_sales int) as values (1,100), (2,150), (3,200), (4,NULL), (5,250);
select * from daily_sales;
+-----+-------------+
| day | total_sales |
| --- | ----------- |
| 1 | 100 |
| 2 | 150 |
| 3 | 200 |
| 4 | NULL |
| 5 | 250 |
+-----+-------------+
SELECT regr_avgx(total_sales, day) AS avg_day FROM daily_sales;
+----------+
| avg_day |
+----------+
| 2.75 |
+----------+
```
"#
)
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
);
hash_map.insert(
RegrType::AvgY,
Documentation::builder(
DOC_SECTION_STATISTICAL,
"Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
"regr_avgy(expression_y, expression_x)")
.with_sql_example(
r#"```sql
create table daily_temperature(day int, temperature int) as values (1,30), (2,32), (3, NULL), (4,35), (5,36);
select * from daily_temperature;
+-----+-------------+
| day | temperature |
| --- | ----------- |
| 1 | 30 |
| 2 | 32 |
| 3 | NULL |
| 4 | 35 |
| 5 | 36 |
+-----+-------------+
-- temperature as Dependent Variable(Y), day as Independent Variable(X)
SELECT regr_avgy(temperature, day) AS avg_temperature FROM daily_temperature;
+-----------------+
| avg_temperature |
+-----------------+
| 33.25 |
+-----------------+
```
"#
)
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
);
hash_map.insert(
RegrType::SXX,
Documentation::builder(
DOC_SECTION_STATISTICAL,
"Computes the sum of squares of the independent variable.",
"regr_sxx(expression_y, expression_x)",
)
.with_sql_example(
r#"```sql
create table study_hours(student_id int, hours int, test_score int) as values (1,2,55), (2,4,65), (3,6,75), (4,8,85), (5,10,95);
select * from study_hours;
+------------+-------+------------+
| student_id | hours | test_score |
+------------+-------+------------+
| 1 | 2 | 55 |
| 2 | 4 | 65 |
| 3 | 6 | 75 |
| 4 | 8 | 85 |
| 5 | 10 | 95 |
+------------+-------+------------+
SELECT regr_sxx(test_score, hours) AS sxx FROM study_hours;
+------+
| sxx |
+------+
| 40.0 |
+------+
```
"#
)
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build(),
);
hash_map.insert(
RegrType::SYY,
Documentation::builder(
DOC_SECTION_STATISTICAL,
"Computes the sum of squares of the dependent variable.",
"regr_syy(expression_y, expression_x)",
)
.with_sql_example(
r#"```sql
create table employee_productivity(week int, productivity_score int) as values (1,60), (2,65), (3,70);
select * from employee_productivity;
+------+--------------------+
| week | productivity_score |
+------+--------------------+
| 1 | 60 |
| 2 | 65 |
| 3 | 70 |
+------+--------------------+
SELECT regr_syy(productivity_score, week) AS sum_squares_y FROM employee_productivity;
+---------------+
| sum_squares_y |
+---------------+
| 50.0 |
+---------------+
```
"#
)
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build(),
);
hash_map.insert(
RegrType::SXY,
Documentation::builder(
DOC_SECTION_STATISTICAL,
"Computes the sum of products of paired data points.",
"regr_sxy(expression_y, expression_x)",
)
.with_sql_example(
r#"```sql
create table employee_productivity(week int, productivity_score int) as values(1,60), (2,65), (3,70);
select * from employee_productivity;
+------+--------------------+
| week | productivity_score |
+------+--------------------+
| 1 | 60 |
| 2 | 65 |
| 3 | 70 |
+------+--------------------+
SELECT regr_sxy(productivity_score, week) AS sum_product_deviations FROM employee_productivity;
+------------------------+
| sum_product_deviations |
+------------------------+
| 10.0 |
+------------------------+
```
"#
)
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build(),
);
hash_map
});
fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
&DOCUMENTATION
}
impl AggregateUDFImpl for Regr {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.func_name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
if self.regr_type == RegrType::Count {
Ok(DataType::UInt64)
} else {
Ok(DataType::Float64)
}
}
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Ok(vec![
Field::new(
format_state_name(args.name, "count"),
DataType::UInt64,
true,
),
Field::new(
format_state_name(args.name, "mean_x"),
DataType::Float64,
true,
),
Field::new(
format_state_name(args.name, "mean_y"),
DataType::Float64,
true,
),
Field::new(
format_state_name(args.name, "m2_x"),
DataType::Float64,
true,
),
Field::new(
format_state_name(args.name, "m2_y"),
DataType::Float64,
true,
),
Field::new(
format_state_name(args.name, "algo_const"),
DataType::Float64,
true,
),
]
.into_iter()
.map(Arc::new)
.collect())
}
fn documentation(&self) -> Option<&Documentation> {
self.regr_type.documentation()
}
}
#[derive(Debug)]
pub struct RegrAccumulator {
count: u64,
mean_x: f64,
mean_y: f64,
m2_x: f64,
m2_y: f64,
algo_const: f64,
regr_type: RegrType,
}
impl RegrAccumulator {
pub fn try_new(regr_type: &RegrType) -> Result<Self> {
Ok(Self {
count: 0_u64,
mean_x: 0_f64,
mean_y: 0_f64,
m2_x: 0_f64,
m2_y: 0_f64,
algo_const: 0_f64,
regr_type: regr_type.clone(),
})
}
}
impl Accumulator for RegrAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::from(self.mean_x),
ScalarValue::from(self.mean_y),
ScalarValue::from(self.m2_x),
ScalarValue::from(self.m2_y),
ScalarValue::from(self.algo_const),
])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values_y = as_float64_array(&values[0])?;
let values_x = as_float64_array(&values[1])?;
for (value_y, value_x) in values_y.iter().zip(values_x) {
let (value_y, value_x) = match (value_y, value_x) {
(Some(y), Some(x)) => (y, x),
_ => continue,
};
self.count += 1;
let delta_x = value_x - self.mean_x;
let delta_y = value_y - self.mean_y;
self.mean_x += delta_x / self.count as f64;
self.mean_y += delta_y / self.count as f64;
let delta_x_2 = value_x - self.mean_x;
let delta_y_2 = value_y - self.mean_y;
self.m2_x += delta_x * delta_x_2;
self.m2_y += delta_y * delta_y_2;
self.algo_const += delta_x * (value_y - self.mean_y);
}
Ok(())
}
fn supports_retract_batch(&self) -> bool {
true
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values_y = as_float64_array(&values[0])?;
let values_x = as_float64_array(&values[1])?;
for (value_y, value_x) in values_y.iter().zip(values_x) {
let (value_y, value_x) = match (value_y, value_x) {
(Some(y), Some(x)) => (y, x),
_ => continue,
};
if self.count > 1 {
self.count -= 1;
let delta_x = value_x - self.mean_x;
let delta_y = value_y - self.mean_y;
self.mean_x -= delta_x / self.count as f64;
self.mean_y -= delta_y / self.count as f64;
let delta_x_2 = value_x - self.mean_x;
let delta_y_2 = value_y - self.mean_y;
self.m2_x -= delta_x * delta_x_2;
self.m2_y -= delta_y * delta_y_2;
self.algo_const -= delta_x * (value_y - self.mean_y);
} else {
self.count = 0;
self.mean_x = 0.0;
self.m2_x = 0.0;
self.m2_y = 0.0;
self.mean_y = 0.0;
self.algo_const = 0.0;
}
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let count_arr = as_uint64_array(&states[0])?;
let mean_x_arr = as_float64_array(&states[1])?;
let mean_y_arr = as_float64_array(&states[2])?;
let m2_x_arr = as_float64_array(&states[3])?;
let m2_y_arr = as_float64_array(&states[4])?;
let algo_const_arr = as_float64_array(&states[5])?;
for i in 0..count_arr.len() {
let count_b = count_arr.value(i);
if count_b == 0_u64 {
continue;
}
let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = (
self.count,
self.mean_x,
self.mean_y,
self.m2_x,
self.m2_y,
self.algo_const,
);
let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = (
count_b,
mean_x_arr.value(i),
mean_y_arr.value(i),
m2_x_arr.value(i),
m2_y_arr.value(i),
algo_const_arr.value(i),
);
let count_ab = count_a + count_b;
let (count_a, count_b) = (count_a as f64, count_b as f64);
let d_x = mean_x_b - mean_x_a;
let d_y = mean_y_b - mean_y_a;
let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
let m2_x_ab =
m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64;
let m2_y_ab =
m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64;
let algo_const_ab = algo_const_a
+ algo_const_b
+ d_x * d_y * count_a * count_b / count_ab as f64;
self.count = count_ab;
self.mean_x = mean_x_ab;
self.mean_y = mean_y_ab;
self.m2_x = m2_x_ab;
self.m2_y = m2_y_ab;
self.algo_const = algo_const_ab;
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let cov_pop_x_y = self.algo_const / self.count as f64;
let var_pop_x = self.m2_x / self.count as f64;
let var_pop_y = self.m2_y / self.count as f64;
let nullif_or_stat = |cond: bool, stat: f64| {
if cond {
Ok(ScalarValue::Float64(None))
} else {
Ok(ScalarValue::Float64(Some(stat)))
}
};
match self.regr_type {
RegrType::Slope => {
let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x)
}
RegrType::Intercept => {
let slope = cov_pop_x_y / var_pop_x;
let nullif_cond = self.count <= 1 || var_pop_x == 0.0;
nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x)
}
RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))),
RegrType::R2 => {
let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0;
nullif_or_stat(
nullif_cond,
(cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y),
)
}
RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x),
RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y),
RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x),
RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y),
RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const),
}
}
fn size(&self) -> usize {
size_of_val(self)
}
}